Coverage for mindsdb / integrations / handlers / xata_handler / xata_handler.py: 0%
179 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Optional
3import pandas as pd
4import json
5import xata
6from xata.helpers import BulkProcessor
8from mindsdb.integrations.libs.response import RESPONSE_TYPE
9from mindsdb.integrations.libs.response import HandlerResponse
10from mindsdb.integrations.libs.response import HandlerResponse as Response
11from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
12from mindsdb.integrations.libs.vectordatabase_handler import (
13 FilterCondition,
14 FilterOperator,
15 TableField,
16 VectorStoreHandler,
17)
18from mindsdb.utilities import log
20logger = log.getLogger(__name__)
23class XataHandler(VectorStoreHandler):
24 """This handler handles connection and execution of the Xata statements."""
26 name = "xata"
28 def __init__(self, name: str, **kwargs):
29 super().__init__(name)
30 self._connection_data = kwargs.get("connection_data")
31 self._client_config = {
32 "db_url": self._connection_data.get("db_url"),
33 "api_key": self._connection_data.get("api_key"),
34 }
35 self._create_table_params = {
36 "dimension": self._connection_data.get("dimension", 8),
37 }
38 self._select_params = {
39 "similarity_function": self._connection_data.get("similarity_function", "cosineSimilarity"),
40 }
41 self._client = None
42 self.is_connected = False
43 self.connect()
45 def __del__(self):
46 if self.is_connected is True:
47 self.disconnect()
49 def connect(self):
50 """Connect to a Xata database."""
51 if self.is_connected is True:
52 return self._client
53 try:
54 self._client = xata.XataClient(**self._client_config)
55 self.is_connected = True
56 return self._client
57 except Exception as e:
58 logger.error(f"Error connecting to Xata client: {e}!")
59 self.is_connected = False
61 def disconnect(self):
62 """Close the database connection."""
63 if self.is_connected is False:
64 return
65 self._client = None
66 self.is_connected = False
68 def check_connection(self):
69 """Check the connection to the Xata database."""
70 response_code = StatusResponse(False)
71 need_to_close = self.is_connected is False
72 # NOTE: no direct way to test this
73 # try getting the user, if it fails, it means that we are not connected
74 try:
75 resp = self._client.users().get()
76 if not resp.is_success():
77 raise Exception(resp["message"])
78 response_code.success = True
79 except Exception as e:
80 logger.error(f"Error connecting to Xata: {e}!")
81 response_code.error_message = str(e)
82 finally:
83 if response_code.success is True and need_to_close:
84 self.disconnect()
85 if response_code.success is False and self.is_connected is True:
86 self.is_connected = False
87 return response_code
89 def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse:
90 """Create a table with the given name in the Xata database."""
92 resp = self._client.table().create(table_name)
93 if not resp.is_success():
94 raise Exception(f"Unable to create table {table_name}: {resp['message']}")
95 resp = self._client.table().set_schema(
96 table_name=table_name,
97 payload={
98 "columns": [
99 {
100 "name": "embeddings",
101 "type": "vector",
102 "vector": {"dimension": self._create_table_params["dimension"]}
103 },
104 {"name": "content", "type": "text"},
105 {"name": "metadata", "type": "json"},
106 ]
107 }
108 )
109 if not resp.is_success():
110 raise Exception(f"Unable to change schema of table {table_name}: {resp['message']}")
112 def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse:
113 """Delete a table from the Xata database."""
115 resp = self._client.table().delete(table_name)
116 if not resp.is_success():
117 raise Exception(f"Unable to delete table: {resp['message']}")
119 def get_columns(self, table_name: str) -> HandlerResponse:
120 """Get columns of the given table"""
121 # Vector stores have predefined columns
122 try:
123 # But at least try to see if the table is valid
124 resp = self._client.table().get_columns(table_name)
125 if not resp.is_success():
126 raise Exception(f"Error getting columns: {resp['message']}")
127 except Exception as e:
128 return Response(
129 resp_type=RESPONSE_TYPE.ERROR,
130 error_message=f"{e}",
131 )
132 return super().get_columns(table_name)
134 def get_tables(self) -> HandlerResponse:
135 """Get the list of tables in the Xata database."""
136 try:
137 table_names = pd.DataFrame(
138 columns=["TABLE_NAME"],
139 data=[table_data["name"] for table_data in self._client.branch().get_details()["schema"]["tables"]],
140 )
141 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=table_names)
142 except Exception as e:
143 return Response(
144 resp_type=RESPONSE_TYPE.ERROR,
145 error_message=f"Error getting list of tables: {e}",
146 )
148 def insert(self, table_name: str, data: pd.DataFrame, columns: List[str] = None):
149 """ Insert data into the Xata database. """
150 if columns:
151 data = data[columns]
152 # Convert to records
153 data = data.to_dict("records")
154 # Convert metadata to json
155 for row in data:
156 if "metadata" in row:
157 row["metadata"] = json.dumps(row["metadata"])
158 if len(data) > 1:
159 # Bulk processing
160 bp = BulkProcessor(self._client, throw_exception=True)
161 bp.put_records(table_name, data)
162 bp.flush_queue()
164 elif len(data) == 0:
165 # Skip
166 return Response(resp_type=RESPONSE_TYPE.OK)
167 elif "id" in data[0] and TableField.ID.value in columns:
168 # If id present
169 id = data[0]["id"]
170 rest_of_data = data[0].copy()
171 del rest_of_data["id"]
173 resp = self._client.records().insert_with_id(
174 table_name=table_name,
175 record_id=id,
176 payload=rest_of_data,
177 create_only=True,
178 columns=columns
179 )
180 if not resp.is_success():
181 raise Exception(resp["message"])
183 else:
184 # If id not present
185 resp = self._client.records().insert(
186 table_name=table_name,
187 payload=data[0],
188 columns=columns
189 )
190 if not resp.is_success():
191 raise Exception(resp["message"])
193 def update(self, table_name: str, data: pd.DataFrame, columns: List[str] = None) -> HandlerResponse:
194 """Update data in the Xata database."""
195 # Not supported
196 return super().update(table_name, data, columns)
198 def _get_xata_operator(self, operator: FilterOperator) -> str:
199 """Translate SQL operator to oprator understood by Xata filter language."""
200 mapping = {
201 FilterOperator.EQUAL: "$is",
202 FilterOperator.NOT_EQUAL: "$isNot",
203 FilterOperator.LESS_THAN: "$lt",
204 FilterOperator.LESS_THAN_OR_EQUAL: "$le",
205 FilterOperator.GREATER_THAN: "$gt",
206 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte",
207 FilterOperator.LIKE: "$pattern",
208 }
209 if operator not in mapping:
210 raise Exception(f"Operator '{operator}' is not supported!")
211 return mapping[operator]
213 def _translate_non_vector_conditions(self, conditions: List[FilterCondition]) -> Optional[dict]:
214 """
215 Translate a list of FilterCondition objects a dict that can be used by Xata for filtering.
216 E.g.,
217 [
218 FilterCondition(
219 column="metadata.price",
220 op=FilterOperator.LESS_THAN,
221 value=100,
222 ),
223 FilterCondition(
224 column="metadata.price",
225 op=FilterOperator.GREATER_THAN,
226 value=10,
227 )
228 ]
229 -->
230 {
231 "metadata->price" {
232 "$gt": 10,
233 "$lt": 100
234 },
235 }
236 """
237 if not conditions:
238 return None
239 # Translate metadata columns
240 for condition in conditions:
241 if condition.column.startswith(TableField.METADATA.value):
242 condition.column = condition.column.replace(".", "->")
243 # Generate filters
244 filters = {}
245 for condition in conditions:
246 # Skip search vector condition
247 if condition.column == TableField.SEARCH_VECTOR.value:
248 continue
249 current_filter = original_filter = {}
250 # Special case LIKE: needs pattern translation
251 if condition.op == FilterOperator.LIKE:
252 condition.value = condition.value.replace("%", "*").replace("_", "?")
253 # Generate substatment
254 current_filter[condition.column] = {self._get_xata_operator(condition.op): condition.value}
255 # Check for conflicting and insert
256 for key in original_filter:
257 if key in filters:
258 filters[key] = {**filters[key], **original_filter[key]}
259 else:
260 filters = {**filters, **original_filter}
261 return filters if filters else None
263 def select(self, table_name: str, columns: List[str] = None, conditions: List[FilterCondition] = None,
264 offset: int = None, limit: int = None) -> pd.DataFrame:
265 """Run general query or a vector similarity search and return results."""
266 if not columns:
267 columns = [col["name"] for col in self.SCHEMA]
268 # Generate filter conditions
269 filters = self._translate_non_vector_conditions(conditions)
270 # Check for search vector
271 search_vector = (
272 []
273 if conditions is None
274 else [
275 condition.value
276 for condition in conditions
277 if condition.column == TableField.SEARCH_VECTOR.value
278 ]
279 )
280 if len(search_vector) > 0:
281 search_vector = search_vector[0]
282 else:
283 search_vector = None
284 # Search
285 results_df = pd.DataFrame(columns)
286 if search_vector is not None:
287 # Similarity
289 params = {
290 "queryVector": search_vector,
291 "column": TableField.EMBEDDINGS.value,
292 "similarityFunction": self._select_params["similarity_function"]
293 }
294 if filters:
295 params["filter"] = filters
296 if limit:
297 params["size"] = limit
298 results = self._client.data().vector_search(table_name, params)
299 # Check for errors
300 if not results.is_success():
301 raise Exception(results["message"])
302 # Convert result
303 results_df = pd.DataFrame.from_records(results["records"])
304 if "xata" in results_df.columns:
305 results_df["xata"] = results_df["xata"].apply(lambda x: x["score"])
306 results_df.rename({"xata": TableField.DISTANCE.value}, axis=1, inplace=True)
308 else:
309 # General get query
311 params = {
312 "columns": columns if columns else [],
313 }
314 if filters:
315 params["filter"] = filters
316 if limit or offset:
317 params["page"] = {}
318 if limit:
319 params["page"]["size"] = limit
320 if offset:
321 params["page"]["offset"] = offset
322 results = self._client.data().query(table_name, params)
323 # Check for errors
324 if not results.is_success():
325 raise Exception(results["message"])
326 # Convert result
327 results_df = pd.DataFrame.from_records(results["records"])
328 if "xata" in results_df.columns:
329 results_df.drop(["xata"], axis=1, inplace=True)
331 return results_df
333 def delete(self, table_name: str, conditions: List[FilterCondition] = None):
334 ids = []
335 for condition in conditions:
336 if condition.op == FilterOperator.EQUAL:
337 ids.append(condition.value)
338 else:
339 return Response(
340 resp_type=RESPONSE_TYPE.ERROR,
341 error_message="You can only delete using '=' operator ID one at a time!",
342 )
344 for id in ids:
345 resp = self._client.records().delete(table_name, id)
346 if not resp.is_success():
347 raise Exception(resp["message"])