Coverage for mindsdb / integrations / handlers / xata_handler / xata_handler.py: 0%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Optional 

2 

3import pandas as pd 

4import json 

5import xata 

6from xata.helpers import BulkProcessor 

7 

8from mindsdb.integrations.libs.response import RESPONSE_TYPE 

9from mindsdb.integrations.libs.response import HandlerResponse 

10from mindsdb.integrations.libs.response import HandlerResponse as Response 

11from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

12from mindsdb.integrations.libs.vectordatabase_handler import ( 

13 FilterCondition, 

14 FilterOperator, 

15 TableField, 

16 VectorStoreHandler, 

17) 

18from mindsdb.utilities import log 

19 

20logger = log.getLogger(__name__) 

21 

22 

23class XataHandler(VectorStoreHandler): 

24 """This handler handles connection and execution of the Xata statements.""" 

25 

26 name = "xata" 

27 

28 def __init__(self, name: str, **kwargs): 

29 super().__init__(name) 

30 self._connection_data = kwargs.get("connection_data") 

31 self._client_config = { 

32 "db_url": self._connection_data.get("db_url"), 

33 "api_key": self._connection_data.get("api_key"), 

34 } 

35 self._create_table_params = { 

36 "dimension": self._connection_data.get("dimension", 8), 

37 } 

38 self._select_params = { 

39 "similarity_function": self._connection_data.get("similarity_function", "cosineSimilarity"), 

40 } 

41 self._client = None 

42 self.is_connected = False 

43 self.connect() 

44 

45 def __del__(self): 

46 if self.is_connected is True: 

47 self.disconnect() 

48 

49 def connect(self): 

50 """Connect to a Xata database.""" 

51 if self.is_connected is True: 

52 return self._client 

53 try: 

54 self._client = xata.XataClient(**self._client_config) 

55 self.is_connected = True 

56 return self._client 

57 except Exception as e: 

58 logger.error(f"Error connecting to Xata client: {e}!") 

59 self.is_connected = False 

60 

61 def disconnect(self): 

62 """Close the database connection.""" 

63 if self.is_connected is False: 

64 return 

65 self._client = None 

66 self.is_connected = False 

67 

68 def check_connection(self): 

69 """Check the connection to the Xata database.""" 

70 response_code = StatusResponse(False) 

71 need_to_close = self.is_connected is False 

72 # NOTE: no direct way to test this 

73 # try getting the user, if it fails, it means that we are not connected 

74 try: 

75 resp = self._client.users().get() 

76 if not resp.is_success(): 

77 raise Exception(resp["message"]) 

78 response_code.success = True 

79 except Exception as e: 

80 logger.error(f"Error connecting to Xata: {e}!") 

81 response_code.error_message = str(e) 

82 finally: 

83 if response_code.success is True and need_to_close: 

84 self.disconnect() 

85 if response_code.success is False and self.is_connected is True: 

86 self.is_connected = False 

87 return response_code 

88 

89 def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse: 

90 """Create a table with the given name in the Xata database.""" 

91 

92 resp = self._client.table().create(table_name) 

93 if not resp.is_success(): 

94 raise Exception(f"Unable to create table {table_name}: {resp['message']}") 

95 resp = self._client.table().set_schema( 

96 table_name=table_name, 

97 payload={ 

98 "columns": [ 

99 { 

100 "name": "embeddings", 

101 "type": "vector", 

102 "vector": {"dimension": self._create_table_params["dimension"]} 

103 }, 

104 {"name": "content", "type": "text"}, 

105 {"name": "metadata", "type": "json"}, 

106 ] 

107 } 

108 ) 

109 if not resp.is_success(): 

110 raise Exception(f"Unable to change schema of table {table_name}: {resp['message']}") 

111 

112 def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse: 

113 """Delete a table from the Xata database.""" 

114 

115 resp = self._client.table().delete(table_name) 

116 if not resp.is_success(): 

117 raise Exception(f"Unable to delete table: {resp['message']}") 

118 

119 def get_columns(self, table_name: str) -> HandlerResponse: 

120 """Get columns of the given table""" 

121 # Vector stores have predefined columns 

122 try: 

123 # But at least try to see if the table is valid 

124 resp = self._client.table().get_columns(table_name) 

125 if not resp.is_success(): 

126 raise Exception(f"Error getting columns: {resp['message']}") 

127 except Exception as e: 

128 return Response( 

129 resp_type=RESPONSE_TYPE.ERROR, 

130 error_message=f"{e}", 

131 ) 

132 return super().get_columns(table_name) 

133 

134 def get_tables(self) -> HandlerResponse: 

135 """Get the list of tables in the Xata database.""" 

136 try: 

137 table_names = pd.DataFrame( 

138 columns=["TABLE_NAME"], 

139 data=[table_data["name"] for table_data in self._client.branch().get_details()["schema"]["tables"]], 

140 ) 

141 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=table_names) 

142 except Exception as e: 

143 return Response( 

144 resp_type=RESPONSE_TYPE.ERROR, 

145 error_message=f"Error getting list of tables: {e}", 

146 ) 

147 

148 def insert(self, table_name: str, data: pd.DataFrame, columns: List[str] = None): 

149 """ Insert data into the Xata database. """ 

150 if columns: 

151 data = data[columns] 

152 # Convert to records 

153 data = data.to_dict("records") 

154 # Convert metadata to json 

155 for row in data: 

156 if "metadata" in row: 

157 row["metadata"] = json.dumps(row["metadata"]) 

158 if len(data) > 1: 

159 # Bulk processing 

160 bp = BulkProcessor(self._client, throw_exception=True) 

161 bp.put_records(table_name, data) 

162 bp.flush_queue() 

163 

164 elif len(data) == 0: 

165 # Skip 

166 return Response(resp_type=RESPONSE_TYPE.OK) 

167 elif "id" in data[0] and TableField.ID.value in columns: 

168 # If id present 

169 id = data[0]["id"] 

170 rest_of_data = data[0].copy() 

171 del rest_of_data["id"] 

172 

173 resp = self._client.records().insert_with_id( 

174 table_name=table_name, 

175 record_id=id, 

176 payload=rest_of_data, 

177 create_only=True, 

178 columns=columns 

179 ) 

180 if not resp.is_success(): 

181 raise Exception(resp["message"]) 

182 

183 else: 

184 # If id not present 

185 resp = self._client.records().insert( 

186 table_name=table_name, 

187 payload=data[0], 

188 columns=columns 

189 ) 

190 if not resp.is_success(): 

191 raise Exception(resp["message"]) 

192 

193 def update(self, table_name: str, data: pd.DataFrame, columns: List[str] = None) -> HandlerResponse: 

194 """Update data in the Xata database.""" 

195 # Not supported 

196 return super().update(table_name, data, columns) 

197 

198 def _get_xata_operator(self, operator: FilterOperator) -> str: 

199 """Translate SQL operator to oprator understood by Xata filter language.""" 

200 mapping = { 

201 FilterOperator.EQUAL: "$is", 

202 FilterOperator.NOT_EQUAL: "$isNot", 

203 FilterOperator.LESS_THAN: "$lt", 

204 FilterOperator.LESS_THAN_OR_EQUAL: "$le", 

205 FilterOperator.GREATER_THAN: "$gt", 

206 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte", 

207 FilterOperator.LIKE: "$pattern", 

208 } 

209 if operator not in mapping: 

210 raise Exception(f"Operator '{operator}' is not supported!") 

211 return mapping[operator] 

212 

213 def _translate_non_vector_conditions(self, conditions: List[FilterCondition]) -> Optional[dict]: 

214 """ 

215 Translate a list of FilterCondition objects a dict that can be used by Xata for filtering. 

216 E.g., 

217 [ 

218 FilterCondition( 

219 column="metadata.price", 

220 op=FilterOperator.LESS_THAN, 

221 value=100, 

222 ), 

223 FilterCondition( 

224 column="metadata.price", 

225 op=FilterOperator.GREATER_THAN, 

226 value=10, 

227 ) 

228 ] 

229 --> 

230 { 

231 "metadata->price" { 

232 "$gt": 10, 

233 "$lt": 100 

234 }, 

235 } 

236 """ 

237 if not conditions: 

238 return None 

239 # Translate metadata columns 

240 for condition in conditions: 

241 if condition.column.startswith(TableField.METADATA.value): 

242 condition.column = condition.column.replace(".", "->") 

243 # Generate filters 

244 filters = {} 

245 for condition in conditions: 

246 # Skip search vector condition 

247 if condition.column == TableField.SEARCH_VECTOR.value: 

248 continue 

249 current_filter = original_filter = {} 

250 # Special case LIKE: needs pattern translation 

251 if condition.op == FilterOperator.LIKE: 

252 condition.value = condition.value.replace("%", "*").replace("_", "?") 

253 # Generate substatment 

254 current_filter[condition.column] = {self._get_xata_operator(condition.op): condition.value} 

255 # Check for conflicting and insert 

256 for key in original_filter: 

257 if key in filters: 

258 filters[key] = {**filters[key], **original_filter[key]} 

259 else: 

260 filters = {**filters, **original_filter} 

261 return filters if filters else None 

262 

263 def select(self, table_name: str, columns: List[str] = None, conditions: List[FilterCondition] = None, 

264 offset: int = None, limit: int = None) -> pd.DataFrame: 

265 """Run general query or a vector similarity search and return results.""" 

266 if not columns: 

267 columns = [col["name"] for col in self.SCHEMA] 

268 # Generate filter conditions 

269 filters = self._translate_non_vector_conditions(conditions) 

270 # Check for search vector 

271 search_vector = ( 

272 [] 

273 if conditions is None 

274 else [ 

275 condition.value 

276 for condition in conditions 

277 if condition.column == TableField.SEARCH_VECTOR.value 

278 ] 

279 ) 

280 if len(search_vector) > 0: 

281 search_vector = search_vector[0] 

282 else: 

283 search_vector = None 

284 # Search 

285 results_df = pd.DataFrame(columns) 

286 if search_vector is not None: 

287 # Similarity 

288 

289 params = { 

290 "queryVector": search_vector, 

291 "column": TableField.EMBEDDINGS.value, 

292 "similarityFunction": self._select_params["similarity_function"] 

293 } 

294 if filters: 

295 params["filter"] = filters 

296 if limit: 

297 params["size"] = limit 

298 results = self._client.data().vector_search(table_name, params) 

299 # Check for errors 

300 if not results.is_success(): 

301 raise Exception(results["message"]) 

302 # Convert result 

303 results_df = pd.DataFrame.from_records(results["records"]) 

304 if "xata" in results_df.columns: 

305 results_df["xata"] = results_df["xata"].apply(lambda x: x["score"]) 

306 results_df.rename({"xata": TableField.DISTANCE.value}, axis=1, inplace=True) 

307 

308 else: 

309 # General get query 

310 

311 params = { 

312 "columns": columns if columns else [], 

313 } 

314 if filters: 

315 params["filter"] = filters 

316 if limit or offset: 

317 params["page"] = {} 

318 if limit: 

319 params["page"]["size"] = limit 

320 if offset: 

321 params["page"]["offset"] = offset 

322 results = self._client.data().query(table_name, params) 

323 # Check for errors 

324 if not results.is_success(): 

325 raise Exception(results["message"]) 

326 # Convert result 

327 results_df = pd.DataFrame.from_records(results["records"]) 

328 if "xata" in results_df.columns: 

329 results_df.drop(["xata"], axis=1, inplace=True) 

330 

331 return results_df 

332 

333 def delete(self, table_name: str, conditions: List[FilterCondition] = None): 

334 ids = [] 

335 for condition in conditions: 

336 if condition.op == FilterOperator.EQUAL: 

337 ids.append(condition.value) 

338 else: 

339 return Response( 

340 resp_type=RESPONSE_TYPE.ERROR, 

341 error_message="You can only delete using '=' operator ID one at a time!", 

342 ) 

343 

344 for id in ids: 

345 resp = self._client.records().delete(table_name, id) 

346 if not resp.is_success(): 

347 raise Exception(resp["message"])