Coverage for mindsdb / interfaces / agents / mcp_client_agent.py: 0%
129 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2import asyncio
3from typing import Dict, List, Any, Iterator, ClassVar
4from contextlib import AsyncExitStack
6import pandas as pd
7from mcp import ClientSession, StdioServerParameters
8from mcp.client.stdio import stdio_client
10from mindsdb.utilities import log
11from mindsdb.interfaces.agents.langchain_agent import LangchainAgent
12from mindsdb.interfaces.storage import db
13from langchain_core.tools import BaseTool
15logger = log.getLogger(__name__)
18class MCPQueryTool(BaseTool):
19 """Tool that executes queries via MCP server"""
21 name: ClassVar[str] = "mcp_query"
22 description: ClassVar[str] = "Execute SQL queries against the MindsDB server via MCP protocol"
24 def __init__(self, session: ClientSession):
25 super().__init__()
26 self.session = session
28 async def _arun(self, query: str) -> str:
29 """Execute a query via MCP asynchronously"""
30 try:
31 logger.info(f"Executing MCP query: {query}")
32 # Find the appropriate tool for SQL queries
33 tools_response = await self.session.list_tools()
34 query_tool = None
36 for tool in tools_response.tools:
37 if tool.name == "query":
38 query_tool = tool
39 break
41 if not query_tool:
42 return "Error: No 'query' tool found in the MCP server"
44 # Call the query tool
45 result = await self.session.call_tool("query", {"query": query})
47 # Process the results
48 if isinstance(result.content, dict) and "data" in result.content and "column_names" in result.content:
49 # Create a DataFrame from the results
50 df = pd.DataFrame(result.content["data"], columns=result.content["column_names"])
51 return df.to_string()
53 # Return raw result for other types
54 return f"Query executed successfully: {json.dumps(result.content)}"
56 except Exception as e:
57 logger.error("Error executing MCP query:")
58 return f"Error executing query: {e}"
60 def _run(self, query: str) -> str:
61 """Synchronous wrapper for async query function"""
62 loop = asyncio.get_event_loop()
63 return loop.run_until_complete(self._arun(query))
66# todo move instantiation to agent controller
67class MCPLangchainAgent(LangchainAgent):
68 """Extension of LangchainAgent that delegates to MCP server"""
70 def __init__(
71 self,
72 agent: db.Agents,
73 model: dict = None,
74 llm_params: dict = None,
75 mcp_host: str = "127.0.0.1",
76 mcp_port: int = 47337,
77 ):
78 super().__init__(agent, model, llm_params)
79 self.mcp_host = mcp_host
80 self.mcp_port = mcp_port
81 self.exit_stack = AsyncExitStack()
82 self.session = None
83 self.stdio = None
84 self.write = None
86 async def connect_to_mcp(self):
87 """Connect to the MCP server using stdio transport"""
88 if self.session is None:
89 logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}")
90 try:
91 # For connecting to an already running MCP server
92 # Set up server parameters to connect to existing process
93 server_params = StdioServerParameters(
94 command="python",
95 args=["-m", "mindsdb", "--api=mcp"],
96 env={"MCP_HOST": self.mcp_host, "MCP_PORT": str(self.mcp_port)},
97 )
99 logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}")
101 # Connect to the server
102 stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
103 self.stdio, self.write = stdio_transport
104 self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
106 await self.session.initialize()
108 # Test the connection by listing tools
109 tools_response = await self.session.list_tools()
110 logger.info(
111 f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools_response.tools]}"
112 )
114 except Exception as e:
115 logger.exception("Failed to connect to MCP server:")
116 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
118 def _langchain_tools_from_skills(self, llm):
119 """Override to add MCP query tool along with other tools"""
120 # Get tools from parent implementation
121 tools = super()._langchain_tools_from_skills(llm)
123 # Initialize MCP connection
124 try:
125 # Using the event loop directly instead of asyncio.run()
126 loop = asyncio.get_event_loop()
127 if self.session is None:
128 loop.run_until_complete(self.connect_to_mcp())
130 # Add MCP query tool if session is established
131 if self.session:
132 tools.append(MCPQueryTool(self.session))
133 logger.info("Added MCP query tool to agent tools")
134 except Exception:
135 logger.exception("Failed to add MCP query tool:")
137 return tools
139 def get_completion(self, messages, stream: bool = False):
140 """Override to ensure MCP connection is established before getting completion"""
141 try:
142 # Ensure connection to MCP is established
143 if self.session is None:
144 # Using the event loop directly instead of asyncio.run()
145 loop = asyncio.get_event_loop()
146 loop.run_until_complete(self.connect_to_mcp())
147 except Exception:
148 logger.exception("Failed to connect to MCP server:")
150 # Call parent implementation to get completion
151 response = super().get_completion(messages, stream)
153 # Ensure response is a string (not a DataFrame)
154 if hasattr(response, "to_string"): # It's a DataFrame
155 return response.to_string()
157 return response
159 async def cleanup(self):
160 """Clean up resources"""
161 if self.exit_stack:
162 await self.exit_stack.aclose()
163 self.session = None
164 self.stdio = None
165 self.write = None
168class LiteLLMAgentWrapper:
169 """Wrapper for MCPLangchainAgent that provides LiteLLM-compatible interface"""
171 def __init__(self, agent: MCPLangchainAgent):
172 self.agent = agent
174 async def acompletion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
175 """Async completion interface compatible with LiteLLM"""
176 # Convert messages to format expected by agent
177 formatted_messages = [
178 {
179 "question": msg["content"] if msg["role"] == "user" else "",
180 "answer": msg["content"] if msg["role"] == "assistant" else "",
181 }
182 for msg in messages
183 ]
185 # Get completion from agent
186 response = self.agent.get_completion(formatted_messages)
188 # Ensure response is a string
189 if not isinstance(response, str):
190 if hasattr(response, "to_string"): # It's a DataFrame
191 response = response.to_string()
192 else:
193 response = str(response)
195 # Format response in LiteLLM expected format
196 return {
197 "choices": [{"message": {"role": "assistant", "content": response}}],
198 "model": self.agent.args["model_name"],
199 "object": "chat.completion",
200 }
202 async def acompletion_stream(self, messages: List[Dict[str, str]], **kwargs) -> Iterator[Dict[str, Any]]:
203 """Async streaming completion interface compatible with LiteLLM"""
204 # Convert messages to format expected by agent
205 formatted_messages = [
206 {
207 "question": msg["content"] if msg["role"] == "user" else "",
208 "answer": msg["content"] if msg["role"] == "assistant" else "",
209 }
210 for msg in messages
211 ]
213 # Stream completion from agent
214 model_name = kwargs.get("model", self.agent.args.get("model_name", "mcp-agent"))
215 try:
216 # Handle synchronous generator from _get_completion_stream
217 for chunk in self.agent._get_completion_stream(formatted_messages):
218 content = chunk.get("output", "")
219 if content and isinstance(content, str):
220 yield {
221 "choices": [{"delta": {"role": "assistant", "content": content}}],
222 "model": model_name,
223 "object": "chat.completion.chunk",
224 }
225 # Allow async context switch
226 await asyncio.sleep(0)
227 except Exception:
228 logger.exception("Streaming error:")
229 raise
231 async def cleanup(self):
232 """Clean up resources"""
233 await self.agent.cleanup()
236def create_mcp_agent(
237 agent_name: str, project_name: str, mcp_host: str = "127.0.0.1", mcp_port: int = 47337
238) -> LiteLLMAgentWrapper:
239 """Create an MCP agent and wrap it for LiteLLM compatibility"""
240 from mindsdb.interfaces.agents.agents_controller import AgentsController
241 from mindsdb.interfaces.storage import db
243 # Initialize database
244 db.init()
246 # Get the agent from database
247 agent_controller = AgentsController()
248 agent_db = agent_controller.get_agent(agent_name, project_name)
250 if agent_db is None:
251 raise ValueError(f"Agent {agent_name} not found in project {project_name}")
253 # Get merged parameters (defaults + agent params)
254 llm_params = agent_controller.get_agent_llm_params(agent_db.params)
256 # Create MCP agent with merged parameters
257 mcp_agent = MCPLangchainAgent(agent_db, llm_params=llm_params, mcp_host=mcp_host, mcp_port=mcp_port)
259 # Wrap for LiteLLM compatibility
260 return LiteLLMAgentWrapper(mcp_agent)