Coverage for mindsdb / interfaces / agents / callback_handlers.py: 34%
100 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import io
2import logging
3import contextlib
4from typing import Any, Dict, List, Union, Callable
6from langchain_core.agents import AgentAction, AgentFinish
7from langchain_core.callbacks.base import BaseCallbackHandler
8from langchain_core.messages.base import BaseMessage
9from langchain_core.outputs import LLMResult
10from langchain_core.callbacks import StdOutCallbackHandler
13class ContextCaptureCallback(BaseCallbackHandler):
14 def __init__(self):
15 self.context = None
17 def on_retriever_end(self, documents: List[Any], *, run_id: str, parent_run_id: Union[str, None] = None, **kwargs: Any) -> Any:
18 self.context = [{
19 'page_content': doc.page_content,
20 'metadata': doc.metadata
21 } for doc in documents]
23 def get_contexts(self):
24 return self.context
27class VerboseLogCallbackHandler(StdOutCallbackHandler):
28 def __init__(self, logger: logging.Logger, verbose: bool):
29 self.logger = logger
30 self.verbose = verbose
31 super().__init__()
33 def __call(self, method: Callable, *args: List[Any], **kwargs: Any) -> Any:
34 if self.verbose is False:
35 return
36 f = io.StringIO()
37 with contextlib.redirect_stdout(f):
38 method(*args, **kwargs)
39 output = f.getvalue()
40 self.logger.info(output)
42 def on_chain_start(self, *args: List[Any], **kwargs: Any) -> None:
43 self.__call(super().on_chain_start, *args, **kwargs)
45 def on_chain_end(self, *args: List[Any], **kwargs: Any) -> None:
46 self.__call(super().on_chain_end, *args, **kwargs)
48 def on_agent_action(self, *args: List[Any], **kwargs: Any) -> None:
49 self.__call(super().on_agent_action, *args, **kwargs)
51 def on_tool_end(self, *args: List[Any], **kwargs: Any) -> None:
52 self.__call(super().on_tool_end, *args, **kwargs)
54 def on_text(self, *args: List[Any], **kwargs: Any) -> None:
55 self.__call(super().on_text, *args, **kwargs)
57 def on_agent_finish(self, *args: List[Any], **kwargs: Any) -> None:
58 self.__call(super().on_agent_finish, *args, **kwargs)
61class LogCallbackHandler(BaseCallbackHandler):
62 '''Langchain callback handler that logs agent and chain executions.'''
64 def __init__(self, logger: logging.Logger, verbose: bool = True):
65 logger.setLevel('DEBUG')
66 self.logger = logger
67 self._num_running_chains = 0
68 self.generated_sql = None
69 self.verbose_log_handler = VerboseLogCallbackHandler(logger, verbose)
71 def on_llm_start(
72 self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
73 ) -> Any:
74 '''Run when LLM starts running.'''
75 self.logger.debug('LLM started with prompts:')
76 for prompt in prompts:
77 self.logger.debug(prompt[:50])
78 self.verbose_log_handler.on_llm_start(serialized, prompts, **kwargs)
80 def on_chat_model_start(
81 self,
82 serialized: Dict[str, Any],
83 messages: List[List[BaseMessage]], **kwargs: Any
84 ) -> Any:
85 '''Run when Chat Model starts running.'''
86 self.logger.debug('Chat model started with messages:')
87 for message_list in messages:
88 for message in message_list:
89 self.logger.debug(message.pretty_repr())
91 def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
92 '''Run on new LLM token. Only available when streaming is enabled.'''
93 pass
95 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
96 '''Run when LLM ends running.'''
97 self.logger.debug('LLM ended with response:')
98 self.logger.debug(str(response.llm_output))
100 def on_llm_error(
101 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
102 ) -> Any:
103 '''Run when LLM errors.'''
104 self.logger.debug(f'LLM encountered an error: {str(error)}')
106 def on_chain_start(
107 self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
108 ) -> Any:
109 '''Run when chain starts running.'''
110 self._num_running_chains += 1
111 self.logger.info('Entering new LLM chain ({} total)'.format(
112 self._num_running_chains))
113 self.logger.debug('Inputs: {}'.format(inputs))
115 self.verbose_log_handler.on_chain_start(serialized=serialized, inputs=inputs, **kwargs)
117 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
118 '''Run when chain ends running.'''
119 self._num_running_chains -= 1
120 self.logger.info('Ended LLM chain ({} total)'.format(
121 self._num_running_chains))
122 self.logger.debug('Outputs: {}'.format(outputs))
124 self.verbose_log_handler.on_chain_end(outputs=outputs, **kwargs)
126 def on_chain_error(
127 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
128 ) -> Any:
129 '''Run when chain errors.'''
130 self._num_running_chains -= 1
131 self.logger.error(
132 'LLM chain encountered an error ({} running): {}'.format(
133 self._num_running_chains, error))
135 def on_tool_start(
136 self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
137 ) -> Any:
138 '''Run when tool starts running.'''
139 pass
141 def on_tool_end(self, output: str, **kwargs: Any) -> Any:
142 '''Run when tool ends running.'''
143 self.verbose_log_handler.on_tool_end(output=output, **kwargs)
145 def on_tool_error(
146 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
147 ) -> Any:
148 '''Run when tool errors.'''
149 pass
151 def on_text(self, text: str, **kwargs: Any) -> Any:
152 '''Run on arbitrary text.'''
153 self.verbose_log_handler.on_text(text=text, **kwargs)
155 def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
156 '''Run on agent action.'''
157 self.logger.debug(f'Running tool {action.tool} with input:')
158 self.logger.debug(action.tool_input)
160 stop_block = 'Observation: '
161 if stop_block in action.tool_input:
162 action.tool_input = action.tool_input[: action.tool_input.find(stop_block)]
164 if action.tool.startswith("sql_db_query"):
165 # Save the generated SQL query
166 self.generated_sql = action.tool_input
168 # fix for mistral
169 action.tool = action.tool.replace('\\', '')
171 self.verbose_log_handler.on_agent_action(action=action, **kwargs)
173 def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
174 '''Run on agent end.'''
175 self.logger.debug('Agent finished with return values:')
176 self.logger.debug(str(finish.return_values))
177 self.verbose_log_handler.on_agent_finish(finish=finish, **kwargs)