Coverage for mindsdb / api / a2a / common / server / task_manager.py: 0%
158 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from abc import ABC, abstractmethod
2from typing import Union, AsyncIterable, List, Dict
3from ...common.types import (
4 Task,
5 JSONRPCResponse,
6 TaskIdParams,
7 TaskQueryParams,
8 GetTaskRequest,
9 TaskNotFoundError,
10 SendTaskRequest,
11 CancelTaskRequest,
12 TaskNotCancelableError,
13 SetTaskPushNotificationRequest,
14 GetTaskPushNotificationRequest,
15 GetTaskResponse,
16 CancelTaskResponse,
17 SendTaskResponse,
18 SetTaskPushNotificationResponse,
19 GetTaskPushNotificationResponse,
20 TaskSendParams,
21 TaskStatus,
22 TaskState,
23 TaskResubscriptionRequest,
24 SendTaskStreamingRequest,
25 SendTaskStreamingResponse,
26 Artifact,
27 PushNotificationConfig,
28 TaskStatusUpdateEvent,
29 JSONRPCError,
30 TaskPushNotificationConfig,
31 InternalError,
32 MessageStreamRequest,
33)
34from ...common.server.utils import new_not_implemented_error
35from mindsdb.utilities import log
36import asyncio
38logger = log.getLogger(__name__)
41class TaskManager(ABC):
42 @abstractmethod
43 async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
44 pass
46 @abstractmethod
47 async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
48 pass
50 @abstractmethod
51 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse:
52 pass
54 @abstractmethod
55 async def on_send_task_subscribe(
56 self, request: SendTaskStreamingRequest, user_info: Dict
57 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
58 pass
60 @abstractmethod
61 async def on_set_task_push_notification(
62 self, request: SetTaskPushNotificationRequest
63 ) -> SetTaskPushNotificationResponse:
64 pass
66 @abstractmethod
67 async def on_get_task_push_notification(
68 self, request: GetTaskPushNotificationRequest
69 ) -> GetTaskPushNotificationResponse:
70 pass
72 @abstractmethod
73 async def on_resubscribe_to_task(
74 self, request: TaskResubscriptionRequest
75 ) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
76 pass
78 @abstractmethod
79 async def on_message_stream(
80 self, request: MessageStreamRequest, user_info: Dict
81 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
82 pass
85class InMemoryTaskManager(TaskManager):
86 def __init__(self):
87 self.tasks: dict[str, Task] = {}
88 self.push_notification_infos: dict[str, PushNotificationConfig] = {}
89 self.lock = asyncio.Lock()
90 self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
91 self.subscriber_lock = asyncio.Lock()
93 async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
94 logger.info(f"Getting task {request.params.id}")
95 task_query_params: TaskQueryParams = request.params
97 async with self.lock:
98 task = self.tasks.get(task_query_params.id)
99 if task is None:
100 return GetTaskResponse(id=request.id, error=TaskNotFoundError())
102 task_result = self.append_task_history(task, task_query_params.historyLength)
104 return GetTaskResponse(id=request.id, result=task_result)
106 async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
107 logger.info(f"Cancelling task {request.params.id}")
108 task_id_params: TaskIdParams = request.params
110 async with self.lock:
111 task = self.tasks.get(task_id_params.id)
112 if task is None:
113 return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
115 return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
117 @abstractmethod
118 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse:
119 pass
121 @abstractmethod
122 async def on_send_task_subscribe(
123 self, request: SendTaskStreamingRequest, user_info: Dict
124 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
125 pass
127 async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
128 async with self.lock:
129 task = self.tasks.get(task_id)
130 if task is None:
131 raise ValueError(f"Task not found for {task_id}")
133 self.push_notification_infos[task_id] = notification_config
135 return
137 async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
138 async with self.lock:
139 task = self.tasks.get(task_id)
140 if task is None:
141 raise ValueError(f"Task not found for {task_id}")
143 return self.push_notification_infos[task_id]
145 return
147 async def has_push_notification_info(self, task_id: str) -> bool:
148 async with self.lock:
149 return task_id in self.push_notification_infos
151 async def on_set_task_push_notification(
152 self, request: SetTaskPushNotificationRequest
153 ) -> SetTaskPushNotificationResponse:
154 logger.info(f"Setting task push notification {request.params.id}")
155 task_notification_params: TaskPushNotificationConfig = request.params
157 try:
158 await self.set_push_notification_info(
159 task_notification_params.id,
160 task_notification_params.pushNotificationConfig,
161 )
162 except Exception:
163 logger.exception("Error while setting push notification info:")
164 return JSONRPCResponse(
165 id=request.id,
166 error=InternalError(message="An error occurred while setting push notification info"),
167 )
169 return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
171 async def on_get_task_push_notification(
172 self, request: GetTaskPushNotificationRequest
173 ) -> GetTaskPushNotificationResponse:
174 logger.info(f"Getting task push notification {request.params.id}")
175 task_params: TaskIdParams = request.params
177 try:
178 notification_info = await self.get_push_notification_info(task_params.id)
179 except Exception:
180 logger.exception("Error while getting push notification info:")
181 return GetTaskPushNotificationResponse(
182 id=request.id,
183 error=InternalError(message="An error occurred while getting push notification info"),
184 )
186 return GetTaskPushNotificationResponse(
187 id=request.id,
188 result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info),
189 )
191 async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
192 logger.info(f"Upserting task {task_send_params.id}")
193 async with self.lock:
194 task = self.tasks.get(task_send_params.id)
195 if task is None:
196 task = Task(
197 id=task_send_params.id,
198 sessionId=task_send_params.sessionId,
199 messages=[task_send_params.message],
200 status=TaskStatus(state=TaskState.SUBMITTED),
201 history=[task_send_params.message],
202 )
203 self.tasks[task_send_params.id] = task
204 else:
205 task.history.append(task_send_params.message)
207 return task
209 async def on_resubscribe_to_task(
210 self, request: TaskResubscriptionRequest
211 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
212 return new_not_implemented_error(request.id)
214 async def update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task:
215 async with self.lock:
216 try:
217 task = self.tasks[task_id]
218 except KeyError:
219 logger.error(f"Task {task_id} not found for updating the task")
220 raise ValueError(f"Task {task_id} not found")
222 task.status = status
224 if status.message is not None:
225 task.history.append(status.message)
227 if artifacts is not None:
228 if task.artifacts is None:
229 task.artifacts = []
230 task.artifacts.extend(artifacts)
232 return task
234 def append_task_history(self, task: Task, historyLength: int | None):
235 new_task = task.model_copy()
236 if historyLength is not None and historyLength > 0:
237 new_task.history = new_task.history[-historyLength:]
238 else:
239 new_task.history = []
241 return new_task
243 async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
244 async with self.subscriber_lock:
245 if task_id not in self.task_sse_subscribers:
246 if is_resubscribe:
247 raise ValueError("Task not found for resubscription")
248 else:
249 self.task_sse_subscribers[task_id] = []
251 sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
252 self.task_sse_subscribers[task_id].append(sse_event_queue)
253 return sse_event_queue
255 async def enqueue_events_for_sse(self, task_id, task_update_event):
256 async with self.subscriber_lock:
257 if task_id not in self.task_sse_subscribers:
258 return
260 current_subscribers = self.task_sse_subscribers[task_id]
261 for subscriber in current_subscribers:
262 await subscriber.put(task_update_event)
264 async def dequeue_events_for_sse(
265 self, request_id, task_id, sse_event_queue: asyncio.Queue
266 ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
267 try:
268 while True:
269 event = await sse_event_queue.get()
270 if isinstance(event, JSONRPCError):
271 yield SendTaskStreamingResponse(id=request_id, error=event)
272 break
274 yield SendTaskStreamingResponse(id=request_id, result=event)
275 if isinstance(event, TaskStatusUpdateEvent) and event.final:
276 break
277 finally:
278 async with self.subscriber_lock:
279 if task_id in self.task_sse_subscribers:
280 self.task_sse_subscribers[task_id].remove(sse_event_queue)