Coverage for mindsdb / integrations / handlers / langchain_handler / tools.py: 0%
135 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import tiktoken
2from typing import Callable
4from mindsdb_sql_parser import parse_sql
5from mindsdb_sql_parser.ast import Insert
6from langchain_community.agent_toolkits.load_tools import load_tools
8from langchain_experimental.utilities import PythonREPL
9from langchain_community.utilities import GoogleSerperAPIWrapper
11from langchain.chains.llm import LLMChain
12from langchain.chains.combine_documents.stuff import StuffDocumentsChain
13from langchain.chains import ReduceDocumentsChain, MapReduceDocumentsChain
15from mindsdb.interfaces.skills.skill_tool import skill_tool
16from mindsdb.utilities import log
17from langchain_core.prompts import PromptTemplate
18from langchain_core.tools import Tool
19from langchain_text_splitters import CharacterTextSplitter
21logger = log.getLogger(__name__)
23# Individual tools
24# Note: all tools are defined in a closure to pass required args (apart from LLM input) through it, as custom tools don't allow custom field assignment. # noqa
27def get_exec_call_tool(llm, executor, model_kwargs) -> Callable:
28 def mdb_exec_call_tool(query: str) -> str:
29 try:
30 ast_query = parse_sql(query.strip('`'))
31 ret = executor.execute_command(ast_query)
32 if ret.data is None and ret.error_code is None:
33 return ''
34 data = ret.data.to_lists() # list of lists
35 data = '\n'.join([ # rows
36 '\t'.join( # columns
37 str(row) if isinstance(row, str) else [str(value) for value in row]
38 ) for row in data
39 ])
40 except Exception as e:
41 data = f"mindsdb tool failed with error:\n{str(e)}" # let the agent know
43 # summarize output if needed
44 data = summarize_if_overflowed(data, llm, model_kwargs['max_tokens'])
46 return data
47 return mdb_exec_call_tool
50def get_exec_metadata_tool(llm, executor, model_kwargs) -> Callable:
51 def mdb_exec_metadata_call(query: str) -> str:
52 try:
53 parts = query.replace('`', '').split('.')
54 assert 1 <= len(parts) <= 2, 'query must be in the format: `integration` or `integration.table`'
56 integration = parts[0]
57 integrations = executor.session.integration_controller
58 handler = integrations.get_data_handler(integration)
60 if len(parts) == 1:
61 df = handler.get_tables().data_frame
62 data = f'The integration `{integration}` has {df.shape[0]} tables: {", ".join(list(df["TABLE_NAME"].values))}' # noqa
64 if len(parts) == 2:
65 df = handler.get_tables().data_frame
66 table_name = parts[-1]
67 try:
68 table_name_col = 'TABLE_NAME' if 'TABLE_NAME' in df.columns else 'table_name'
69 mdata = df[df[table_name_col] == table_name].iloc[0].to_list()
70 if len(mdata) == 3:
71 _, nrows, table_type = mdata
72 data = f'Metadata for table {table_name}:\n\tRow count: {nrows}\n\tType: {table_type}\n'
73 elif len(mdata) == 2:
74 nrows = mdata
75 data = f'Metadata for table {table_name}:\n\tRow count: {nrows}\n'
76 else:
77 data = f'Metadata for table {table_name}:\n'
78 fields = handler.get_columns(table_name).data_frame['Field'].to_list()
79 types = handler.get_columns(table_name).data_frame['Type'].to_list()
80 data += 'List of columns and types:\n'
81 data += '\n'.join([f'\tColumn: `{field}`\tType: `{typ}`' for field, typ in zip(fields, types)])
82 except BaseException:
83 data = f'Table {table_name} not found.'
84 except Exception as e:
85 data = f"mindsdb tool failed with error:\n{str(e)}" # let the agent know
87 # summarize output if needed
88 data = summarize_if_overflowed(data, llm, model_kwargs['max_tokens'])
90 return data
91 return mdb_exec_metadata_call
94def get_mdb_write_tool(executor) -> Callable:
95 def mdb_write_call(query: str) -> str:
96 try:
97 query = query.strip('`')
98 ast_query = parse_sql(query.strip('`'))
99 if isinstance(ast_query, Insert):
100 _ = executor.execute_command(ast_query)
101 return "mindsdb write tool executed successfully"
102 except Exception as e:
103 return f"mindsdb write tool failed with error:\n{str(e)}"
104 return mdb_write_call
107def _setup_standard_tools(tools, llm, model_kwargs):
108 executor = skill_tool.get_command_executor()
110 all_standard_tools = []
111 langchain_tools = []
112 for tool in tools:
113 if tool == 'mindsdb_read':
114 mdb_tool = Tool(
115 name="MindsDB",
116 func=get_exec_call_tool(llm, executor, model_kwargs),
117 description="useful to read from databases or tables connected to the mindsdb machine learning package. the action must be a valid simple SQL query, always ending with a semicolon. For example, you can do `show databases;` to list the available data sources, and `show tables;` to list the available tables within each data source." # noqa
118 )
120 mdb_meta_tool = Tool(
121 name="MDB-Metadata",
122 func=get_exec_metadata_tool(llm, executor, model_kwargs),
123 description="useful to get column names from a mindsdb table or metadata from a mindsdb data source. the command should be either 1) a data source name, to list all available tables that it exposes, or 2) a string with the format `data_source_name.table_name` (for example, `files.my_table`), to get the table name, table type, column names, data types per column, and amount of rows of the specified table." # noqa
124 )
125 all_standard_tools.append(mdb_tool)
126 all_standard_tools.append(mdb_meta_tool)
127 if tool == 'mindsdb_write':
128 mdb_write_tool = Tool(
129 name="MDB-Write",
130 func=get_mdb_write_tool(executor),
131 description="useful to write into data sources connected to mindsdb. command must be a valid SQL query with syntax: `INSERT INTO data_source_name.table_name (column_name_1, column_name_2, [...]) VALUES (column_1_value_row_1, column_2_value_row_1, [...]), (column_1_value_row_2, column_2_value_row_2, [...]), [...];`. note the command always ends with a semicolon. order of column names and values for each row must be a perfect match. If write fails, try casting value with a function, passing the value without quotes, or truncating string as needed.`." # noqa
132 )
133 all_standard_tools.append(mdb_write_tool)
134 elif tool == 'python_repl':
135 tool = Tool(
136 name="python_repl",
137 func=PythonREPL().run,
138 description="useful for running custom Python code. Note: this is a powerful tool, so use with caution." # noqa
139 )
140 langchain_tools.append(tool)
141 elif tool == 'serper':
142 search = GoogleSerperAPIWrapper()
143 tool = Tool(
144 name="Intermediate Answer",
145 func=search.run,
146 description="useful for when you need to ask with search",
147 )
148 langchain_tools.append(tool)
149 else:
150 raise ValueError(f"Unsupported tool: {tool}")
152 if langchain_tools:
153 all_standard_tools += load_tools(langchain_tools)
154 return all_standard_tools
157# Collector
158def setup_tools(llm, model_kwargs, pred_args, default_agent_tools):
160 toolkit = pred_args['tools'] if pred_args.get('tools') is not None else default_agent_tools
162 standard_tools = []
163 function_tools = []
165 for tool in toolkit:
166 if isinstance(tool, str):
167 standard_tools.append(tool)
168 else:
169 # user defined custom functions
170 function_tools.append(tool)
172 tools = []
174 if len(tools) == 0:
175 tools = _setup_standard_tools(standard_tools, llm, model_kwargs)
177 if model_kwargs.get('serper_api_key', False):
178 search = GoogleSerperAPIWrapper(serper_api_key=model_kwargs.pop('serper_api_key'))
179 tools.append(Tool(
180 name="Intermediate Answer (serper.dev)",
181 func=search.run,
182 description="useful for when you need to search the internet (note: in general, use this as a last resort)" # noqa
183 ))
185 for tool in function_tools:
186 tools.append(Tool(
187 name=tool['name'],
188 func=tool['func'],
189 description=tool['description'],
190 ))
192 return tools
195# Helpers
196def summarize_if_overflowed(data, llm, max_tokens, budget_multiplier=0.8) -> str:
197 """
198 This helper retries with a summarized version of the
199 output if the previous call fails due to the token limit being exceeded.
201 We trigger summarization when the token count exceeds the limit times a multiplier to be conservative.
202 """
203 # tokenize data for length check
204 # note: this is a rough estimate, as the tokenizer used in each LLM may be different
205 encoding = tiktoken.get_encoding("gpt2")
206 n_tokens = len(encoding.encode(data))
208 # map-reduce given token budget
209 if n_tokens > max_tokens * budget_multiplier:
210 # map
211 map_template = """The following is a set of documents
212 {docs}
213 Based on this list of docs, please identify the main themes
214 Helpful Answer:"""
215 map_prompt = PromptTemplate.from_template(map_template)
216 map_chain = LLMChain(llm=llm, prompt=map_prompt)
218 # reduce
219 reduce_template = """The following is set of summaries:
220 {doc_summaries}
221 Take these and distill it into a final, consolidated summary of the main themes.
222 Helpful Answer:"""
223 reduce_prompt = PromptTemplate.from_template(reduce_template)
224 reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
225 combine_documents_chain = StuffDocumentsChain(
226 llm_chain=reduce_chain, document_variable_name="doc_summaries"
227 )
228 reduce_documents_chain = ReduceDocumentsChain(
229 combine_documents_chain=combine_documents_chain,
230 collapse_documents_chain=combine_documents_chain,
231 token_max=max_tokens * budget_multiplier, # applies for each group of documents
232 )
233 map_reduce_chain = MapReduceDocumentsChain(
234 llm_chain=map_chain,
235 reduce_documents_chain=reduce_documents_chain,
236 document_variable_name="docs",
237 return_intermediate_steps=False,
238 )
239 # split
240 text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)
241 docs = text_splitter.create_documents([data])
242 split_docs = text_splitter.split_documents(docs)
244 # run chain
245 data = map_reduce_chain.run(split_docs)
246 return data