Coverage for mindsdb / integrations / handlers / langchain_handler / tools.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import tiktoken 

2from typing import Callable 

3 

4from mindsdb_sql_parser import parse_sql 

5from mindsdb_sql_parser.ast import Insert 

6from langchain_community.agent_toolkits.load_tools import load_tools 

7 

8from langchain_experimental.utilities import PythonREPL 

9from langchain_community.utilities import GoogleSerperAPIWrapper 

10 

11from langchain.chains.llm import LLMChain 

12from langchain.chains.combine_documents.stuff import StuffDocumentsChain 

13from langchain.chains import ReduceDocumentsChain, MapReduceDocumentsChain 

14 

15from mindsdb.interfaces.skills.skill_tool import skill_tool 

16from mindsdb.utilities import log 

17from langchain_core.prompts import PromptTemplate 

18from langchain_core.tools import Tool 

19from langchain_text_splitters import CharacterTextSplitter 

20 

21logger = log.getLogger(__name__) 

22 

23# Individual tools 

24# Note: all tools are defined in a closure to pass required args (apart from LLM input) through it, as custom tools don't allow custom field assignment. # noqa 

25 

26 

27def get_exec_call_tool(llm, executor, model_kwargs) -> Callable: 

28 def mdb_exec_call_tool(query: str) -> str: 

29 try: 

30 ast_query = parse_sql(query.strip('`')) 

31 ret = executor.execute_command(ast_query) 

32 if ret.data is None and ret.error_code is None: 

33 return '' 

34 data = ret.data.to_lists() # list of lists 

35 data = '\n'.join([ # rows 

36 '\t'.join( # columns 

37 str(row) if isinstance(row, str) else [str(value) for value in row] 

38 ) for row in data 

39 ]) 

40 except Exception as e: 

41 data = f"mindsdb tool failed with error:\n{str(e)}" # let the agent know 

42 

43 # summarize output if needed 

44 data = summarize_if_overflowed(data, llm, model_kwargs['max_tokens']) 

45 

46 return data 

47 return mdb_exec_call_tool 

48 

49 

50def get_exec_metadata_tool(llm, executor, model_kwargs) -> Callable: 

51 def mdb_exec_metadata_call(query: str) -> str: 

52 try: 

53 parts = query.replace('`', '').split('.') 

54 assert 1 <= len(parts) <= 2, 'query must be in the format: `integration` or `integration.table`' 

55 

56 integration = parts[0] 

57 integrations = executor.session.integration_controller 

58 handler = integrations.get_data_handler(integration) 

59 

60 if len(parts) == 1: 

61 df = handler.get_tables().data_frame 

62 data = f'The integration `{integration}` has {df.shape[0]} tables: {", ".join(list(df["TABLE_NAME"].values))}' # noqa 

63 

64 if len(parts) == 2: 

65 df = handler.get_tables().data_frame 

66 table_name = parts[-1] 

67 try: 

68 table_name_col = 'TABLE_NAME' if 'TABLE_NAME' in df.columns else 'table_name' 

69 mdata = df[df[table_name_col] == table_name].iloc[0].to_list() 

70 if len(mdata) == 3: 

71 _, nrows, table_type = mdata 

72 data = f'Metadata for table {table_name}:\n\tRow count: {nrows}\n\tType: {table_type}\n' 

73 elif len(mdata) == 2: 

74 nrows = mdata 

75 data = f'Metadata for table {table_name}:\n\tRow count: {nrows}\n' 

76 else: 

77 data = f'Metadata for table {table_name}:\n' 

78 fields = handler.get_columns(table_name).data_frame['Field'].to_list() 

79 types = handler.get_columns(table_name).data_frame['Type'].to_list() 

80 data += 'List of columns and types:\n' 

81 data += '\n'.join([f'\tColumn: `{field}`\tType: `{typ}`' for field, typ in zip(fields, types)]) 

82 except BaseException: 

83 data = f'Table {table_name} not found.' 

84 except Exception as e: 

85 data = f"mindsdb tool failed with error:\n{str(e)}" # let the agent know 

86 

87 # summarize output if needed 

88 data = summarize_if_overflowed(data, llm, model_kwargs['max_tokens']) 

89 

90 return data 

91 return mdb_exec_metadata_call 

92 

93 

94def get_mdb_write_tool(executor) -> Callable: 

95 def mdb_write_call(query: str) -> str: 

96 try: 

97 query = query.strip('`') 

98 ast_query = parse_sql(query.strip('`')) 

99 if isinstance(ast_query, Insert): 

100 _ = executor.execute_command(ast_query) 

101 return "mindsdb write tool executed successfully" 

102 except Exception as e: 

103 return f"mindsdb write tool failed with error:\n{str(e)}" 

104 return mdb_write_call 

105 

106 

107def _setup_standard_tools(tools, llm, model_kwargs): 

108 executor = skill_tool.get_command_executor() 

109 

110 all_standard_tools = [] 

111 langchain_tools = [] 

112 for tool in tools: 

113 if tool == 'mindsdb_read': 

114 mdb_tool = Tool( 

115 name="MindsDB", 

116 func=get_exec_call_tool(llm, executor, model_kwargs), 

117 description="useful to read from databases or tables connected to the mindsdb machine learning package. the action must be a valid simple SQL query, always ending with a semicolon. For example, you can do `show databases;` to list the available data sources, and `show tables;` to list the available tables within each data source." # noqa 

118 ) 

119 

120 mdb_meta_tool = Tool( 

121 name="MDB-Metadata", 

122 func=get_exec_metadata_tool(llm, executor, model_kwargs), 

123 description="useful to get column names from a mindsdb table or metadata from a mindsdb data source. the command should be either 1) a data source name, to list all available tables that it exposes, or 2) a string with the format `data_source_name.table_name` (for example, `files.my_table`), to get the table name, table type, column names, data types per column, and amount of rows of the specified table." # noqa 

124 ) 

125 all_standard_tools.append(mdb_tool) 

126 all_standard_tools.append(mdb_meta_tool) 

127 if tool == 'mindsdb_write': 

128 mdb_write_tool = Tool( 

129 name="MDB-Write", 

130 func=get_mdb_write_tool(executor), 

131 description="useful to write into data sources connected to mindsdb. command must be a valid SQL query with syntax: `INSERT INTO data_source_name.table_name (column_name_1, column_name_2, [...]) VALUES (column_1_value_row_1, column_2_value_row_1, [...]), (column_1_value_row_2, column_2_value_row_2, [...]), [...];`. note the command always ends with a semicolon. order of column names and values for each row must be a perfect match. If write fails, try casting value with a function, passing the value without quotes, or truncating string as needed.`." # noqa 

132 ) 

133 all_standard_tools.append(mdb_write_tool) 

134 elif tool == 'python_repl': 

135 tool = Tool( 

136 name="python_repl", 

137 func=PythonREPL().run, 

138 description="useful for running custom Python code. Note: this is a powerful tool, so use with caution." # noqa 

139 ) 

140 langchain_tools.append(tool) 

141 elif tool == 'serper': 

142 search = GoogleSerperAPIWrapper() 

143 tool = Tool( 

144 name="Intermediate Answer", 

145 func=search.run, 

146 description="useful for when you need to ask with search", 

147 ) 

148 langchain_tools.append(tool) 

149 else: 

150 raise ValueError(f"Unsupported tool: {tool}") 

151 

152 if langchain_tools: 

153 all_standard_tools += load_tools(langchain_tools) 

154 return all_standard_tools 

155 

156 

157# Collector 

158def setup_tools(llm, model_kwargs, pred_args, default_agent_tools): 

159 

160 toolkit = pred_args['tools'] if pred_args.get('tools') is not None else default_agent_tools 

161 

162 standard_tools = [] 

163 function_tools = [] 

164 

165 for tool in toolkit: 

166 if isinstance(tool, str): 

167 standard_tools.append(tool) 

168 else: 

169 # user defined custom functions 

170 function_tools.append(tool) 

171 

172 tools = [] 

173 

174 if len(tools) == 0: 

175 tools = _setup_standard_tools(standard_tools, llm, model_kwargs) 

176 

177 if model_kwargs.get('serper_api_key', False): 

178 search = GoogleSerperAPIWrapper(serper_api_key=model_kwargs.pop('serper_api_key')) 

179 tools.append(Tool( 

180 name="Intermediate Answer (serper.dev)", 

181 func=search.run, 

182 description="useful for when you need to search the internet (note: in general, use this as a last resort)" # noqa 

183 )) 

184 

185 for tool in function_tools: 

186 tools.append(Tool( 

187 name=tool['name'], 

188 func=tool['func'], 

189 description=tool['description'], 

190 )) 

191 

192 return tools 

193 

194 

195# Helpers 

196def summarize_if_overflowed(data, llm, max_tokens, budget_multiplier=0.8) -> str: 

197 """ 

198 This helper retries with a summarized version of the 

199 output if the previous call fails due to the token limit being exceeded. 

200 

201 We trigger summarization when the token count exceeds the limit times a multiplier to be conservative. 

202 """ 

203 # tokenize data for length check 

204 # note: this is a rough estimate, as the tokenizer used in each LLM may be different 

205 encoding = tiktoken.get_encoding("gpt2") 

206 n_tokens = len(encoding.encode(data)) 

207 

208 # map-reduce given token budget 

209 if n_tokens > max_tokens * budget_multiplier: 

210 # map 

211 map_template = """The following is a set of documents 

212 {docs} 

213 Based on this list of docs, please identify the main themes 

214 Helpful Answer:""" 

215 map_prompt = PromptTemplate.from_template(map_template) 

216 map_chain = LLMChain(llm=llm, prompt=map_prompt) 

217 

218 # reduce 

219 reduce_template = """The following is set of summaries: 

220 {doc_summaries} 

221 Take these and distill it into a final, consolidated summary of the main themes. 

222 Helpful Answer:""" 

223 reduce_prompt = PromptTemplate.from_template(reduce_template) 

224 reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) 

225 combine_documents_chain = StuffDocumentsChain( 

226 llm_chain=reduce_chain, document_variable_name="doc_summaries" 

227 ) 

228 reduce_documents_chain = ReduceDocumentsChain( 

229 combine_documents_chain=combine_documents_chain, 

230 collapse_documents_chain=combine_documents_chain, 

231 token_max=max_tokens * budget_multiplier, # applies for each group of documents 

232 ) 

233 map_reduce_chain = MapReduceDocumentsChain( 

234 llm_chain=map_chain, 

235 reduce_documents_chain=reduce_documents_chain, 

236 document_variable_name="docs", 

237 return_intermediate_steps=False, 

238 ) 

239 # split 

240 text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0) 

241 docs = text_splitter.create_documents([data]) 

242 split_docs = text_splitter.split_documents(docs) 

243 

244 # run chain 

245 data = map_reduce_chain.run(split_docs) 

246 return data