Coverage for mindsdb / api / a2a / common / types.py: 0%

275 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Union, Any 

2from pydantic import BaseModel, Field, TypeAdapter 

3from typing import Literal, List, Annotated, Optional 

4from datetime import datetime 

5from pydantic import model_validator, ConfigDict, field_serializer 

6from uuid import uuid4 

7from enum import Enum 

8from typing_extensions import Self 

9 

10 

11class TaskState(str, Enum): 

12 SUBMITTED = "submitted" 

13 WORKING = "working" 

14 INPUT_REQUIRED = "input-required" 

15 COMPLETED = "completed" 

16 CANCELED = "canceled" 

17 FAILED = "failed" 

18 UNKNOWN = "unknown" 

19 

20 

21class TextPart(BaseModel): 

22 type: Literal["text"] = "text" 

23 text: str 

24 metadata: dict[str, Any] | None = None 

25 

26 

27class FileContent(BaseModel): 

28 name: str | None = None 

29 mimeType: str | None = None 

30 bytes: str | None = None 

31 uri: str | None = None 

32 

33 @model_validator(mode="after") 

34 def check_content(self) -> Self: 

35 if not (self.bytes or self.uri): 

36 raise ValueError("Either 'bytes' or 'uri' must be present in the file data") 

37 if self.bytes and self.uri: 

38 raise ValueError("Only one of 'bytes' or 'uri' can be present in the file data") 

39 return self 

40 

41 

42class FilePart(BaseModel): 

43 type: Literal["file"] = "file" 

44 file: FileContent 

45 metadata: dict[str, Any] | None = None 

46 

47 

48class DataPart(BaseModel): 

49 type: Literal["data"] = "data" 

50 data: dict[str, Any] 

51 metadata: dict[str, Any] | None = None 

52 

53 

54Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")] 

55 

56 

57class Message(BaseModel): 

58 role: Literal["user", "agent", "assistant"] 

59 parts: List[Part] 

60 metadata: dict[str, Any] | None = None 

61 history: Optional[List["Message"]] = None 

62 messageId: str | None = None 

63 

64 

65class FlexibleMessage(BaseModel): 

66 """Message that can handle both 'type' and 'kind' in parts.""" 

67 

68 role: Literal["user", "agent", "assistant"] 

69 parts: List[dict[str, Any]] # Raw parts that we'll process manually 

70 metadata: dict[str, Any] | None = None 

71 history: Optional[List["FlexibleMessage"]] = None 

72 

73 @model_validator(mode="after") 

74 def normalize_parts(self): 

75 """Convert parts with 'kind' to parts with 'type'.""" 

76 normalized_parts = [] 

77 for part in self.parts: 

78 if isinstance(part, dict): 

79 # Convert 'kind' to 'type' if needed 

80 if "kind" in part and "type" not in part: 

81 normalized_part = part.copy() 

82 normalized_part["type"] = normalized_part.pop("kind") 

83 else: 

84 normalized_part = part 

85 

86 # Validate the normalized part 

87 try: 

88 if normalized_part.get("type") == "text": 

89 normalized_parts.append(TextPart.model_validate(normalized_part)) 

90 elif normalized_part.get("type") == "file": 

91 normalized_parts.append(FilePart.model_validate(normalized_part)) 

92 elif normalized_part.get("type") == "data": 

93 normalized_parts.append(DataPart.model_validate(normalized_part)) 

94 else: 

95 raise ValueError(f"Unknown part type: {normalized_part.get('type')}") 

96 except Exception as e: 

97 raise ValueError(f"Invalid part: {normalized_part}, error: {e}") 

98 else: 

99 normalized_parts.append(part) 

100 

101 self.parts = normalized_parts 

102 return self 

103 

104 

105class TaskStatus(BaseModel): 

106 state: TaskState 

107 message: Message | None = None 

108 timestamp: datetime = Field(default_factory=datetime.now) 

109 

110 @field_serializer("timestamp") 

111 def serialize_dt(self, dt: datetime, _info): 

112 return dt.isoformat() 

113 

114 

115class Artifact(BaseModel): 

116 name: str | None = None 

117 description: str | None = None 

118 parts: List[Part] 

119 metadata: dict[str, Any] | None = None 

120 index: int = 0 

121 append: bool | None = None 

122 lastChunk: bool | None = None 

123 

124 

125class Task(BaseModel): 

126 id: str 

127 sessionId: str | None = None 

128 status: TaskStatus 

129 artifacts: List[Artifact] | None = None 

130 history: List[Message] | None = None 

131 metadata: dict[str, Any] | None = None 

132 contextId: str | None = None 

133 

134 

135class TaskStatusUpdateEvent(BaseModel): 

136 id: str 

137 status: TaskStatus 

138 final: bool = False 

139 metadata: dict[str, Any] | None = None 

140 contextId: str | None = None 

141 taskId: str | None = None 

142 

143 

144class TaskArtifactUpdateEvent(BaseModel): 

145 id: str 

146 artifact: Artifact 

147 metadata: dict[str, Any] | None = None 

148 contextId: str | None = None 

149 taskId: str | None = None 

150 

151 

152class AuthenticationInfo(BaseModel): 

153 model_config = ConfigDict(extra="allow") 

154 

155 schemes: List[str] 

156 credentials: str | None = None 

157 

158 

159class PushNotificationConfig(BaseModel): 

160 url: str 

161 token: str | None = None 

162 authentication: AuthenticationInfo | None = None 

163 

164 

165class TaskIdParams(BaseModel): 

166 id: str 

167 metadata: dict[str, Any] | None = None 

168 

169 

170class TaskQueryParams(TaskIdParams): 

171 historyLength: int | None = None 

172 

173 

174class TaskSendParams(BaseModel): 

175 id: str 

176 sessionId: str = Field(default_factory=lambda: uuid4().hex) 

177 message: Message 

178 acceptedOutputModes: Optional[List[str]] = None 

179 pushNotification: PushNotificationConfig | None = None 

180 historyLength: int | None = None 

181 metadata: dict[str, Any] | None = None 

182 

183 

184class TaskPushNotificationConfig(BaseModel): 

185 id: str 

186 pushNotificationConfig: PushNotificationConfig 

187 

188 

189# RPC Messages 

190 

191 

192class JSONRPCMessage(BaseModel): 

193 jsonrpc: Literal["2.0"] = "2.0" 

194 id: int | str | None = Field(default_factory=lambda: uuid4().hex) 

195 

196 

197class JSONRPCRequest(JSONRPCMessage): 

198 method: str 

199 params: dict[str, Any] | None = None 

200 

201 

202class JSONRPCError(BaseModel): 

203 code: int 

204 message: str 

205 data: Any | None = None 

206 

207 

208class JSONRPCResponse(JSONRPCMessage): 

209 result: Any | None = None 

210 error: JSONRPCError | None = None 

211 

212 

213class SendTaskRequest(JSONRPCRequest): 

214 method: Literal["tasks/send"] = "tasks/send" 

215 params: TaskSendParams 

216 

217 

218class SendTaskResponse(JSONRPCResponse): 

219 result: Task | None = None 

220 

221 

222class SendTaskStreamingRequest(JSONRPCRequest): 

223 method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe" 

224 params: TaskSendParams 

225 

226 

227class SendTaskStreamingResponse(JSONRPCResponse): 

228 result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None 

229 

230 

231class MessageStreamParams(BaseModel): 

232 sessionId: str = Field(default_factory=lambda: uuid4().hex) 

233 message: FlexibleMessage 

234 metadata: dict[str, Any] | None = None 

235 

236 

237class MessageStreamRequest(JSONRPCRequest): 

238 method: Literal["message/stream"] = "message/stream" 

239 params: MessageStreamParams 

240 

241 

242class MessageStreamResponse(JSONRPCResponse): 

243 result: Message | None = None 

244 

245 

246class SendStreamingMessageSuccessResponse(JSONRPCResponse): 

247 result: Union[Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent] | None = None 

248 

249 

250class GetTaskRequest(JSONRPCRequest): 

251 method: Literal["tasks/get"] = "tasks/get" 

252 params: TaskQueryParams 

253 

254 

255class GetTaskResponse(JSONRPCResponse): 

256 result: Task | None = None 

257 

258 

259class CancelTaskRequest(JSONRPCRequest): 

260 method: Literal["tasks/cancel",] = "tasks/cancel" 

261 params: TaskIdParams 

262 

263 

264class CancelTaskResponse(JSONRPCResponse): 

265 result: Task | None = None 

266 

267 

268class SetTaskPushNotificationRequest(JSONRPCRequest): 

269 method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set" 

270 params: TaskPushNotificationConfig 

271 

272 

273class SetTaskPushNotificationResponse(JSONRPCResponse): 

274 result: TaskPushNotificationConfig | None = None 

275 

276 

277class GetTaskPushNotificationRequest(JSONRPCRequest): 

278 method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get" 

279 params: TaskIdParams 

280 

281 

282class GetTaskPushNotificationResponse(JSONRPCResponse): 

283 result: TaskPushNotificationConfig | None = None 

284 

285 

286class TaskResubscriptionRequest(JSONRPCRequest): 

287 method: Literal["tasks/resubscribe",] = "tasks/resubscribe" 

288 params: TaskIdParams 

289 

290 

291A2ARequest = TypeAdapter( 

292 Annotated[ 

293 Union[ 

294 SendTaskRequest, 

295 GetTaskRequest, 

296 CancelTaskRequest, 

297 SetTaskPushNotificationRequest, 

298 GetTaskPushNotificationRequest, 

299 TaskResubscriptionRequest, 

300 SendTaskStreamingRequest, 

301 MessageStreamRequest, 

302 ], 

303 Field(discriminator="method"), 

304 ] 

305) 

306 

307# Error types 

308 

309 

310class JSONParseError(JSONRPCError): 

311 code: int = -32700 

312 message: str = "Invalid JSON payload" 

313 data: Any | None = None 

314 

315 

316class InvalidRequestError(JSONRPCError): 

317 code: int = -32600 

318 message: str = "Request payload validation error" 

319 data: Any | None = None 

320 

321 

322class MethodNotFoundError(JSONRPCError): 

323 code: int = -32601 

324 message: str = "Method not found" 

325 data: None = None 

326 

327 

328class InvalidParamsError(JSONRPCError): 

329 code: int = -32602 

330 message: str = "Invalid parameters" 

331 data: Any | None = None 

332 

333 

334class InternalError(JSONRPCError): 

335 code: int = -32603 

336 message: str = "Internal error" 

337 data: Any | None = None 

338 

339 

340class TaskNotFoundError(JSONRPCError): 

341 code: int = -32001 

342 message: str = "Task not found" 

343 data: None = None 

344 

345 

346class TaskNotCancelableError(JSONRPCError): 

347 code: int = -32002 

348 message: str = "Task cannot be canceled" 

349 data: None = None 

350 

351 

352class PushNotificationNotSupportedError(JSONRPCError): 

353 code: int = -32003 

354 message: str = "Push Notification is not supported" 

355 data: None = None 

356 

357 

358class UnsupportedOperationError(JSONRPCError): 

359 code: int = -32004 

360 message: str = "This operation is not supported" 

361 data: None = None 

362 

363 

364class ContentTypeNotSupportedError(JSONRPCError): 

365 code: int = -32005 

366 message: str = "Incompatible content types" 

367 data: None = None 

368 

369 

370class AgentProvider(BaseModel): 

371 organization: str 

372 url: str | None = None 

373 

374 

375class AgentCapabilities(BaseModel): 

376 streaming: bool = False 

377 pushNotifications: bool = False 

378 stateTransitionHistory: bool = False 

379 

380 

381class AgentAuthentication(BaseModel): 

382 schemes: List[str] 

383 credentials: str | None = None 

384 

385 

386class AgentSkill(BaseModel): 

387 id: str 

388 name: str 

389 description: str | None = None 

390 tags: List[str] | None = None 

391 examples: List[str] | None = None 

392 inputModes: List[str] | None = None 

393 outputModes: List[str] | None = None 

394 

395 

396class AgentCard(BaseModel): 

397 name: str 

398 description: str | None = None 

399 url: str 

400 provider: AgentProvider | None = None 

401 version: str 

402 documentationUrl: str | None = None 

403 capabilities: AgentCapabilities 

404 authentication: AgentAuthentication | None = None 

405 defaultInputModes: List[str] = ["text"] 

406 defaultOutputModes: List[str] = ["text"] 

407 skills: List[AgentSkill] 

408 

409 

410class A2AClientError(Exception): 

411 pass 

412 

413 

414class A2AClientHTTPError(A2AClientError): 

415 def __init__(self, status_code: int, message: str): 

416 self.status_code = status_code 

417 self.message = message 

418 super().__init__(f"HTTP Error {status_code}: {message}") 

419 

420 

421class A2AClientJSONError(A2AClientError): 

422 def __init__(self, message: str): 

423 self.message = message 

424 super().__init__(f"JSON Error: {message}") 

425 

426 

427class MissingAPIKeyError(Exception): 

428 """Exception for missing API key.""" 

429 

430 pass