Coverage for mindsdb / api / a2a / task_manager.py: 0%
270 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import time
2import logging
3import asyncio
4from typing import AsyncIterable, Dict, Union
6from mindsdb.api.a2a.common.types import (
7 SendTaskRequest,
8 TaskSendParams,
9 Message,
10 TaskStatus,
11 Artifact,
12 TaskStatusUpdateEvent,
13 TaskArtifactUpdateEvent,
14 TaskState,
15 Task,
16 SendTaskResponse,
17 InternalError,
18 JSONRPCResponse,
19 SendTaskStreamingRequest,
20 SendTaskStreamingResponse,
21 InvalidRequestError,
22 MessageStreamRequest,
23 SendStreamingMessageSuccessResponse,
24)
25from mindsdb.api.a2a.common.server.task_manager import InMemoryTaskManager
26from mindsdb.api.a2a.agent import MindsDBAgent
27from mindsdb.api.a2a.utils import to_serializable, convert_a2a_message_to_qa_format
28from mindsdb.interfaces.agents.agents_controller import AgentsController
31logger = logging.getLogger(__name__)
34def to_question_format(messages):
35 """Convert A2A messages to a list of {"question": ...} dicts for agent compatibility."""
36 out = []
37 for msg in messages:
38 if "question" in msg:
39 out.append(msg)
40 elif "parts" in msg and isinstance(msg["parts"], list):
41 for part in msg["parts"]:
42 part_dict = to_serializable(part)
43 if part_dict.get("type") == "text" and "text" in part_dict:
44 out.append({"question": part_dict["text"]})
45 return out
48class AgentTaskManager(InMemoryTaskManager):
49 def __init__(
50 self,
51 project_name: str,
52 agent_name: str = None,
53 ):
54 super().__init__()
55 self.project_name = project_name
56 self.agent_name = agent_name
57 self.tasks = {} # Task storage
58 self.lock = asyncio.Lock() # Lock for task operations
60 def _create_agent(self, user_info: Dict, agent_name: str = None) -> MindsDBAgent:
61 """Create a new MindsDBAgent instance for the given agent name."""
62 if not agent_name:
63 raise ValueError("Agent name is required but was not provided in the request")
65 return MindsDBAgent(
66 agent_name=agent_name,
67 project_name=self.project_name,
68 user_info=user_info,
69 )
71 async def _stream_generator(
72 self, request: SendTaskStreamingRequest, user_info: Dict
73 ) -> AsyncIterable[SendTaskStreamingResponse]:
74 task_send_params: TaskSendParams = request.params
75 query = self._get_user_query(task_send_params)
76 params = self._get_task_params(task_send_params)
77 agent_name = params["agent_name"]
78 streaming = params["streaming"]
80 # Create and store the task first to ensure it exists
81 try:
82 task = await self.upsert_task(task_send_params)
83 logger.info(f"Task created/updated with history length: {len(task.history) if task.history else 0}")
84 except Exception as e:
85 logger.exception("Error creating task:")
86 error_result = to_serializable(
87 {
88 "id": request.id,
89 "error": to_serializable(InternalError(message=f"Error creating task: {e}")),
90 }
91 )
92 yield error_result
93 return # Early return from generator
95 agent = self._create_agent(user_info, agent_name)
97 # Get the history from the task object (where it was properly extracted and stored)
98 history = task.history if task and task.history else []
100 if not streaming:
101 # If streaming is disabled, use invoke and return a single response
102 try:
103 result = agent.invoke(query, task_send_params.sessionId, history=history)
105 # Use the parts from the agent response if available, or create them
106 if "parts" in result:
107 parts = result["parts"]
108 else:
109 result_text = result.get("content", "No response from MindsDB")
110 parts = [{"type": "text", "text": result_text}]
112 # Check if we have structured data
113 if "data" in result and result["data"]:
114 parts.append(
115 {
116 "type": "data",
117 "data": result["data"],
118 "metadata": {"subtype": "json"},
119 }
120 )
122 # Create and yield the final response
123 task_state = TaskState.COMPLETED
124 artifact = Artifact(parts=parts, index=0, append=False)
125 task_status = TaskStatus(state=task_state)
127 # Update the task store
128 await self._update_store(task_send_params.id, task_status, [artifact])
130 # Yield the artifact update
131 yield to_serializable(
132 SendTaskStreamingResponse(
133 id=request.id,
134 result=to_serializable(TaskArtifactUpdateEvent(id=task_send_params.id, artifact=artifact)),
135 )
136 )
138 # Yield the final status update
139 yield to_serializable(
140 SendTaskStreamingResponse(
141 id=request.id,
142 result=to_serializable(
143 TaskStatusUpdateEvent(
144 id=task_send_params.id,
145 status=to_serializable(TaskStatus(state=task_status.state)),
146 final=True,
147 )
148 ),
149 )
150 )
151 return
153 except Exception as e:
154 logger.exception("Error invoking agent:")
155 error_result = to_serializable(
156 {
157 "id": request.id,
158 "error": to_serializable(
159 JSONRPCResponse(
160 id=request.id,
161 error=to_serializable(InternalError(message=f"Error invoking agent: {e}")),
162 )
163 ),
164 }
165 )
166 yield error_result
167 return
169 # If streaming is enabled (default), use the streaming implementation
170 try:
171 logger.debug(f"Entering agent.stream() at {time.time()}")
172 # Create A2A message structure and convert using centralized utility
173 a2a_message = task_send_params.message.model_dump()
174 logger.debug(f"History: {history}")
175 if history:
176 a2a_message["history"] = [msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in history]
178 # Convert to Q&A format using centralized utility function
179 all_messages = convert_a2a_message_to_qa_format(a2a_message)
181 async for item in agent.streaming_invoke(all_messages, timeout=60):
182 # Clean up: Remove verbose debug logs, keep only errors and essential info
183 if isinstance(item, dict) and "artifact" in item and "parts" in item["artifact"]:
184 item["artifact"]["parts"] = [to_serializable(p) for p in item["artifact"]["parts"]]
185 yield to_serializable(item)
186 except TimeoutError as e:
187 logger.error(f"Timeout error while streaming the response: {e}")
188 error_text = "The request timed out. The agent is taking longer than expected to respond. Please try again or increase the timeout."
189 parts = [{"type": "text", "text": error_text}]
190 parts = [to_serializable(part) for part in parts]
191 artifact = {
192 "parts": parts,
193 "index": 0,
194 "append": False,
195 }
196 error_result = {
197 "id": request.id,
198 "error": {
199 "id": task_send_params.id,
200 "artifact": artifact,
201 "error_type": "timeout",
202 },
203 }
204 yield error_result
205 except ConnectionError as e:
206 logger.error(f"Connection error while streaming the response: {e}")
207 error_text = "Failed to connect to the agent. Please check if the agent is running and accessible."
208 parts = [{"type": "text", "text": error_text}]
209 parts = [to_serializable(part) for part in parts]
210 artifact = {
211 "parts": parts,
212 "index": 0,
213 "append": False,
214 }
215 error_result = {
216 "id": request.id,
217 "error": {
218 "id": task_send_params.id,
219 "artifact": artifact,
220 "error_type": "connection",
221 },
222 }
223 yield error_result
224 except Exception as e:
225 logger.exception("An error occurred while streaming the response:")
226 # Provide more specific error messages based on error type
227 if "API key" in str(e) or "authentication" in str(e).lower():
228 error_text = f"Authentication error: {str(e)}"
229 error_category = "authentication"
230 elif "404" in str(e) or "not found" in str(e).lower():
231 error_text = f"Resource not found: {str(e)}"
232 error_category = "not_found"
233 elif "rate limit" in str(e).lower() or "429" in str(e):
234 error_text = f"Rate limit exceeded: {str(e)}"
235 error_category = "rate_limit"
236 else:
237 error_text = f"An error occurred while streaming the response: {str(e)}"
238 error_category = "general"
240 # Ensure all parts are plain dicts
241 parts = [{"type": "text", "text": error_text}]
242 parts = [to_serializable(part) for part in parts]
243 artifact = {
244 "parts": parts,
245 "index": 0,
246 "append": False,
247 }
248 error_result = {
249 "id": request.id,
250 "error": {
251 "id": task_send_params.id,
252 "artifact": artifact,
253 "error_type": error_category,
254 },
255 }
256 yield error_result
258 async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
259 """Create or update a task in the task store.
261 Args:
262 task_send_params: The parameters for the task.
264 Returns:
265 The created or updated task.
266 """
267 logger.info(f"Upserting task {task_send_params.id}")
268 async with self.lock:
269 task = self.tasks.get(task_send_params.id)
270 if task is None:
271 # Convert the message to a dict if it's not already one
272 message = task_send_params.message
273 message_dict = message.dict() if hasattr(message, "dict") else message
275 # Get history from request if available - check both locations
276 history = []
278 # First check if history is at top level (task_send_params.history)
279 if hasattr(task_send_params, "history") and task_send_params.history:
280 # Convert each history item to dict if needed
281 for item in task_send_params.history:
282 item_dict = item.model_dump() if hasattr(item, "model_dump") else item
283 history.append(item_dict)
284 # Also check if history is nested under message (message.history)
285 elif hasattr(task_send_params.message, "history") and task_send_params.message.history:
286 for item in task_send_params.message.history:
287 item_dict = item.model_dump() if hasattr(item, "model_dump") else item
288 history.append(item_dict)
290 # DO NOT add current message to history - it should be processed separately
291 # The current message will be extracted during streaming from task_send_params.message
293 # Create a new task
294 task = Task(
295 id=task_send_params.id,
296 sessionId=task_send_params.sessionId,
297 status=TaskStatus(state=TaskState.SUBMITTED),
298 history=history,
299 artifacts=[],
300 )
301 self.tasks[task_send_params.id] = task
302 else:
303 # Convert the message to a dict if it's not already one
304 message = task_send_params.message
305 message_dict = message.dict() if hasattr(message, "dict") else message
307 # Update the existing task
308 if task.history is None:
309 task.history = []
311 # If we have new history from the request, use it
312 if hasattr(task_send_params, "history") and task_send_params.history:
313 # Convert each history item to dict if needed and ensure proper role
314 history = []
315 for item in task_send_params.history:
316 item_dict = item.dict() if hasattr(item, "dict") else item
317 # Ensure the role is properly set
318 if "role" not in item_dict:
319 item_dict["role"] = "assistant" if "answer" in item_dict else "user"
320 history.append(item_dict)
321 task.history = history
323 # Add current message to history
324 task.history.append(message_dict)
325 return task
327 def _validate_request(
328 self, request: Union[SendTaskRequest, SendTaskStreamingRequest]
329 ) -> Union[None, JSONRPCResponse]:
330 """Validate the request and return an error response if invalid."""
331 # Check if the request has the required parameters
332 if not hasattr(request, "params") or not request.params:
333 return JSONRPCResponse(
334 id=request.id,
335 error=InvalidRequestError(message="Missing params"),
336 )
338 # Check if the request has a message
339 if not hasattr(request.params, "message") or not request.params.message:
340 return JSONRPCResponse(
341 id=request.id,
342 error=InvalidRequestError(message="Missing message in params"),
343 )
345 # Check if the message has metadata
346 if not hasattr(request.params.message, "metadata") or not request.params.message.metadata:
347 return JSONRPCResponse(
348 id=request.id,
349 error=InvalidRequestError(message="Missing metadata in message"),
350 )
352 # Check if the agent name is provided in the metadata
353 metadata = request.params.message.metadata
354 agent_name = metadata.get("agent_name", metadata.get("agentName"))
355 if not agent_name:
356 return JSONRPCResponse(
357 id=request.id,
358 error=InvalidRequestError(
359 message="Agent name is required but was not provided in the request metadata"
360 ),
361 )
363 return None
365 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse:
366 error = self._validate_request(request)
367 if error:
368 return error
370 return await self._invoke(request, user_info=user_info)
372 async def on_send_task_subscribe(
373 self, request: SendTaskStreamingRequest, user_info: Dict
374 ) -> AsyncIterable[SendTaskStreamingResponse]:
375 error = self._validate_request(request)
376 if error:
377 logger.info(f"Yielding error at {time.time()} for invalid request: {error}")
378 yield to_serializable(SendTaskStreamingResponse(id=request.id, error=to_serializable(error.error)))
379 return
381 # We can't await an async generator directly, so we need to use it as is
382 try:
383 logger.debug(f"Entering streaming path at {time.time()}")
384 async for response in self._stream_generator(request, user_info):
385 logger.debug(f"Yielding streaming response at {time.time()} with: {str(response)[:120]}")
386 yield response
387 except Exception as e:
388 # If an error occurs, yield an error response
389 logger.exception(f"Error in on_send_task_subscribe: {e}")
390 error_result = to_serializable(
391 {
392 "id": request.id,
393 "error": to_serializable(InternalError(message=f"Error processing streaming request: {e}")),
394 }
395 )
396 yield error_result
398 async def _update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task:
399 async with self.lock:
400 try:
401 task = self.tasks[task_id]
402 except KeyError:
403 logger.error(f"Task {task_id} not found for updating the task")
404 # Create a new task with the provided ID if it doesn't exist
405 # This ensures we don't fail when a task is not found
406 task = Task(
407 id=task_id,
408 sessionId="recovery-session", # Use a placeholder session ID
409 messages=[], # No messages available
410 status=status, # Use the provided status
411 history=[], # No history available
412 )
413 self.tasks[task_id] = task
415 task.status = status
417 # Store assistant's response in history if we have a message
418 if status.message and status.message.role == "agent":
419 if task.history is None:
420 task.history = []
421 # Convert message to dict if needed
422 message_dict = status.message.dict() if hasattr(status.message, "dict") else status.message
423 # Ensure role is set to assistant
424 message_dict["role"] = "assistant"
425 task.history.append(message_dict)
427 if artifacts is not None:
428 for artifact in artifacts:
429 if artifact.append and len(task.artifacts) > 0:
430 # Append to the last artifact
431 last_artifact = task.artifacts[-1]
432 for part in artifact.parts:
433 last_artifact.parts.append(part)
434 else:
435 # Add as a new artifact
436 task.artifacts.append(artifact)
437 return task
439 def _get_user_query(self, task_send_params: TaskSendParams) -> str:
440 """Extract the user query from the task parameters."""
441 message = task_send_params.message
442 if not message.parts:
443 return ""
445 # Find the first text part
446 for part in message.parts:
447 if part.type == "text":
448 return part.text
450 # If no text part found, return empty string
451 return ""
453 def _get_task_params(self, task_send_params: TaskSendParams) -> dict:
454 """Extract common parameters from task metadata."""
455 metadata = task_send_params.message.metadata or {}
456 # Check for both agent_name and agentName in the metadata
457 agent_name = metadata.get("agent_name", metadata.get("agentName"))
458 return {
459 "agent_name": agent_name,
460 "streaming": metadata.get("streaming", True),
461 "session_id": task_send_params.sessionId,
462 }
464 async def _invoke(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse:
465 task_send_params: TaskSendParams = request.params
466 query = self._get_user_query(task_send_params)
467 params = self._get_task_params(task_send_params)
468 agent_name = params["agent_name"]
469 streaming = params["streaming"]
470 agent = self._create_agent(user_info, agent_name)
472 try:
473 # Get the history from the task
474 task = self.tasks.get(task_send_params.id)
475 history = task.history if task and task.history else []
477 # Always use streaming internally, but handle the response differently based on the streaming parameter
478 all_parts = []
479 final_metadata = {}
481 # Create a streaming generator
482 stream_gen = agent.stream(query, task_send_params.sessionId, history=history)
484 if streaming:
485 # For streaming mode, we'll use the streaming endpoint instead
486 # Just create a minimal response to acknowledge the request
487 task_state = TaskState.WORKING
488 task = await self._update_store(task_send_params.id, TaskStatus(state=task_state), [])
489 return to_serializable(SendTaskResponse(id=request.id, result=task))
490 else:
491 # For non-streaming mode, collect all chunks into a single response
492 async for chunk in stream_gen:
493 # Extract parts if they exist
494 if "parts" in chunk and chunk["parts"]:
495 all_parts.extend(chunk["parts"])
496 elif "content" in chunk:
497 all_parts.append({"type": "text", "text": chunk["content"]})
499 # Extract metadata if it exists
500 if "metadata" in chunk:
501 final_metadata.update(chunk["metadata"])
503 # If we didn't get any parts, create a default part
504 if not all_parts:
505 all_parts = [{"type": "text", "text": "No response from MindsDB"}]
507 # Create the final response
508 task_state = TaskState.COMPLETED
509 task = await self._update_store(
510 task_send_params.id,
511 TaskStatus(
512 state=task_state,
513 message=Message(role="agent", parts=all_parts, metadata=final_metadata),
514 ),
515 [Artifact(parts=all_parts)],
516 )
517 return to_serializable(SendTaskResponse(id=request.id, result=task))
518 except Exception as e:
519 logger.exception("Error invoking agent:")
520 result_text = f"Error invoking agent: {e}"
521 parts = [{"type": "text", "text": result_text}]
523 task_state = TaskState.FAILED
524 task = await self._update_store(
525 task_send_params.id,
526 TaskStatus(state=task_state, message=Message(role="agent", parts=parts)),
527 [Artifact(parts=parts)],
528 )
529 return to_serializable(SendTaskResponse(id=request.id, result=task))
531 async def on_message_stream(
532 self, request: MessageStreamRequest, user_info: Dict
533 ) -> Union[AsyncIterable[SendStreamingMessageSuccessResponse], JSONRPCResponse]:
534 """
535 Handle message streaming requests.
536 """
537 logger.info(f"Processing message stream request for session {request.params.sessionId}")
539 query = self._get_user_query(request.params)
540 params = self._get_task_params(request.params)
542 try:
543 task_id = f"msg_stream_{request.params.sessionId}_{request.id}"
544 context_id = f"ctx_{request.params.sessionId}"
545 message_id = f"msg_{request.id}"
547 agents_controller = AgentsController()
548 existing_agent = agents_controller.get_agent(params["agent_name"])
549 resp = agents_controller.get_completion(existing_agent, [{"question": query}])
550 response_message = resp["answer"][0]
552 response_message = Message(
553 role="agent", parts=[{"type": "text", "text": response_message}], metadata={}, messageId=message_id
554 )
556 task_status = TaskStatus(state=TaskState.COMPLETED, message=response_message)
558 task_status_update = TaskStatusUpdateEvent(
559 id=task_id,
560 status=task_status,
561 final=True,
562 metadata={"message_stream": True},
563 contextId=context_id,
564 taskId=task_id,
565 )
567 async def message_stream_generator():
568 yield to_serializable(SendStreamingMessageSuccessResponse(id=request.id, result=task_status_update))
570 return message_stream_generator()
572 except Exception as e:
573 logger.error(f"Error processing message stream: {e}")
574 return SendStreamingMessageSuccessResponse(
575 id=request.id, error=InternalError(message=f"Error processing message stream: {str(e)}")
576 )