Coverage for mindsdb / api / a2a / common / server / server.py: 0%
85 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2import time
3from typing import AsyncIterable, Any, Dict
5from starlette.applications import Starlette
6from starlette.middleware.cors import CORSMiddleware
7from starlette.responses import JSONResponse
8from sse_starlette.sse import EventSourceResponse
9from starlette.requests import Request
10from starlette.routing import Route
11from ...common.types import (
12 A2ARequest,
13 JSONRPCResponse,
14 InvalidRequestError,
15 JSONParseError,
16 GetTaskRequest,
17 CancelTaskRequest,
18 SendTaskRequest,
19 SetTaskPushNotificationRequest,
20 GetTaskPushNotificationRequest,
21 InternalError,
22 AgentCard,
23 TaskResubscriptionRequest,
24 SendTaskStreamingRequest,
25 MessageStreamRequest,
26)
27from pydantic import ValidationError
28from ...common.server.task_manager import TaskManager
30from mindsdb.utilities import log
32logger = log.getLogger(__name__)
35class A2AServer:
36 def __init__(
37 self,
38 agent_card: AgentCard = None,
39 task_manager: TaskManager = None,
40 ):
41 self.task_manager = task_manager
42 self.agent_card = agent_card
43 self.app = Starlette(
44 routes=[
45 Route("/", self._process_request, methods=["POST"]),
46 Route("/.well-known/agent.json", self._get_agent_card, methods=["GET"]),
47 Route("/.well-known/agent-card.json", self._get_agent_card, methods=["GET"]),
48 Route("/status", self._get_status, methods=["GET"]),
49 ]
50 )
51 # TODO: Remove this when we have a proper CORS policy
52 self.app.add_middleware(
53 CORSMiddleware,
54 allow_origins=["*"],
55 allow_credentials=True,
56 allow_methods=["*"],
57 allow_headers=["*"],
58 )
59 self.start_time = time.time()
61 def _get_agent_card(self, request: Request) -> JSONResponse:
62 return JSONResponse(self.agent_card.model_dump(exclude_none=True))
64 def _get_status(self, request: Request) -> JSONResponse:
65 """
66 Status endpoint that returns basic server information.
67 This endpoint can be used by the frontend to check if the A2A server is running.
68 """
69 uptime_seconds = time.time() - self.start_time
71 status_info: Dict[str, Any] = {
72 "status": "ok",
73 "service": "mindsdb-a2a",
74 "uptime_seconds": round(uptime_seconds, 2),
75 "agent_name": self.agent_card.name if self.agent_card else None,
76 "version": self.agent_card.version if self.agent_card else "unknown",
77 }
79 return JSONResponse(status_info)
81 async def _process_request(self, request: Request):
82 try:
83 body = await request.json()
84 json_rpc_request = A2ARequest.validate_python(body)
86 user_info = {
87 "user-id": request.headers.get("user-id", None),
88 "company-id": request.headers.get("company-id", None),
89 "user-class": request.headers.get("user-class", None),
90 "authorization": request.headers.get("Authorization", None),
91 }
93 if isinstance(json_rpc_request, GetTaskRequest):
94 result = await self.task_manager.on_get_task(json_rpc_request)
95 elif isinstance(json_rpc_request, SendTaskRequest):
96 result = await self.task_manager.on_send_task(json_rpc_request, user_info)
97 elif isinstance(json_rpc_request, SendTaskStreamingRequest):
98 # Don't await the async generator, just pass it to _create_response
99 result = self.task_manager.on_send_task_subscribe(json_rpc_request, user_info)
100 elif isinstance(json_rpc_request, CancelTaskRequest):
101 result = await self.task_manager.on_cancel_task(json_rpc_request)
102 elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
103 result = await self.task_manager.on_set_task_push_notification(json_rpc_request)
104 elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
105 result = await self.task_manager.on_get_task_push_notification(json_rpc_request)
106 elif isinstance(json_rpc_request, TaskResubscriptionRequest):
107 result = await self.task_manager.on_resubscribe_to_task(json_rpc_request)
108 elif isinstance(json_rpc_request, MessageStreamRequest):
109 result = await self.task_manager.on_message_stream(json_rpc_request, user_info)
110 else:
111 logger.warning(f"Unexpected request type: {type(json_rpc_request)}")
112 raise ValueError(f"Unexpected request type: {type(request)}")
114 return self._create_response(result)
116 except Exception as e:
117 return self._handle_exception(e)
119 def _handle_exception(self, e: Exception) -> JSONResponse:
120 if isinstance(e, json.decoder.JSONDecodeError):
121 json_rpc_error = JSONParseError()
122 elif isinstance(e, ValidationError):
123 json_rpc_error = InvalidRequestError(data=json.loads(e.json()))
124 else:
125 logger.exception("Unhandled exception:")
126 json_rpc_error = InternalError()
128 response = JSONRPCResponse(id=None, error=json_rpc_error)
129 return JSONResponse(response.model_dump(exclude_none=True), status_code=400)
131 def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse:
132 if isinstance(result, AsyncIterable):
133 # Step 2: Yield actual serialized event as JSON, with timing logs
134 async def event_generator(result):
135 async for item in result:
136 t0 = time.time()
137 logger.debug(f"[A2AServer] STEP2 serializing item at {t0}: {str(item)[:120]}")
138 try:
139 if hasattr(item, "model_dump_json"):
140 data = item.model_dump_json(exclude_none=True)
141 else:
142 data = json.dumps(item)
143 except Exception as e:
144 logger.exception("Serialization error in SSE stream:")
145 data = json.dumps({"error": f"Serialization error: {e}"})
146 yield {"data": data}
148 # Add robust SSE headers for compatibility
149 sse_headers = {
150 "Content-Type": "text/event-stream",
151 "Cache-Control": "no-cache, no-transform",
152 "X-Accel-Buffering": "no",
153 "Connection": "keep-alive",
154 "Transfer-Encoding": "chunked",
155 }
156 return EventSourceResponse(event_generator(result), headers=sse_headers)
157 elif isinstance(result, JSONRPCResponse):
158 return JSONResponse(result.model_dump(exclude_none=True))
159 elif isinstance(result, dict):
160 logger.warning("Falling back to JSONResponse for result type: dict")
161 return JSONResponse(result)
162 else:
163 logger.error(f"Unexpected result type: {type(result)}")
164 raise ValueError(f"Unexpected result type: {type(result)}")