Coverage for mindsdb / api / a2a / task_manager.py: 0%

270 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import time 

2import logging 

3import asyncio 

4from typing import AsyncIterable, Dict, Union 

5 

6from mindsdb.api.a2a.common.types import ( 

7 SendTaskRequest, 

8 TaskSendParams, 

9 Message, 

10 TaskStatus, 

11 Artifact, 

12 TaskStatusUpdateEvent, 

13 TaskArtifactUpdateEvent, 

14 TaskState, 

15 Task, 

16 SendTaskResponse, 

17 InternalError, 

18 JSONRPCResponse, 

19 SendTaskStreamingRequest, 

20 SendTaskStreamingResponse, 

21 InvalidRequestError, 

22 MessageStreamRequest, 

23 SendStreamingMessageSuccessResponse, 

24) 

25from mindsdb.api.a2a.common.server.task_manager import InMemoryTaskManager 

26from mindsdb.api.a2a.agent import MindsDBAgent 

27from mindsdb.api.a2a.utils import to_serializable, convert_a2a_message_to_qa_format 

28from mindsdb.interfaces.agents.agents_controller import AgentsController 

29 

30 

31logger = logging.getLogger(__name__) 

32 

33 

34def to_question_format(messages): 

35 """Convert A2A messages to a list of {"question": ...} dicts for agent compatibility.""" 

36 out = [] 

37 for msg in messages: 

38 if "question" in msg: 

39 out.append(msg) 

40 elif "parts" in msg and isinstance(msg["parts"], list): 

41 for part in msg["parts"]: 

42 part_dict = to_serializable(part) 

43 if part_dict.get("type") == "text" and "text" in part_dict: 

44 out.append({"question": part_dict["text"]}) 

45 return out 

46 

47 

48class AgentTaskManager(InMemoryTaskManager): 

49 def __init__( 

50 self, 

51 project_name: str, 

52 agent_name: str = None, 

53 ): 

54 super().__init__() 

55 self.project_name = project_name 

56 self.agent_name = agent_name 

57 self.tasks = {} # Task storage 

58 self.lock = asyncio.Lock() # Lock for task operations 

59 

60 def _create_agent(self, user_info: Dict, agent_name: str = None) -> MindsDBAgent: 

61 """Create a new MindsDBAgent instance for the given agent name.""" 

62 if not agent_name: 

63 raise ValueError("Agent name is required but was not provided in the request") 

64 

65 return MindsDBAgent( 

66 agent_name=agent_name, 

67 project_name=self.project_name, 

68 user_info=user_info, 

69 ) 

70 

71 async def _stream_generator( 

72 self, request: SendTaskStreamingRequest, user_info: Dict 

73 ) -> AsyncIterable[SendTaskStreamingResponse]: 

74 task_send_params: TaskSendParams = request.params 

75 query = self._get_user_query(task_send_params) 

76 params = self._get_task_params(task_send_params) 

77 agent_name = params["agent_name"] 

78 streaming = params["streaming"] 

79 

80 # Create and store the task first to ensure it exists 

81 try: 

82 task = await self.upsert_task(task_send_params) 

83 logger.info(f"Task created/updated with history length: {len(task.history) if task.history else 0}") 

84 except Exception as e: 

85 logger.exception("Error creating task:") 

86 error_result = to_serializable( 

87 { 

88 "id": request.id, 

89 "error": to_serializable(InternalError(message=f"Error creating task: {e}")), 

90 } 

91 ) 

92 yield error_result 

93 return # Early return from generator 

94 

95 agent = self._create_agent(user_info, agent_name) 

96 

97 # Get the history from the task object (where it was properly extracted and stored) 

98 history = task.history if task and task.history else [] 

99 

100 if not streaming: 

101 # If streaming is disabled, use invoke and return a single response 

102 try: 

103 result = agent.invoke(query, task_send_params.sessionId, history=history) 

104 

105 # Use the parts from the agent response if available, or create them 

106 if "parts" in result: 

107 parts = result["parts"] 

108 else: 

109 result_text = result.get("content", "No response from MindsDB") 

110 parts = [{"type": "text", "text": result_text}] 

111 

112 # Check if we have structured data 

113 if "data" in result and result["data"]: 

114 parts.append( 

115 { 

116 "type": "data", 

117 "data": result["data"], 

118 "metadata": {"subtype": "json"}, 

119 } 

120 ) 

121 

122 # Create and yield the final response 

123 task_state = TaskState.COMPLETED 

124 artifact = Artifact(parts=parts, index=0, append=False) 

125 task_status = TaskStatus(state=task_state) 

126 

127 # Update the task store 

128 await self._update_store(task_send_params.id, task_status, [artifact]) 

129 

130 # Yield the artifact update 

131 yield to_serializable( 

132 SendTaskStreamingResponse( 

133 id=request.id, 

134 result=to_serializable(TaskArtifactUpdateEvent(id=task_send_params.id, artifact=artifact)), 

135 ) 

136 ) 

137 

138 # Yield the final status update 

139 yield to_serializable( 

140 SendTaskStreamingResponse( 

141 id=request.id, 

142 result=to_serializable( 

143 TaskStatusUpdateEvent( 

144 id=task_send_params.id, 

145 status=to_serializable(TaskStatus(state=task_status.state)), 

146 final=True, 

147 ) 

148 ), 

149 ) 

150 ) 

151 return 

152 

153 except Exception as e: 

154 logger.exception("Error invoking agent:") 

155 error_result = to_serializable( 

156 { 

157 "id": request.id, 

158 "error": to_serializable( 

159 JSONRPCResponse( 

160 id=request.id, 

161 error=to_serializable(InternalError(message=f"Error invoking agent: {e}")), 

162 ) 

163 ), 

164 } 

165 ) 

166 yield error_result 

167 return 

168 

169 # If streaming is enabled (default), use the streaming implementation 

170 try: 

171 logger.debug(f"Entering agent.stream() at {time.time()}") 

172 # Create A2A message structure and convert using centralized utility 

173 a2a_message = task_send_params.message.model_dump() 

174 logger.debug(f"History: {history}") 

175 if history: 

176 a2a_message["history"] = [msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in history] 

177 

178 # Convert to Q&A format using centralized utility function 

179 all_messages = convert_a2a_message_to_qa_format(a2a_message) 

180 

181 async for item in agent.streaming_invoke(all_messages, timeout=60): 

182 # Clean up: Remove verbose debug logs, keep only errors and essential info 

183 if isinstance(item, dict) and "artifact" in item and "parts" in item["artifact"]: 

184 item["artifact"]["parts"] = [to_serializable(p) for p in item["artifact"]["parts"]] 

185 yield to_serializable(item) 

186 except TimeoutError as e: 

187 logger.error(f"Timeout error while streaming the response: {e}") 

188 error_text = "The request timed out. The agent is taking longer than expected to respond. Please try again or increase the timeout." 

189 parts = [{"type": "text", "text": error_text}] 

190 parts = [to_serializable(part) for part in parts] 

191 artifact = { 

192 "parts": parts, 

193 "index": 0, 

194 "append": False, 

195 } 

196 error_result = { 

197 "id": request.id, 

198 "error": { 

199 "id": task_send_params.id, 

200 "artifact": artifact, 

201 "error_type": "timeout", 

202 }, 

203 } 

204 yield error_result 

205 except ConnectionError as e: 

206 logger.error(f"Connection error while streaming the response: {e}") 

207 error_text = "Failed to connect to the agent. Please check if the agent is running and accessible." 

208 parts = [{"type": "text", "text": error_text}] 

209 parts = [to_serializable(part) for part in parts] 

210 artifact = { 

211 "parts": parts, 

212 "index": 0, 

213 "append": False, 

214 } 

215 error_result = { 

216 "id": request.id, 

217 "error": { 

218 "id": task_send_params.id, 

219 "artifact": artifact, 

220 "error_type": "connection", 

221 }, 

222 } 

223 yield error_result 

224 except Exception as e: 

225 logger.exception("An error occurred while streaming the response:") 

226 # Provide more specific error messages based on error type 

227 if "API key" in str(e) or "authentication" in str(e).lower(): 

228 error_text = f"Authentication error: {str(e)}" 

229 error_category = "authentication" 

230 elif "404" in str(e) or "not found" in str(e).lower(): 

231 error_text = f"Resource not found: {str(e)}" 

232 error_category = "not_found" 

233 elif "rate limit" in str(e).lower() or "429" in str(e): 

234 error_text = f"Rate limit exceeded: {str(e)}" 

235 error_category = "rate_limit" 

236 else: 

237 error_text = f"An error occurred while streaming the response: {str(e)}" 

238 error_category = "general" 

239 

240 # Ensure all parts are plain dicts 

241 parts = [{"type": "text", "text": error_text}] 

242 parts = [to_serializable(part) for part in parts] 

243 artifact = { 

244 "parts": parts, 

245 "index": 0, 

246 "append": False, 

247 } 

248 error_result = { 

249 "id": request.id, 

250 "error": { 

251 "id": task_send_params.id, 

252 "artifact": artifact, 

253 "error_type": error_category, 

254 }, 

255 } 

256 yield error_result 

257 

258 async def upsert_task(self, task_send_params: TaskSendParams) -> Task: 

259 """Create or update a task in the task store. 

260 

261 Args: 

262 task_send_params: The parameters for the task. 

263 

264 Returns: 

265 The created or updated task. 

266 """ 

267 logger.info(f"Upserting task {task_send_params.id}") 

268 async with self.lock: 

269 task = self.tasks.get(task_send_params.id) 

270 if task is None: 

271 # Convert the message to a dict if it's not already one 

272 message = task_send_params.message 

273 message_dict = message.dict() if hasattr(message, "dict") else message 

274 

275 # Get history from request if available - check both locations 

276 history = [] 

277 

278 # First check if history is at top level (task_send_params.history) 

279 if hasattr(task_send_params, "history") and task_send_params.history: 

280 # Convert each history item to dict if needed 

281 for item in task_send_params.history: 

282 item_dict = item.model_dump() if hasattr(item, "model_dump") else item 

283 history.append(item_dict) 

284 # Also check if history is nested under message (message.history) 

285 elif hasattr(task_send_params.message, "history") and task_send_params.message.history: 

286 for item in task_send_params.message.history: 

287 item_dict = item.model_dump() if hasattr(item, "model_dump") else item 

288 history.append(item_dict) 

289 

290 # DO NOT add current message to history - it should be processed separately 

291 # The current message will be extracted during streaming from task_send_params.message 

292 

293 # Create a new task 

294 task = Task( 

295 id=task_send_params.id, 

296 sessionId=task_send_params.sessionId, 

297 status=TaskStatus(state=TaskState.SUBMITTED), 

298 history=history, 

299 artifacts=[], 

300 ) 

301 self.tasks[task_send_params.id] = task 

302 else: 

303 # Convert the message to a dict if it's not already one 

304 message = task_send_params.message 

305 message_dict = message.dict() if hasattr(message, "dict") else message 

306 

307 # Update the existing task 

308 if task.history is None: 

309 task.history = [] 

310 

311 # If we have new history from the request, use it 

312 if hasattr(task_send_params, "history") and task_send_params.history: 

313 # Convert each history item to dict if needed and ensure proper role 

314 history = [] 

315 for item in task_send_params.history: 

316 item_dict = item.dict() if hasattr(item, "dict") else item 

317 # Ensure the role is properly set 

318 if "role" not in item_dict: 

319 item_dict["role"] = "assistant" if "answer" in item_dict else "user" 

320 history.append(item_dict) 

321 task.history = history 

322 

323 # Add current message to history 

324 task.history.append(message_dict) 

325 return task 

326 

327 def _validate_request( 

328 self, request: Union[SendTaskRequest, SendTaskStreamingRequest] 

329 ) -> Union[None, JSONRPCResponse]: 

330 """Validate the request and return an error response if invalid.""" 

331 # Check if the request has the required parameters 

332 if not hasattr(request, "params") or not request.params: 

333 return JSONRPCResponse( 

334 id=request.id, 

335 error=InvalidRequestError(message="Missing params"), 

336 ) 

337 

338 # Check if the request has a message 

339 if not hasattr(request.params, "message") or not request.params.message: 

340 return JSONRPCResponse( 

341 id=request.id, 

342 error=InvalidRequestError(message="Missing message in params"), 

343 ) 

344 

345 # Check if the message has metadata 

346 if not hasattr(request.params.message, "metadata") or not request.params.message.metadata: 

347 return JSONRPCResponse( 

348 id=request.id, 

349 error=InvalidRequestError(message="Missing metadata in message"), 

350 ) 

351 

352 # Check if the agent name is provided in the metadata 

353 metadata = request.params.message.metadata 

354 agent_name = metadata.get("agent_name", metadata.get("agentName")) 

355 if not agent_name: 

356 return JSONRPCResponse( 

357 id=request.id, 

358 error=InvalidRequestError( 

359 message="Agent name is required but was not provided in the request metadata" 

360 ), 

361 ) 

362 

363 return None 

364 

365 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse: 

366 error = self._validate_request(request) 

367 if error: 

368 return error 

369 

370 return await self._invoke(request, user_info=user_info) 

371 

372 async def on_send_task_subscribe( 

373 self, request: SendTaskStreamingRequest, user_info: Dict 

374 ) -> AsyncIterable[SendTaskStreamingResponse]: 

375 error = self._validate_request(request) 

376 if error: 

377 logger.info(f"Yielding error at {time.time()} for invalid request: {error}") 

378 yield to_serializable(SendTaskStreamingResponse(id=request.id, error=to_serializable(error.error))) 

379 return 

380 

381 # We can't await an async generator directly, so we need to use it as is 

382 try: 

383 logger.debug(f"Entering streaming path at {time.time()}") 

384 async for response in self._stream_generator(request, user_info): 

385 logger.debug(f"Yielding streaming response at {time.time()} with: {str(response)[:120]}") 

386 yield response 

387 except Exception as e: 

388 # If an error occurs, yield an error response 

389 logger.exception(f"Error in on_send_task_subscribe: {e}") 

390 error_result = to_serializable( 

391 { 

392 "id": request.id, 

393 "error": to_serializable(InternalError(message=f"Error processing streaming request: {e}")), 

394 } 

395 ) 

396 yield error_result 

397 

398 async def _update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task: 

399 async with self.lock: 

400 try: 

401 task = self.tasks[task_id] 

402 except KeyError: 

403 logger.error(f"Task {task_id} not found for updating the task") 

404 # Create a new task with the provided ID if it doesn't exist 

405 # This ensures we don't fail when a task is not found 

406 task = Task( 

407 id=task_id, 

408 sessionId="recovery-session", # Use a placeholder session ID 

409 messages=[], # No messages available 

410 status=status, # Use the provided status 

411 history=[], # No history available 

412 ) 

413 self.tasks[task_id] = task 

414 

415 task.status = status 

416 

417 # Store assistant's response in history if we have a message 

418 if status.message and status.message.role == "agent": 

419 if task.history is None: 

420 task.history = [] 

421 # Convert message to dict if needed 

422 message_dict = status.message.dict() if hasattr(status.message, "dict") else status.message 

423 # Ensure role is set to assistant 

424 message_dict["role"] = "assistant" 

425 task.history.append(message_dict) 

426 

427 if artifacts is not None: 

428 for artifact in artifacts: 

429 if artifact.append and len(task.artifacts) > 0: 

430 # Append to the last artifact 

431 last_artifact = task.artifacts[-1] 

432 for part in artifact.parts: 

433 last_artifact.parts.append(part) 

434 else: 

435 # Add as a new artifact 

436 task.artifacts.append(artifact) 

437 return task 

438 

439 def _get_user_query(self, task_send_params: TaskSendParams) -> str: 

440 """Extract the user query from the task parameters.""" 

441 message = task_send_params.message 

442 if not message.parts: 

443 return "" 

444 

445 # Find the first text part 

446 for part in message.parts: 

447 if part.type == "text": 

448 return part.text 

449 

450 # If no text part found, return empty string 

451 return "" 

452 

453 def _get_task_params(self, task_send_params: TaskSendParams) -> dict: 

454 """Extract common parameters from task metadata.""" 

455 metadata = task_send_params.message.metadata or {} 

456 # Check for both agent_name and agentName in the metadata 

457 agent_name = metadata.get("agent_name", metadata.get("agentName")) 

458 return { 

459 "agent_name": agent_name, 

460 "streaming": metadata.get("streaming", True), 

461 "session_id": task_send_params.sessionId, 

462 } 

463 

464 async def _invoke(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse: 

465 task_send_params: TaskSendParams = request.params 

466 query = self._get_user_query(task_send_params) 

467 params = self._get_task_params(task_send_params) 

468 agent_name = params["agent_name"] 

469 streaming = params["streaming"] 

470 agent = self._create_agent(user_info, agent_name) 

471 

472 try: 

473 # Get the history from the task 

474 task = self.tasks.get(task_send_params.id) 

475 history = task.history if task and task.history else [] 

476 

477 # Always use streaming internally, but handle the response differently based on the streaming parameter 

478 all_parts = [] 

479 final_metadata = {} 

480 

481 # Create a streaming generator 

482 stream_gen = agent.stream(query, task_send_params.sessionId, history=history) 

483 

484 if streaming: 

485 # For streaming mode, we'll use the streaming endpoint instead 

486 # Just create a minimal response to acknowledge the request 

487 task_state = TaskState.WORKING 

488 task = await self._update_store(task_send_params.id, TaskStatus(state=task_state), []) 

489 return to_serializable(SendTaskResponse(id=request.id, result=task)) 

490 else: 

491 # For non-streaming mode, collect all chunks into a single response 

492 async for chunk in stream_gen: 

493 # Extract parts if they exist 

494 if "parts" in chunk and chunk["parts"]: 

495 all_parts.extend(chunk["parts"]) 

496 elif "content" in chunk: 

497 all_parts.append({"type": "text", "text": chunk["content"]}) 

498 

499 # Extract metadata if it exists 

500 if "metadata" in chunk: 

501 final_metadata.update(chunk["metadata"]) 

502 

503 # If we didn't get any parts, create a default part 

504 if not all_parts: 

505 all_parts = [{"type": "text", "text": "No response from MindsDB"}] 

506 

507 # Create the final response 

508 task_state = TaskState.COMPLETED 

509 task = await self._update_store( 

510 task_send_params.id, 

511 TaskStatus( 

512 state=task_state, 

513 message=Message(role="agent", parts=all_parts, metadata=final_metadata), 

514 ), 

515 [Artifact(parts=all_parts)], 

516 ) 

517 return to_serializable(SendTaskResponse(id=request.id, result=task)) 

518 except Exception as e: 

519 logger.exception("Error invoking agent:") 

520 result_text = f"Error invoking agent: {e}" 

521 parts = [{"type": "text", "text": result_text}] 

522 

523 task_state = TaskState.FAILED 

524 task = await self._update_store( 

525 task_send_params.id, 

526 TaskStatus(state=task_state, message=Message(role="agent", parts=parts)), 

527 [Artifact(parts=parts)], 

528 ) 

529 return to_serializable(SendTaskResponse(id=request.id, result=task)) 

530 

531 async def on_message_stream( 

532 self, request: MessageStreamRequest, user_info: Dict 

533 ) -> Union[AsyncIterable[SendStreamingMessageSuccessResponse], JSONRPCResponse]: 

534 """ 

535 Handle message streaming requests. 

536 """ 

537 logger.info(f"Processing message stream request for session {request.params.sessionId}") 

538 

539 query = self._get_user_query(request.params) 

540 params = self._get_task_params(request.params) 

541 

542 try: 

543 task_id = f"msg_stream_{request.params.sessionId}_{request.id}" 

544 context_id = f"ctx_{request.params.sessionId}" 

545 message_id = f"msg_{request.id}" 

546 

547 agents_controller = AgentsController() 

548 existing_agent = agents_controller.get_agent(params["agent_name"]) 

549 resp = agents_controller.get_completion(existing_agent, [{"question": query}]) 

550 response_message = resp["answer"][0] 

551 

552 response_message = Message( 

553 role="agent", parts=[{"type": "text", "text": response_message}], metadata={}, messageId=message_id 

554 ) 

555 

556 task_status = TaskStatus(state=TaskState.COMPLETED, message=response_message) 

557 

558 task_status_update = TaskStatusUpdateEvent( 

559 id=task_id, 

560 status=task_status, 

561 final=True, 

562 metadata={"message_stream": True}, 

563 contextId=context_id, 

564 taskId=task_id, 

565 ) 

566 

567 async def message_stream_generator(): 

568 yield to_serializable(SendStreamingMessageSuccessResponse(id=request.id, result=task_status_update)) 

569 

570 return message_stream_generator() 

571 

572 except Exception as e: 

573 logger.error(f"Error processing message stream: {e}") 

574 return SendStreamingMessageSuccessResponse( 

575 id=request.id, error=InternalError(message=f"Error processing message stream: {str(e)}") 

576 )