Coverage for mindsdb / api / a2a / common / server / task_manager.py: 0%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from abc import ABC, abstractmethod 

2from typing import Union, AsyncIterable, List, Dict 

3from ...common.types import ( 

4 Task, 

5 JSONRPCResponse, 

6 TaskIdParams, 

7 TaskQueryParams, 

8 GetTaskRequest, 

9 TaskNotFoundError, 

10 SendTaskRequest, 

11 CancelTaskRequest, 

12 TaskNotCancelableError, 

13 SetTaskPushNotificationRequest, 

14 GetTaskPushNotificationRequest, 

15 GetTaskResponse, 

16 CancelTaskResponse, 

17 SendTaskResponse, 

18 SetTaskPushNotificationResponse, 

19 GetTaskPushNotificationResponse, 

20 TaskSendParams, 

21 TaskStatus, 

22 TaskState, 

23 TaskResubscriptionRequest, 

24 SendTaskStreamingRequest, 

25 SendTaskStreamingResponse, 

26 Artifact, 

27 PushNotificationConfig, 

28 TaskStatusUpdateEvent, 

29 JSONRPCError, 

30 TaskPushNotificationConfig, 

31 InternalError, 

32 MessageStreamRequest, 

33) 

34from ...common.server.utils import new_not_implemented_error 

35from mindsdb.utilities import log 

36import asyncio 

37 

38logger = log.getLogger(__name__) 

39 

40 

41class TaskManager(ABC): 

42 @abstractmethod 

43 async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: 

44 pass 

45 

46 @abstractmethod 

47 async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: 

48 pass 

49 

50 @abstractmethod 

51 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse: 

52 pass 

53 

54 @abstractmethod 

55 async def on_send_task_subscribe( 

56 self, request: SendTaskStreamingRequest, user_info: Dict 

57 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 

58 pass 

59 

60 @abstractmethod 

61 async def on_set_task_push_notification( 

62 self, request: SetTaskPushNotificationRequest 

63 ) -> SetTaskPushNotificationResponse: 

64 pass 

65 

66 @abstractmethod 

67 async def on_get_task_push_notification( 

68 self, request: GetTaskPushNotificationRequest 

69 ) -> GetTaskPushNotificationResponse: 

70 pass 

71 

72 @abstractmethod 

73 async def on_resubscribe_to_task( 

74 self, request: TaskResubscriptionRequest 

75 ) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]: 

76 pass 

77 

78 @abstractmethod 

79 async def on_message_stream( 

80 self, request: MessageStreamRequest, user_info: Dict 

81 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 

82 pass 

83 

84 

85class InMemoryTaskManager(TaskManager): 

86 def __init__(self): 

87 self.tasks: dict[str, Task] = {} 

88 self.push_notification_infos: dict[str, PushNotificationConfig] = {} 

89 self.lock = asyncio.Lock() 

90 self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} 

91 self.subscriber_lock = asyncio.Lock() 

92 

93 async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: 

94 logger.info(f"Getting task {request.params.id}") 

95 task_query_params: TaskQueryParams = request.params 

96 

97 async with self.lock: 

98 task = self.tasks.get(task_query_params.id) 

99 if task is None: 

100 return GetTaskResponse(id=request.id, error=TaskNotFoundError()) 

101 

102 task_result = self.append_task_history(task, task_query_params.historyLength) 

103 

104 return GetTaskResponse(id=request.id, result=task_result) 

105 

106 async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: 

107 logger.info(f"Cancelling task {request.params.id}") 

108 task_id_params: TaskIdParams = request.params 

109 

110 async with self.lock: 

111 task = self.tasks.get(task_id_params.id) 

112 if task is None: 

113 return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) 

114 

115 return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) 

116 

117 @abstractmethod 

118 async def on_send_task(self, request: SendTaskRequest, user_info: Dict) -> SendTaskResponse: 

119 pass 

120 

121 @abstractmethod 

122 async def on_send_task_subscribe( 

123 self, request: SendTaskStreamingRequest, user_info: Dict 

124 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 

125 pass 

126 

127 async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig): 

128 async with self.lock: 

129 task = self.tasks.get(task_id) 

130 if task is None: 

131 raise ValueError(f"Task not found for {task_id}") 

132 

133 self.push_notification_infos[task_id] = notification_config 

134 

135 return 

136 

137 async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: 

138 async with self.lock: 

139 task = self.tasks.get(task_id) 

140 if task is None: 

141 raise ValueError(f"Task not found for {task_id}") 

142 

143 return self.push_notification_infos[task_id] 

144 

145 return 

146 

147 async def has_push_notification_info(self, task_id: str) -> bool: 

148 async with self.lock: 

149 return task_id in self.push_notification_infos 

150 

151 async def on_set_task_push_notification( 

152 self, request: SetTaskPushNotificationRequest 

153 ) -> SetTaskPushNotificationResponse: 

154 logger.info(f"Setting task push notification {request.params.id}") 

155 task_notification_params: TaskPushNotificationConfig = request.params 

156 

157 try: 

158 await self.set_push_notification_info( 

159 task_notification_params.id, 

160 task_notification_params.pushNotificationConfig, 

161 ) 

162 except Exception: 

163 logger.exception("Error while setting push notification info:") 

164 return JSONRPCResponse( 

165 id=request.id, 

166 error=InternalError(message="An error occurred while setting push notification info"), 

167 ) 

168 

169 return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params) 

170 

171 async def on_get_task_push_notification( 

172 self, request: GetTaskPushNotificationRequest 

173 ) -> GetTaskPushNotificationResponse: 

174 logger.info(f"Getting task push notification {request.params.id}") 

175 task_params: TaskIdParams = request.params 

176 

177 try: 

178 notification_info = await self.get_push_notification_info(task_params.id) 

179 except Exception: 

180 logger.exception("Error while getting push notification info:") 

181 return GetTaskPushNotificationResponse( 

182 id=request.id, 

183 error=InternalError(message="An error occurred while getting push notification info"), 

184 ) 

185 

186 return GetTaskPushNotificationResponse( 

187 id=request.id, 

188 result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info), 

189 ) 

190 

191 async def upsert_task(self, task_send_params: TaskSendParams) -> Task: 

192 logger.info(f"Upserting task {task_send_params.id}") 

193 async with self.lock: 

194 task = self.tasks.get(task_send_params.id) 

195 if task is None: 

196 task = Task( 

197 id=task_send_params.id, 

198 sessionId=task_send_params.sessionId, 

199 messages=[task_send_params.message], 

200 status=TaskStatus(state=TaskState.SUBMITTED), 

201 history=[task_send_params.message], 

202 ) 

203 self.tasks[task_send_params.id] = task 

204 else: 

205 task.history.append(task_send_params.message) 

206 

207 return task 

208 

209 async def on_resubscribe_to_task( 

210 self, request: TaskResubscriptionRequest 

211 ) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: 

212 return new_not_implemented_error(request.id) 

213 

214 async def update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task: 

215 async with self.lock: 

216 try: 

217 task = self.tasks[task_id] 

218 except KeyError: 

219 logger.error(f"Task {task_id} not found for updating the task") 

220 raise ValueError(f"Task {task_id} not found") 

221 

222 task.status = status 

223 

224 if status.message is not None: 

225 task.history.append(status.message) 

226 

227 if artifacts is not None: 

228 if task.artifacts is None: 

229 task.artifacts = [] 

230 task.artifacts.extend(artifacts) 

231 

232 return task 

233 

234 def append_task_history(self, task: Task, historyLength: int | None): 

235 new_task = task.model_copy() 

236 if historyLength is not None and historyLength > 0: 

237 new_task.history = new_task.history[-historyLength:] 

238 else: 

239 new_task.history = [] 

240 

241 return new_task 

242 

243 async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): 

244 async with self.subscriber_lock: 

245 if task_id not in self.task_sse_subscribers: 

246 if is_resubscribe: 

247 raise ValueError("Task not found for resubscription") 

248 else: 

249 self.task_sse_subscribers[task_id] = [] 

250 

251 sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited 

252 self.task_sse_subscribers[task_id].append(sse_event_queue) 

253 return sse_event_queue 

254 

255 async def enqueue_events_for_sse(self, task_id, task_update_event): 

256 async with self.subscriber_lock: 

257 if task_id not in self.task_sse_subscribers: 

258 return 

259 

260 current_subscribers = self.task_sse_subscribers[task_id] 

261 for subscriber in current_subscribers: 

262 await subscriber.put(task_update_event) 

263 

264 async def dequeue_events_for_sse( 

265 self, request_id, task_id, sse_event_queue: asyncio.Queue 

266 ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: 

267 try: 

268 while True: 

269 event = await sse_event_queue.get() 

270 if isinstance(event, JSONRPCError): 

271 yield SendTaskStreamingResponse(id=request_id, error=event) 

272 break 

273 

274 yield SendTaskStreamingResponse(id=request_id, result=event) 

275 if isinstance(event, TaskStatusUpdateEvent) and event.final: 

276 break 

277 finally: 

278 async with self.subscriber_lock: 

279 if task_id in self.task_sse_subscribers: 

280 self.task_sse_subscribers[task_id].remove(sse_event_queue)