Coverage for mindsdb / interfaces / agents / litellm_server.py: 0%
163 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import asyncio
2import argparse
3import json
5from typing import List, Dict, Optional
6from contextlib import AsyncExitStack
8import uvicorn
9from fastapi import FastAPI, HTTPException, BackgroundTasks
10from fastapi.responses import StreamingResponse
11from fastapi.middleware.cors import CORSMiddleware
12from pydantic import BaseModel, Field
13from mcp import ClientSession, StdioServerParameters
14from mcp.client.stdio import stdio_client
16from mindsdb.utilities import log
17from mindsdb.interfaces.agents.mcp_client_agent import create_mcp_agent
19logger = log.getLogger(__name__)
21app = FastAPI(title="MindsDB MCP Agent LiteLLM API")
23# Configure CORS
24app.add_middleware(
25 CORSMiddleware,
26 allow_origins=["*"],
27 allow_credentials=True,
28 allow_methods=["*"],
29 allow_headers=["*"],
30)
32# Store agent wrapper as a global variable
33agent_wrapper = None
34# MCP session for direct SQL queries
35mcp_session = None
36exit_stack = AsyncExitStack()
39class ChatMessage(BaseModel):
40 role: str
41 content: str
44class ChatCompletionRequest(BaseModel):
45 model: str
46 messages: List[ChatMessage]
47 stream: bool = False
48 temperature: Optional[float] = None
49 max_tokens: Optional[int] = None
52class ChatCompletionChoice(BaseModel):
53 index: int = 0
54 message: Optional[Dict[str, str]] = None
55 delta: Optional[Dict[str, str]] = None
56 finish_reason: Optional[str] = "stop"
59class ChatCompletionResponse(BaseModel):
60 id: str = "mcp-agent-response"
61 object: str = "chat.completion"
62 created: int = 0
63 model: str
64 choices: List[ChatCompletionChoice]
65 usage: Dict[str, int] = Field(
66 default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
67 )
70class DirectSQLRequest(BaseModel):
71 query: str
74@app.post("/v1/chat/completions")
75async def chat_completions(request: ChatCompletionRequest):
76 global agent_wrapper
78 if agent_wrapper is None:
79 raise HTTPException(
80 status_code=500,
81 detail="Agent not initialized. Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http",
82 )
84 try:
85 # Convert request to messages format
86 messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
88 if request.stream:
89 # Return a streaming response
90 async def generate():
91 try:
92 async for chunk in agent_wrapper.acompletion_stream(messages, model=request.model):
93 yield f"data: {json.dumps(chunk)}\n\n"
94 yield "data: [DONE]\n\n"
95 except Exception:
96 logger.exception("Streaming error:")
97 yield "data: {{'error': 'Streaming failed due to an internal error.'}}\n\n"
99 return StreamingResponse(generate(), media_type="text/event-stream")
100 else:
101 # Return a regular response
102 response = await agent_wrapper.acompletion(messages)
104 # Ensure the content is a string
105 content = response["choices"][0]["message"].get("content", "")
106 if not isinstance(content, str):
107 content = str(content)
109 # Transform to proper OpenAI format
110 return ChatCompletionResponse(
111 model=request.model, choices=[ChatCompletionChoice(message={"role": "assistant", "content": content})]
112 )
114 except Exception as e:
115 logger.exception("Error in chat completion:")
116 raise HTTPException(status_code=500, detail=str(e))
119@app.post("/direct-sql")
120async def direct_sql(request: DirectSQLRequest, background_tasks: BackgroundTasks):
121 """Execute a direct SQL query via MCP (for testing)"""
122 global agent_wrapper, mcp_session
124 if agent_wrapper is None and mcp_session is None:
125 raise HTTPException(
126 status_code=500, detail="No MCP session available. Make sure MindsDB server is running with MCP enabled."
127 )
129 try:
130 # First try to use the agent's session if available
131 if hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session:
132 session = agent_wrapper.agent.session
133 result = await session.call_tool("query", {"query": request.query})
134 return {"result": result.content}
135 # If agent session not available, use the direct session
136 elif mcp_session:
137 result = await mcp_session.call_tool("query", {"query": request.query})
138 return {"result": result.content}
139 else:
140 raise HTTPException(status_code=500, detail="No MCP session available")
142 except Exception as e:
143 logger.exception("Error executing direct SQL:")
144 raise HTTPException(status_code=500, detail=str(e))
147@app.get("/v1/models")
148async def list_models():
149 """List available models - always returns the single model we're using"""
150 global agent_wrapper
152 if agent_wrapper is None:
153 return {"object": "list", "data": [{"id": "mcp-agent", "object": "model", "created": 0, "owned_by": "mindsdb"}]}
155 # Return the actual model name if available
156 model_name = agent_wrapper.agent.args.get("model_name", "mcp-agent")
158 return {"object": "list", "data": [{"id": model_name, "object": "model", "created": 0, "owned_by": "mindsdb"}]}
161@app.get("/health")
162async def health_check():
163 """Health check endpoint"""
164 global agent_wrapper
166 health_status = {
167 "status": "ok",
168 "agent_initialized": agent_wrapper is not None,
169 }
171 if agent_wrapper is not None:
172 health_status["mcp_connected"] = (
173 hasattr(agent_wrapper.agent, "session") and agent_wrapper.agent.session is not None
174 )
175 health_status["agent_name"] = agent_wrapper.agent.agent.name
176 health_status["model_name"] = agent_wrapper.agent.args.get("model_name", "unknown")
178 return health_status
181@app.get("/test-mcp-connection")
182async def test_mcp_connection():
183 """Test the connection to the MCP server"""
184 global mcp_session, exit_stack
186 try:
187 # If we already have a session, test it
188 if mcp_session:
189 try:
190 tools_response = await mcp_session.list_tools()
191 return {
192 "status": "ok",
193 "message": "Successfully connected to MCP server",
194 "tools": [tool.name for tool in tools_response.tools],
195 }
196 except Exception:
197 # If error, close existing session and create a new one
198 await exit_stack.aclose()
199 mcp_session = None
201 # Create a new MCP session - connect to running server
202 server_params = StdioServerParameters(command="python", args=["-m", "mindsdb", "--api=mcp"], env=None)
204 stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
205 stdio, write = stdio_transport
206 session = await exit_stack.enter_async_context(ClientSession(stdio, write))
208 await session.initialize()
210 # Save the session for future use
211 mcp_session = session
213 # Get available tools
214 tools_response = await session.list_tools()
216 return {
217 "status": "ok",
218 "message": "Successfully connected to MCP server",
219 "tools": [tool.name for tool in tools_response.tools],
220 }
221 except Exception as e:
222 logger.exception("Error connecting to MCP server:")
223 error_detail = f"Error connecting to MCP server: {str(e)}. Make sure MindsDB server is running with HTTP enabled: python -m mindsdb --api=http"
224 raise HTTPException(status_code=500, detail=error_detail)
227async def init_agent(agent_name: str, project_name: str, mcp_host: str, mcp_port: int):
228 """Initialize the agent"""
229 global agent_wrapper
231 try:
232 logger.info(f"Initializing MCP agent '{agent_name}' in project '{project_name}'")
233 logger.info(f"Connecting to MCP server at {mcp_host}:{mcp_port}")
234 logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
236 agent_wrapper = create_mcp_agent(
237 agent_name=agent_name, project_name=project_name, mcp_host=mcp_host, mcp_port=mcp_port
238 )
240 logger.info("Agent initialized successfully")
241 return True
242 except Exception:
243 logger.exception("Failed to initialize agent:")
244 return False
247@app.on_event("shutdown")
248async def shutdown_event():
249 """Clean up resources on server shutdown"""
250 global agent_wrapper, exit_stack
252 if agent_wrapper:
253 await agent_wrapper.cleanup()
255 await exit_stack.aclose()
258async def run_server_async(
259 agent_name: str,
260 project_name: str = "mindsdb",
261 mcp_host: str = "127.0.0.1",
262 mcp_port: int = 47337,
263 host: str = "0.0.0.0",
264 port: int = 8000,
265):
266 """Run the FastAPI server"""
267 # Initialize the agent
268 success = await init_agent(agent_name, project_name, mcp_host, mcp_port)
269 if not success:
270 logger.error("Failed to initialize agent. Make sure MindsDB server is running with MCP enabled.")
271 return 1
273 return 0
276def run_server(
277 agent_name: str,
278 project_name: str = "mindsdb",
279 mcp_host: str = "127.0.0.1",
280 mcp_port: int = 47337,
281 host: str = "0.0.0.0",
282 port: int = 8000,
283):
284 """Run the FastAPI server"""
285 logger.info("Make sure MindsDB server is running with MCP enabled: python -m mindsdb --api=mysql,mcp,http")
286 # Initialize database
287 from mindsdb.interfaces.storage import db
289 db.init()
291 # Run initialization in the event loop
292 loop = asyncio.new_event_loop()
293 asyncio.set_event_loop(loop)
294 result = loop.run_until_complete(run_server_async(agent_name, project_name, mcp_host, mcp_port))
295 if result != 0:
296 return result
297 # Run the server
298 logger.info(f"Starting server on {host}:{port}")
299 uvicorn.run(app, host=host, port=port)
300 return 0
303if __name__ == "__main__":
304 parser = argparse.ArgumentParser(description="Run a LiteLLM-compatible API server for MCP agent")
305 parser.add_argument("--agent", type=str, required=True, help="Name of the agent to use")
306 parser.add_argument("--project", type=str, default="mindsdb", help="Project containing the agent")
307 parser.add_argument("--mcp-host", type=str, default="127.0.0.1", help="MCP server host")
308 parser.add_argument("--mcp-port", type=int, default=47337, help="MCP server port")
309 parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
310 parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
312 args = parser.parse_args()
314 run_server(
315 agent_name=args.agent,
316 project_name=args.project,
317 mcp_host=args.mcp_host,
318 mcp_port=args.mcp_port,
319 host=args.host,
320 port=args.port,
321 )