Coverage for mindsdb / api / a2a / common / server / server.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2import time 

3from typing import AsyncIterable, Any, Dict 

4 

5from starlette.applications import Starlette 

6from starlette.middleware.cors import CORSMiddleware 

7from starlette.responses import JSONResponse 

8from sse_starlette.sse import EventSourceResponse 

9from starlette.requests import Request 

10from starlette.routing import Route 

11from ...common.types import ( 

12 A2ARequest, 

13 JSONRPCResponse, 

14 InvalidRequestError, 

15 JSONParseError, 

16 GetTaskRequest, 

17 CancelTaskRequest, 

18 SendTaskRequest, 

19 SetTaskPushNotificationRequest, 

20 GetTaskPushNotificationRequest, 

21 InternalError, 

22 AgentCard, 

23 TaskResubscriptionRequest, 

24 SendTaskStreamingRequest, 

25 MessageStreamRequest, 

26) 

27from pydantic import ValidationError 

28from ...common.server.task_manager import TaskManager 

29 

30from mindsdb.utilities import log 

31 

32logger = log.getLogger(__name__) 

33 

34 

35class A2AServer: 

36 def __init__( 

37 self, 

38 agent_card: AgentCard = None, 

39 task_manager: TaskManager = None, 

40 ): 

41 self.task_manager = task_manager 

42 self.agent_card = agent_card 

43 self.app = Starlette( 

44 routes=[ 

45 Route("/", self._process_request, methods=["POST"]), 

46 Route("/.well-known/agent.json", self._get_agent_card, methods=["GET"]), 

47 Route("/.well-known/agent-card.json", self._get_agent_card, methods=["GET"]), 

48 Route("/status", self._get_status, methods=["GET"]), 

49 ] 

50 ) 

51 # TODO: Remove this when we have a proper CORS policy 

52 self.app.add_middleware( 

53 CORSMiddleware, 

54 allow_origins=["*"], 

55 allow_credentials=True, 

56 allow_methods=["*"], 

57 allow_headers=["*"], 

58 ) 

59 self.start_time = time.time() 

60 

61 def _get_agent_card(self, request: Request) -> JSONResponse: 

62 return JSONResponse(self.agent_card.model_dump(exclude_none=True)) 

63 

64 def _get_status(self, request: Request) -> JSONResponse: 

65 """ 

66 Status endpoint that returns basic server information. 

67 This endpoint can be used by the frontend to check if the A2A server is running. 

68 """ 

69 uptime_seconds = time.time() - self.start_time 

70 

71 status_info: Dict[str, Any] = { 

72 "status": "ok", 

73 "service": "mindsdb-a2a", 

74 "uptime_seconds": round(uptime_seconds, 2), 

75 "agent_name": self.agent_card.name if self.agent_card else None, 

76 "version": self.agent_card.version if self.agent_card else "unknown", 

77 } 

78 

79 return JSONResponse(status_info) 

80 

81 async def _process_request(self, request: Request): 

82 try: 

83 body = await request.json() 

84 json_rpc_request = A2ARequest.validate_python(body) 

85 

86 user_info = { 

87 "user-id": request.headers.get("user-id", None), 

88 "company-id": request.headers.get("company-id", None), 

89 "user-class": request.headers.get("user-class", None), 

90 "authorization": request.headers.get("Authorization", None), 

91 } 

92 

93 if isinstance(json_rpc_request, GetTaskRequest): 

94 result = await self.task_manager.on_get_task(json_rpc_request) 

95 elif isinstance(json_rpc_request, SendTaskRequest): 

96 result = await self.task_manager.on_send_task(json_rpc_request, user_info) 

97 elif isinstance(json_rpc_request, SendTaskStreamingRequest): 

98 # Don't await the async generator, just pass it to _create_response 

99 result = self.task_manager.on_send_task_subscribe(json_rpc_request, user_info) 

100 elif isinstance(json_rpc_request, CancelTaskRequest): 

101 result = await self.task_manager.on_cancel_task(json_rpc_request) 

102 elif isinstance(json_rpc_request, SetTaskPushNotificationRequest): 

103 result = await self.task_manager.on_set_task_push_notification(json_rpc_request) 

104 elif isinstance(json_rpc_request, GetTaskPushNotificationRequest): 

105 result = await self.task_manager.on_get_task_push_notification(json_rpc_request) 

106 elif isinstance(json_rpc_request, TaskResubscriptionRequest): 

107 result = await self.task_manager.on_resubscribe_to_task(json_rpc_request) 

108 elif isinstance(json_rpc_request, MessageStreamRequest): 

109 result = await self.task_manager.on_message_stream(json_rpc_request, user_info) 

110 else: 

111 logger.warning(f"Unexpected request type: {type(json_rpc_request)}") 

112 raise ValueError(f"Unexpected request type: {type(request)}") 

113 

114 return self._create_response(result) 

115 

116 except Exception as e: 

117 return self._handle_exception(e) 

118 

119 def _handle_exception(self, e: Exception) -> JSONResponse: 

120 if isinstance(e, json.decoder.JSONDecodeError): 

121 json_rpc_error = JSONParseError() 

122 elif isinstance(e, ValidationError): 

123 json_rpc_error = InvalidRequestError(data=json.loads(e.json())) 

124 else: 

125 logger.exception("Unhandled exception:") 

126 json_rpc_error = InternalError() 

127 

128 response = JSONRPCResponse(id=None, error=json_rpc_error) 

129 return JSONResponse(response.model_dump(exclude_none=True), status_code=400) 

130 

131 def _create_response(self, result: Any) -> JSONResponse | EventSourceResponse: 

132 if isinstance(result, AsyncIterable): 

133 # Step 2: Yield actual serialized event as JSON, with timing logs 

134 async def event_generator(result): 

135 async for item in result: 

136 t0 = time.time() 

137 logger.debug(f"[A2AServer] STEP2 serializing item at {t0}: {str(item)[:120]}") 

138 try: 

139 if hasattr(item, "model_dump_json"): 

140 data = item.model_dump_json(exclude_none=True) 

141 else: 

142 data = json.dumps(item) 

143 except Exception as e: 

144 logger.exception("Serialization error in SSE stream:") 

145 data = json.dumps({"error": f"Serialization error: {e}"}) 

146 yield {"data": data} 

147 

148 # Add robust SSE headers for compatibility 

149 sse_headers = { 

150 "Content-Type": "text/event-stream", 

151 "Cache-Control": "no-cache, no-transform", 

152 "X-Accel-Buffering": "no", 

153 "Connection": "keep-alive", 

154 "Transfer-Encoding": "chunked", 

155 } 

156 return EventSourceResponse(event_generator(result), headers=sse_headers) 

157 elif isinstance(result, JSONRPCResponse): 

158 return JSONResponse(result.model_dump(exclude_none=True)) 

159 elif isinstance(result, dict): 

160 logger.warning("Falling back to JSONResponse for result type: dict") 

161 return JSONResponse(result) 

162 else: 

163 logger.error(f"Unexpected result type: {type(result)}") 

164 raise ValueError(f"Unexpected result type: {type(result)}")