Coverage for mindsdb / interfaces / agents / litellm_server.py: 0%

163 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import asyncio 

2import argparse 

3import json 

4 

5from typing import List, Dict, Optional 

6from contextlib import AsyncExitStack 

7 

8import uvicorn 

9from fastapi import FastAPI, HTTPException, BackgroundTasks 

10from fastapi.responses import StreamingResponse 

11from fastapi.middleware.cors import CORSMiddleware 

12from pydantic import BaseModel, Field 

13from mcp import ClientSession, StdioServerParameters 

14from mcp.client.stdio import stdio_client 

15 

16from mindsdb.utilities import log 

17from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent 

18 

19logger = log.getLogger(__name__) 

20 

21app = FastAPI(title="MindsDB MCP Agent LiteLLM API") 

22 

23# Configure CORS 

24app.add_middleware( 

25 CORSMiddleware, 

26 allow_origins=["*"], 

27 allow_credentials=True, 

28 allow_methods=["*"], 

29 allow_headers=["*"], 

30) 

31 

32# Store agent wrapper as a global variable 

33agent_wrapper = None 

34# MCP session for direct SQL queries 

35mcp_session = None 

36exit_stack = AsyncExitStack() 

37 

38 

39class ChatMessage(BaseModel): 

40 role: str 

41 content: str 

42 

43 

44class ChatCompletionRequest(BaseModel): 

45 model: str 

46 messages: List[ChatMessage] 

47 stream: bool = False 

48 temperature: Optional[float] = None 

49 max_tokens: Optional[int] = None 

50 

51 

52class ChatCompletionChoice(BaseModel): 

53 index: int = 0 

54 message: Optional[Dict[str, str]] = None 

55 delta: Optional[Dict[str, str]] = None 

56 finish_reason: Optional[str] = "stop" 

57 

58 

59class ChatCompletionResponse(BaseModel): 

60 id: str = "mcp-agent-response" 

61 object: str = "chat.completion" 

62 created: int = 0 

63 model: str 

64 choices: List[ChatCompletionChoice] 

65 usage: Dict[str, int] = Field( 

66 default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} 

67 ) 

68 

69 

70class DirectSQLRequest(BaseModel): 

71 query: str 

72 

73 

74@app.post("/v1/chat/completions") 

75async def chat_completions(request: ChatCompletionRequest): 

76 global agent_wrapper 

77 

78 if agent_wrapper is None: 

79 raise HTTPException( 

80 status_code=500, 

81 detail="Agent not initialized. Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http", 

82 ) 

83 

84 try: 

85 # Convert request to messages format 

86 messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] 

87 

88 if request.stream: 

89 # Return a streaming response 

90 async def generate(): 

91 try: 

92 async for chunk in agent_wrapper.acompletion_stream(messages, model=request.model): 

93 yield f"data: {json.dumps(chunk)}\n\n" 

94 yield "data: [DONE]\n\n" 

95 except Exception: 

96 logger.exception("Streaming error:") 

97 yield "data: {{'error': 'Streaming failed due to an internal error.'}}\n\n" 

98 

99 return StreamingResponse(generate(), media_type="text/event-stream") 

100 else: 

101 # Return a regular response 

102 response = await agent_wrapper.acompletion(messages) 

103 

104 # Ensure the content is a string 

105 content = response["choices"][0]["message"].get("content", "") 

106 if not isinstance(content, str): 

107 content = str(content) 

108 

109 # Transform to proper OpenAI format 

110 return ChatCompletionResponse( 

111 model=request.model, choices=[ChatCompletionChoice(message={"role": "assistant", "content": content})] 

112 ) 

113 

114 except Exception as e: 

115 logger.exception("Error in chat completion:") 

116 raise HTTPException(status_code=500, detail=str(e)) 

117 

118 

119@app.post("/direct-sql") 

120async def direct_sql(request: DirectSQLRequest, background_tasks: BackgroundTasks): 

121 """Execute a direct SQL query via MCP (for testing)""" 

122 global agent_wrapper, mcp_session 

123 

124 if agent_wrapper is None and mcp_session is None: 

125 raise HTTPException( 

126 status_code=500, detail="No MCP session available. Make sure MindsDB server is running with MCP enabled." 

127 ) 

128 

129 try: 

130 # First try to use the agent's session if available 

131 if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session: 

132 session = agent_wrapper.agent.session 

133 result = await session.call_tool("query", {"query": request.query}) 

134 return {"result": result.content} 

135 # If agent session not available, use the direct session 

136 elif mcp_session: 

137 result = await mcp_session.call_tool("query", {"query": request.query}) 

138 return {"result": result.content} 

139 else: 

140 raise HTTPException(status_code=500, detail="No MCP session available") 

141 

142 except Exception as e: 

143 logger.exception("Error executing direct SQL:") 

144 raise HTTPException(status_code=500, detail=str(e)) 

145 

146 

147@app.get("/v1/models") 

148async def list_models(): 

149 """List available models - always returns the single model we're using""" 

150 global agent_wrapper 

151 

152 if agent_wrapper is None: 

153 return {"object": "list", "data": [{"id": "mcp-agent", "object": "model", "created": 0, "owned_by": "mindsdb"}]} 

154 

155 # Return the actual model name if available 

156 model_name = agent_wrapper.agent.args.get("model_name", "mcp-agent") 

157 

158 return {"object": "list", "data": [{"id": model_name, "object": "model", "created": 0, "owned_by": "mindsdb"}]} 

159 

160 

161@app.get("/health") 

162async def health_check(): 

163 """Health check endpoint""" 

164 global agent_wrapper 

165 

166 health_status = { 

167 "status": "ok", 

168 "agent_initialized": agent_wrapper is not None, 

169 } 

170 

171 if agent_wrapper is not None: 

172 health_status["mcp_connected"] = ( 

173 hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session is not None 

174 ) 

175 health_status["agent_name"] = agent_wrapper.agent.agent.name 

176 health_status["model_name"] = agent_wrapper.agent.args.get("model_name", "unknown") 

177 

178 return health_status 

179 

180 

181@app.get("/test-mcp-connection") 

182async def test_mcp_connection(): 

183 """Test the connection to the MCP server""" 

184 global mcp_session, exit_stack 

185 

186 try: 

187 # If we already have a session, test it 

188 if mcp_session: 

189 try: 

190 tools_response = await mcp_session.list_tools() 

191 return { 

192 "status": "ok", 

193 "message": "Successfully connected to MCP server", 

194 "tools": [tool.name for tool in tools_response.tools], 

195 } 

196 except Exception: 

197 # If error, close existing session and create a new one 

198 await exit_stack.aclose() 

199 mcp_session = None 

200 

201 # Create a new MCP session - connect to running server 

202 server_params = StdioServerParameters(command="python", args=["-m", "mindsdb", "--api=mcp"], env=None) 

203 

204 stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params)) 

205 stdio, write = stdio_transport 

206 session = await exit_stack.enter_async_context(ClientSession(stdio, write)) 

207 

208 await session.initialize() 

209 

210 # Save the session for future use 

211 mcp_session = session 

212 

213 # Get available tools 

214 tools_response = await session.list_tools() 

215 

216 return { 

217 "status": "ok", 

218 "message": "Successfully connected to MCP server", 

219 "tools": [tool.name for tool in tools_response.tools], 

220 } 

221 except Exception as e: 

222 logger.exception("Error connecting to MCP server:") 

223 error_detail = f"Error connecting to MCP server: {str(e)}. Make sure MindsDB server is running with HTTP enabled: python -m mindsdb --api=http" 

224 raise HTTPException(status_code=500, detail=error_detail) 

225 

226 

227async def init_agent(agent_name: str, project_name: str, mcp_host: str, mcp_port: int): 

228 """Initialize the agent""" 

229 global agent_wrapper 

230 

231 try: 

232 logger.info(f"Initializing MCP agent '{agent_name}' in project '{project_name}'") 

233 logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}") 

234 logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http") 

235 

236 agent_wrapper = create_mcp_agent( 

237 agent_name=agent_name, project_name=project_name, mcp_host=mcp_host, mcp_port=mcp_port 

238 ) 

239 

240 logger.info("Agent initialized successfully") 

241 return True 

242 except Exception: 

243 logger.exception("Failed to initialize agent:") 

244 return False 

245 

246 

247@app.on_event("shutdown") 

248async def shutdown_event(): 

249 """Clean up resources on server shutdown""" 

250 global agent_wrapper, exit_stack 

251 

252 if agent_wrapper: 

253 await agent_wrapper.cleanup() 

254 

255 await exit_stack.aclose() 

256 

257 

258async def run_server_async( 

259 agent_name: str, 

260 project_name: str = "mindsdb", 

261 mcp_host: str = "127.0.0.1", 

262 mcp_port: int = 47337, 

263 host: str = "0.0.0.0", 

264 port: int = 8000, 

265): 

266 """Run the FastAPI server""" 

267 # Initialize the agent 

268 success = await init_agent(agent_name, project_name, mcp_host, mcp_port) 

269 if not success: 

270 logger.error("Failed to initialize agent. Make sure MindsDB server is running with MCP enabled.") 

271 return 1 

272 

273 return 0 

274 

275 

276def run_server( 

277 agent_name: str, 

278 project_name: str = "mindsdb", 

279 mcp_host: str = "127.0.0.1", 

280 mcp_port: int = 47337, 

281 host: str = "0.0.0.0", 

282 port: int = 8000, 

283): 

284 """Run the FastAPI server""" 

285 logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http") 

286 # Initialize database 

287 from mindsdb.interfaces.storage import db 

288 

289 db.init() 

290 

291 # Run initialization in the event loop 

292 loop = asyncio.new_event_loop() 

293 asyncio.set_event_loop(loop) 

294 result = loop.run_until_complete(run_server_async(agent_name, project_name, mcp_host, mcp_port)) 

295 if result != 0: 

296 return result 

297 # Run the server 

298 logger.info(f"Starting server on {host}:{port}") 

299 uvicorn.run(app, host=host, port=port) 

300 return 0 

301 

302 

303if __name__ == "__main__": 

304 parser = argparse.ArgumentParser(description="Run a LiteLLM-compatible API server for MCP agent") 

305 parser.add_argument("--agent", type=str, required=True, help="Name of the agent to use") 

306 parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent") 

307 parser.add_argument("--mcp-host", type=str, default="127.0.0.1", help="MCP server host") 

308 parser.add_argument("--mcp-port", type=int, default=47337, help="MCP server port") 

309 parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") 

310 parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") 

311 

312 args = parser.parse_args() 

313 

314 run_server( 

315 agent_name=args.agent, 

316 project_name=args.project, 

317 mcp_host=args.mcp_host, 

318 mcp_port=args.mcp_port, 

319 host=args.host, 

320 port=args.port, 

321 )