Coverage for mindsdb / interfaces / agents / mcp_client_agent.py: 0%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2import asyncio 

3from typing import Dict, List, Any, Iterator, ClassVar 

4from contextlib import AsyncExitStack 

5 

6import pandas as pd 

7from mcp import ClientSession, StdioServerParameters 

8from mcp.client.stdio import stdio_client 

9 

10from mindsdb.utilities import log 

11from mindsdb.interfaces.agents.langchain_agent import LangchainAgent 

12from mindsdb.interfaces.storage import db 

13from langchain_core.tools import BaseTool 

14 

15logger = log.getLogger(__name__) 

16 

17 

18class MCPQueryTool(BaseTool): 

19 """Tool that executes queries via MCP server""" 

20 

21 name: ClassVar[str] = "mcp_query" 

22 description: ClassVar[str] = "Execute SQL queries against the MindsDB server via MCP protocol" 

23 

24 def __init__(self, session: ClientSession): 

25 super().__init__() 

26 self.session = session 

27 

28 async def _arun(self, query: str) -> str: 

29 """Execute a query via MCP asynchronously""" 

30 try: 

31 logger.info(f"Executing MCP query: {query}") 

32 # Find the appropriate tool for SQL queries 

33 tools_response = await self.session.list_tools() 

34 query_tool = None 

35 

36 for tool in tools_response.tools: 

37 if tool.name == "query": 

38 query_tool = tool 

39 break 

40 

41 if not query_tool: 

42 return "Error: No 'query' tool found in the MCP server" 

43 

44 # Call the query tool 

45 result = await self.session.call_tool("query", {"query": query}) 

46 

47 # Process the results 

48 if isinstance(result.content, dict) and "data" in result.content and "column_names" in result.content: 

49 # Create a DataFrame from the results 

50 df = pd.DataFrame(result.content["data"], columns=result.content["column_names"]) 

51 return df.to_string() 

52 

53 # Return raw result for other types 

54 return f"Query executed successfully: {json.dumps(result.content)}" 

55 

56 except Exception as e: 

57 logger.error("Error executing MCP query:") 

58 return f"Error executing query: {e}" 

59 

60 def _run(self, query: str) -> str: 

61 """Synchronous wrapper for async query function""" 

62 loop = asyncio.get_event_loop() 

63 return loop.run_until_complete(self._arun(query)) 

64 

65 

66# todo move instantiation to agent controller 

67class MCPLangchainAgent(LangchainAgent): 

68 """Extension of LangchainAgent that delegates to MCP server""" 

69 

70 def __init__( 

71 self, 

72 agent: db.Agents, 

73 model: dict = None, 

74 llm_params: dict = None, 

75 mcp_host: str = "127.0.0.1", 

76 mcp_port: int = 47337, 

77 ): 

78 super().__init__(agent, model, llm_params) 

79 self.mcp_host = mcp_host 

80 self.mcp_port = mcp_port 

81 self.exit_stack = AsyncExitStack() 

82 self.session = None 

83 self.stdio = None 

84 self.write = None 

85 

86 async def connect_to_mcp(self): 

87 """Connect to the MCP server using stdio transport""" 

88 if self.session is None: 

89 logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}") 

90 try: 

91 # For connecting to an already running MCP server 

92 # Set up server parameters to connect to existing process 

93 server_params = StdioServerParameters( 

94 command="python", 

95 args=["-m", "mindsdb", "--api=mcp"], 

96 env={"MCP_HOST": self.mcp_host, "MCP_PORT": str(self.mcp_port)}, 

97 ) 

98 

99 logger.info(f"Connecting to MCP server at {self.mcp_host}:{self.mcp_port}") 

100 

101 # Connect to the server 

102 stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) 

103 self.stdio, self.write = stdio_transport 

104 self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) 

105 

106 await self.session.initialize() 

107 

108 # Test the connection by listing tools 

109 tools_response = await self.session.list_tools() 

110 logger.info( 

111 f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools_response.tools]}" 

112 ) 

113 

114 except Exception as e: 

115 logger.exception("Failed to connect to MCP server:") 

116 raise ConnectionError(f"Failed to connect to MCP server: {e}") from e 

117 

118 def _langchain_tools_from_skills(self, llm): 

119 """Override to add MCP query tool along with other tools""" 

120 # Get tools from parent implementation 

121 tools = super()._langchain_tools_from_skills(llm) 

122 

123 # Initialize MCP connection 

124 try: 

125 # Using the event loop directly instead of asyncio.run() 

126 loop = asyncio.get_event_loop() 

127 if self.session is None: 

128 loop.run_until_complete(self.connect_to_mcp()) 

129 

130 # Add MCP query tool if session is established 

131 if self.session: 

132 tools.append(MCPQueryTool(self.session)) 

133 logger.info("Added MCP query tool to agent tools") 

134 except Exception: 

135 logger.exception("Failed to add MCP query tool:") 

136 

137 return tools 

138 

139 def get_completion(self, messages, stream: bool = False): 

140 """Override to ensure MCP connection is established before getting completion""" 

141 try: 

142 # Ensure connection to MCP is established 

143 if self.session is None: 

144 # Using the event loop directly instead of asyncio.run() 

145 loop = asyncio.get_event_loop() 

146 loop.run_until_complete(self.connect_to_mcp()) 

147 except Exception: 

148 logger.exception("Failed to connect to MCP server:") 

149 

150 # Call parent implementation to get completion 

151 response = super().get_completion(messages, stream) 

152 

153 # Ensure response is a string (not a DataFrame) 

154 if hasattr(response, "to_string"): # It's a DataFrame 

155 return response.to_string() 

156 

157 return response 

158 

159 async def cleanup(self): 

160 """Clean up resources""" 

161 if self.exit_stack: 

162 await self.exit_stack.aclose() 

163 self.session = None 

164 self.stdio = None 

165 self.write = None 

166 

167 

168class LiteLLMAgentWrapper: 

169 """Wrapper for MCPLangchainAgent that provides LiteLLM-compatible interface""" 

170 

171 def __init__(self, agent: MCPLangchainAgent): 

172 self.agent = agent 

173 

174 async def acompletion(self, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]: 

175 """Async completion interface compatible with LiteLLM""" 

176 # Convert messages to format expected by agent 

177 formatted_messages = [ 

178 { 

179 "question": msg["content"] if msg["role"] == "user" else "", 

180 "answer": msg["content"] if msg["role"] == "assistant" else "", 

181 } 

182 for msg in messages 

183 ] 

184 

185 # Get completion from agent 

186 response = self.agent.get_completion(formatted_messages) 

187 

188 # Ensure response is a string 

189 if not isinstance(response, str): 

190 if hasattr(response, "to_string"): # It's a DataFrame 

191 response = response.to_string() 

192 else: 

193 response = str(response) 

194 

195 # Format response in LiteLLM expected format 

196 return { 

197 "choices": [{"message": {"role": "assistant", "content": response}}], 

198 "model": self.agent.args["model_name"], 

199 "object": "chat.completion", 

200 } 

201 

202 async def acompletion_stream(self, messages: List[Dict[str, str]], **kwargs) -> Iterator[Dict[str, Any]]: 

203 """Async streaming completion interface compatible with LiteLLM""" 

204 # Convert messages to format expected by agent 

205 formatted_messages = [ 

206 { 

207 "question": msg["content"] if msg["role"] == "user" else "", 

208 "answer": msg["content"] if msg["role"] == "assistant" else "", 

209 } 

210 for msg in messages 

211 ] 

212 

213 # Stream completion from agent 

214 model_name = kwargs.get("model", self.agent.args.get("model_name", "mcp-agent")) 

215 try: 

216 # Handle synchronous generator from _get_completion_stream 

217 for chunk in self.agent._get_completion_stream(formatted_messages): 

218 content = chunk.get("output", "") 

219 if content and isinstance(content, str): 

220 yield { 

221 "choices": [{"delta": {"role": "assistant", "content": content}}], 

222 "model": model_name, 

223 "object": "chat.completion.chunk", 

224 } 

225 # Allow async context switch 

226 await asyncio.sleep(0) 

227 except Exception: 

228 logger.exception("Streaming error:") 

229 raise 

230 

231 async def cleanup(self): 

232 """Clean up resources""" 

233 await self.agent.cleanup() 

234 

235 

236def create_mcp_agent( 

237 agent_name: str, project_name: str, mcp_host: str = "127.0.0.1", mcp_port: int = 47337 

238) -> LiteLLMAgentWrapper: 

239 """Create an MCP agent and wrap it for LiteLLM compatibility""" 

240 from mindsdb.interfaces.agents.agents_controller import AgentsController 

241 from mindsdb.interfaces.storage import db 

242 

243 # Initialize database 

244 db.init() 

245 

246 # Get the agent from database 

247 agent_controller = AgentsController() 

248 agent_db = agent_controller.get_agent(agent_name, project_name) 

249 

250 if agent_db is None: 

251 raise ValueError(f"Agent {agent_name} not found in project {project_name}") 

252 

253 # Get merged parameters (defaults + agent params) 

254 llm_params = agent_controller.get_agent_llm_params(agent_db.params) 

255 

256 # Create MCP agent with merged parameters 

257 mcp_agent = MCPLangchainAgent(agent_db, llm_params=llm_params, mcp_host=mcp_host, mcp_port=mcp_port) 

258 

259 # Wrap for LiteLLM compatibility 

260 return LiteLLMAgentWrapper(mcp_agent)