Coverage for mindsdb / interfaces / agents / callback_handlers.py: 34%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import io 

2import logging 

3import contextlib 

4from typing import Any, Dict, List, Union, Callable 

5 

6from langchain_core.agents import AgentAction, AgentFinish 

7from langchain_core.callbacks.base import BaseCallbackHandler 

8from langchain_core.messages.base import BaseMessage 

9from langchain_core.outputs import LLMResult 

10from langchain_core.callbacks import StdOutCallbackHandler 

11 

12 

13class ContextCaptureCallback(BaseCallbackHandler): 

14 def __init__(self): 

15 self.context = None 

16 

17 def on_retriever_end(self, documents: List[Any], *, run_id: str, parent_run_id: Union[str, None] = None, **kwargs: Any) -> Any: 

18 self.context = [{ 

19 'page_content': doc.page_content, 

20 'metadata': doc.metadata 

21 } for doc in documents] 

22 

23 def get_contexts(self): 

24 return self.context 

25 

26 

27class VerboseLogCallbackHandler(StdOutCallbackHandler): 

28 def __init__(self, logger: logging.Logger, verbose: bool): 

29 self.logger = logger 

30 self.verbose = verbose 

31 super().__init__() 

32 

33 def __call(self, method: Callable, *args: List[Any], **kwargs: Any) -> Any: 

34 if self.verbose is False: 

35 return 

36 f = io.StringIO() 

37 with contextlib.redirect_stdout(f): 

38 method(*args, **kwargs) 

39 output = f.getvalue() 

40 self.logger.info(output) 

41 

42 def on_chain_start(self, *args: List[Any], **kwargs: Any) -> None: 

43 self.__call(super().on_chain_start, *args, **kwargs) 

44 

45 def on_chain_end(self, *args: List[Any], **kwargs: Any) -> None: 

46 self.__call(super().on_chain_end, *args, **kwargs) 

47 

48 def on_agent_action(self, *args: List[Any], **kwargs: Any) -> None: 

49 self.__call(super().on_agent_action, *args, **kwargs) 

50 

51 def on_tool_end(self, *args: List[Any], **kwargs: Any) -> None: 

52 self.__call(super().on_tool_end, *args, **kwargs) 

53 

54 def on_text(self, *args: List[Any], **kwargs: Any) -> None: 

55 self.__call(super().on_text, *args, **kwargs) 

56 

57 def on_agent_finish(self, *args: List[Any], **kwargs: Any) -> None: 

58 self.__call(super().on_agent_finish, *args, **kwargs) 

59 

60 

61class LogCallbackHandler(BaseCallbackHandler): 

62 '''Langchain callback handler that logs agent and chain executions.''' 

63 

64 def __init__(self, logger: logging.Logger, verbose: bool = True): 

65 logger.setLevel('DEBUG') 

66 self.logger = logger 

67 self._num_running_chains = 0 

68 self.generated_sql = None 

69 self.verbose_log_handler = VerboseLogCallbackHandler(logger, verbose) 

70 

71 def on_llm_start( 

72 self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 

73 ) -> Any: 

74 '''Run when LLM starts running.''' 

75 self.logger.debug('LLM started with prompts:') 

76 for prompt in prompts: 

77 self.logger.debug(prompt[:50]) 

78 self.verbose_log_handler.on_llm_start(serialized, prompts, **kwargs) 

79 

80 def on_chat_model_start( 

81 self, 

82 serialized: Dict[str, Any], 

83 messages: List[List[BaseMessage]], **kwargs: Any 

84 ) -> Any: 

85 '''Run when Chat Model starts running.''' 

86 self.logger.debug('Chat model started with messages:') 

87 for message_list in messages: 

88 for message in message_list: 

89 self.logger.debug(message.pretty_repr()) 

90 

91 def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: 

92 '''Run on new LLM token. Only available when streaming is enabled.''' 

93 pass 

94 

95 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: 

96 '''Run when LLM ends running.''' 

97 self.logger.debug('LLM ended with response:') 

98 self.logger.debug(str(response.llm_output)) 

99 

100 def on_llm_error( 

101 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 

102 ) -> Any: 

103 '''Run when LLM errors.''' 

104 self.logger.debug(f'LLM encountered an error: {str(error)}') 

105 

106 def on_chain_start( 

107 self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 

108 ) -> Any: 

109 '''Run when chain starts running.''' 

110 self._num_running_chains += 1 

111 self.logger.info('Entering new LLM chain ({} total)'.format( 

112 self._num_running_chains)) 

113 self.logger.debug('Inputs: {}'.format(inputs)) 

114 

115 self.verbose_log_handler.on_chain_start(serialized=serialized, inputs=inputs, **kwargs) 

116 

117 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: 

118 '''Run when chain ends running.''' 

119 self._num_running_chains -= 1 

120 self.logger.info('Ended LLM chain ({} total)'.format( 

121 self._num_running_chains)) 

122 self.logger.debug('Outputs: {}'.format(outputs)) 

123 

124 self.verbose_log_handler.on_chain_end(outputs=outputs, **kwargs) 

125 

126 def on_chain_error( 

127 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 

128 ) -> Any: 

129 '''Run when chain errors.''' 

130 self._num_running_chains -= 1 

131 self.logger.error( 

132 'LLM chain encountered an error ({} running): {}'.format( 

133 self._num_running_chains, error)) 

134 

135 def on_tool_start( 

136 self, serialized: Dict[str, Any], input_str: str, **kwargs: Any 

137 ) -> Any: 

138 '''Run when tool starts running.''' 

139 pass 

140 

141 def on_tool_end(self, output: str, **kwargs: Any) -> Any: 

142 '''Run when tool ends running.''' 

143 self.verbose_log_handler.on_tool_end(output=output, **kwargs) 

144 

145 def on_tool_error( 

146 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 

147 ) -> Any: 

148 '''Run when tool errors.''' 

149 pass 

150 

151 def on_text(self, text: str, **kwargs: Any) -> Any: 

152 '''Run on arbitrary text.''' 

153 self.verbose_log_handler.on_text(text=text, **kwargs) 

154 

155 def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: 

156 '''Run on agent action.''' 

157 self.logger.debug(f'Running tool {action.tool} with input:') 

158 self.logger.debug(action.tool_input) 

159 

160 stop_block = 'Observation: ' 

161 if stop_block in action.tool_input: 

162 action.tool_input = action.tool_input[: action.tool_input.find(stop_block)] 

163 

164 if action.tool.startswith("sql_db_query"): 

165 # Save the generated SQL query 

166 self.generated_sql = action.tool_input 

167 

168 # fix for mistral 

169 action.tool = action.tool.replace('\\', '') 

170 

171 self.verbose_log_handler.on_agent_action(action=action, **kwargs) 

172 

173 def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: 

174 '''Run on agent end.''' 

175 self.logger.debug('Agent finished with return values:') 

176 self.logger.debug(str(finish.return_values)) 

177 self.verbose_log_handler.on_agent_finish(finish=finish, **kwargs)