Coverage for mindsdb / api / a2a / common / types.py: 0%
275 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Union, Any
2from pydantic import BaseModel, Field, TypeAdapter
3from typing import Literal, List, Annotated, Optional
4from datetime import datetime
5from pydantic import model_validator, ConfigDict, field_serializer
6from uuid import uuid4
7from enum import Enum
8from typing_extensions import Self
11class TaskState(str, Enum):
12 SUBMITTED = "submitted"
13 WORKING = "working"
14 INPUT_REQUIRED = "input-required"
15 COMPLETED = "completed"
16 CANCELED = "canceled"
17 FAILED = "failed"
18 UNKNOWN = "unknown"
21class TextPart(BaseModel):
22 type: Literal["text"] = "text"
23 text: str
24 metadata: dict[str, Any] | None = None
27class FileContent(BaseModel):
28 name: str | None = None
29 mimeType: str | None = None
30 bytes: str | None = None
31 uri: str | None = None
33 @model_validator(mode="after")
34 def check_content(self) -> Self:
35 if not (self.bytes or self.uri):
36 raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
37 if self.bytes and self.uri:
38 raise ValueError("Only one of 'bytes' or 'uri' can be present in the file data")
39 return self
42class FilePart(BaseModel):
43 type: Literal["file"] = "file"
44 file: FileContent
45 metadata: dict[str, Any] | None = None
48class DataPart(BaseModel):
49 type: Literal["data"] = "data"
50 data: dict[str, Any]
51 metadata: dict[str, Any] | None = None
54Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
57class Message(BaseModel):
58 role: Literal["user", "agent", "assistant"]
59 parts: List[Part]
60 metadata: dict[str, Any] | None = None
61 history: Optional[List["Message"]] = None
62 messageId: str | None = None
65class FlexibleMessage(BaseModel):
66 """Message that can handle both 'type' and 'kind' in parts."""
68 role: Literal["user", "agent", "assistant"]
69 parts: List[dict[str, Any]] # Raw parts that we'll process manually
70 metadata: dict[str, Any] | None = None
71 history: Optional[List["FlexibleMessage"]] = None
73 @model_validator(mode="after")
74 def normalize_parts(self):
75 """Convert parts with 'kind' to parts with 'type'."""
76 normalized_parts = []
77 for part in self.parts:
78 if isinstance(part, dict):
79 # Convert 'kind' to 'type' if needed
80 if "kind" in part and "type" not in part:
81 normalized_part = part.copy()
82 normalized_part["type"] = normalized_part.pop("kind")
83 else:
84 normalized_part = part
86 # Validate the normalized part
87 try:
88 if normalized_part.get("type") == "text":
89 normalized_parts.append(TextPart.model_validate(normalized_part))
90 elif normalized_part.get("type") == "file":
91 normalized_parts.append(FilePart.model_validate(normalized_part))
92 elif normalized_part.get("type") == "data":
93 normalized_parts.append(DataPart.model_validate(normalized_part))
94 else:
95 raise ValueError(f"Unknown part type: {normalized_part.get('type')}")
96 except Exception as e:
97 raise ValueError(f"Invalid part: {normalized_part}, error: {e}")
98 else:
99 normalized_parts.append(part)
101 self.parts = normalized_parts
102 return self
105class TaskStatus(BaseModel):
106 state: TaskState
107 message: Message | None = None
108 timestamp: datetime = Field(default_factory=datetime.now)
110 @field_serializer("timestamp")
111 def serialize_dt(self, dt: datetime, _info):
112 return dt.isoformat()
115class Artifact(BaseModel):
116 name: str | None = None
117 description: str | None = None
118 parts: List[Part]
119 metadata: dict[str, Any] | None = None
120 index: int = 0
121 append: bool | None = None
122 lastChunk: bool | None = None
125class Task(BaseModel):
126 id: str
127 sessionId: str | None = None
128 status: TaskStatus
129 artifacts: List[Artifact] | None = None
130 history: List[Message] | None = None
131 metadata: dict[str, Any] | None = None
132 contextId: str | None = None
135class TaskStatusUpdateEvent(BaseModel):
136 id: str
137 status: TaskStatus
138 final: bool = False
139 metadata: dict[str, Any] | None = None
140 contextId: str | None = None
141 taskId: str | None = None
144class TaskArtifactUpdateEvent(BaseModel):
145 id: str
146 artifact: Artifact
147 metadata: dict[str, Any] | None = None
148 contextId: str | None = None
149 taskId: str | None = None
152class AuthenticationInfo(BaseModel):
153 model_config = ConfigDict(extra="allow")
155 schemes: List[str]
156 credentials: str | None = None
159class PushNotificationConfig(BaseModel):
160 url: str
161 token: str | None = None
162 authentication: AuthenticationInfo | None = None
165class TaskIdParams(BaseModel):
166 id: str
167 metadata: dict[str, Any] | None = None
170class TaskQueryParams(TaskIdParams):
171 historyLength: int | None = None
174class TaskSendParams(BaseModel):
175 id: str
176 sessionId: str = Field(default_factory=lambda: uuid4().hex)
177 message: Message
178 acceptedOutputModes: Optional[List[str]] = None
179 pushNotification: PushNotificationConfig | None = None
180 historyLength: int | None = None
181 metadata: dict[str, Any] | None = None
184class TaskPushNotificationConfig(BaseModel):
185 id: str
186 pushNotificationConfig: PushNotificationConfig
189# RPC Messages
192class JSONRPCMessage(BaseModel):
193 jsonrpc: Literal["2.0"] = "2.0"
194 id: int | str | None = Field(default_factory=lambda: uuid4().hex)
197class JSONRPCRequest(JSONRPCMessage):
198 method: str
199 params: dict[str, Any] | None = None
202class JSONRPCError(BaseModel):
203 code: int
204 message: str
205 data: Any | None = None
208class JSONRPCResponse(JSONRPCMessage):
209 result: Any | None = None
210 error: JSONRPCError | None = None
213class SendTaskRequest(JSONRPCRequest):
214 method: Literal["tasks/send"] = "tasks/send"
215 params: TaskSendParams
218class SendTaskResponse(JSONRPCResponse):
219 result: Task | None = None
222class SendTaskStreamingRequest(JSONRPCRequest):
223 method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
224 params: TaskSendParams
227class SendTaskStreamingResponse(JSONRPCResponse):
228 result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
231class MessageStreamParams(BaseModel):
232 sessionId: str = Field(default_factory=lambda: uuid4().hex)
233 message: FlexibleMessage
234 metadata: dict[str, Any] | None = None
237class MessageStreamRequest(JSONRPCRequest):
238 method: Literal["message/stream"] = "message/stream"
239 params: MessageStreamParams
242class MessageStreamResponse(JSONRPCResponse):
243 result: Message | None = None
246class SendStreamingMessageSuccessResponse(JSONRPCResponse):
247 result: Union[Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent] | None = None
250class GetTaskRequest(JSONRPCRequest):
251 method: Literal["tasks/get"] = "tasks/get"
252 params: TaskQueryParams
255class GetTaskResponse(JSONRPCResponse):
256 result: Task | None = None
259class CancelTaskRequest(JSONRPCRequest):
260 method: Literal["tasks/cancel",] = "tasks/cancel"
261 params: TaskIdParams
264class CancelTaskResponse(JSONRPCResponse):
265 result: Task | None = None
268class SetTaskPushNotificationRequest(JSONRPCRequest):
269 method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
270 params: TaskPushNotificationConfig
273class SetTaskPushNotificationResponse(JSONRPCResponse):
274 result: TaskPushNotificationConfig | None = None
277class GetTaskPushNotificationRequest(JSONRPCRequest):
278 method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
279 params: TaskIdParams
282class GetTaskPushNotificationResponse(JSONRPCResponse):
283 result: TaskPushNotificationConfig | None = None
286class TaskResubscriptionRequest(JSONRPCRequest):
287 method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
288 params: TaskIdParams
291A2ARequest = TypeAdapter(
292 Annotated[
293 Union[
294 SendTaskRequest,
295 GetTaskRequest,
296 CancelTaskRequest,
297 SetTaskPushNotificationRequest,
298 GetTaskPushNotificationRequest,
299 TaskResubscriptionRequest,
300 SendTaskStreamingRequest,
301 MessageStreamRequest,
302 ],
303 Field(discriminator="method"),
304 ]
305)
307# Error types
310class JSONParseError(JSONRPCError):
311 code: int = -32700
312 message: str = "Invalid JSON payload"
313 data: Any | None = None
316class InvalidRequestError(JSONRPCError):
317 code: int = -32600
318 message: str = "Request payload validation error"
319 data: Any | None = None
322class MethodNotFoundError(JSONRPCError):
323 code: int = -32601
324 message: str = "Method not found"
325 data: None = None
328class InvalidParamsError(JSONRPCError):
329 code: int = -32602
330 message: str = "Invalid parameters"
331 data: Any | None = None
334class InternalError(JSONRPCError):
335 code: int = -32603
336 message: str = "Internal error"
337 data: Any | None = None
340class TaskNotFoundError(JSONRPCError):
341 code: int = -32001
342 message: str = "Task not found"
343 data: None = None
346class TaskNotCancelableError(JSONRPCError):
347 code: int = -32002
348 message: str = "Task cannot be canceled"
349 data: None = None
352class PushNotificationNotSupportedError(JSONRPCError):
353 code: int = -32003
354 message: str = "Push Notification is not supported"
355 data: None = None
358class UnsupportedOperationError(JSONRPCError):
359 code: int = -32004
360 message: str = "This operation is not supported"
361 data: None = None
364class ContentTypeNotSupportedError(JSONRPCError):
365 code: int = -32005
366 message: str = "Incompatible content types"
367 data: None = None
370class AgentProvider(BaseModel):
371 organization: str
372 url: str | None = None
375class AgentCapabilities(BaseModel):
376 streaming: bool = False
377 pushNotifications: bool = False
378 stateTransitionHistory: bool = False
381class AgentAuthentication(BaseModel):
382 schemes: List[str]
383 credentials: str | None = None
386class AgentSkill(BaseModel):
387 id: str
388 name: str
389 description: str | None = None
390 tags: List[str] | None = None
391 examples: List[str] | None = None
392 inputModes: List[str] | None = None
393 outputModes: List[str] | None = None
396class AgentCard(BaseModel):
397 name: str
398 description: str | None = None
399 url: str
400 provider: AgentProvider | None = None
401 version: str
402 documentationUrl: str | None = None
403 capabilities: AgentCapabilities
404 authentication: AgentAuthentication | None = None
405 defaultInputModes: List[str] = ["text"]
406 defaultOutputModes: List[str] = ["text"]
407 skills: List[AgentSkill]
410class A2AClientError(Exception):
411 pass
414class A2AClientHTTPError(A2AClientError):
415 def __init__(self, status_code: int, message: str):
416 self.status_code = status_code
417 self.message = message
418 super().__init__(f"HTTP Error {status_code}: {message}")
421class A2AClientJSONError(A2AClientError):
422 def __init__(self, message: str):
423 self.message = message
424 super().__init__(f"JSON Error: {message}")
427class MissingAPIKeyError(Exception):
428 """Exception for missing API key."""
430 pass