Coverage for mindsdb / integrations / handlers / weaviate_handler / weaviate_handler.py: 0%

249 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2from datetime import datetime 

3from typing import List, Optional 

4 

5import weaviate 

6from weaviate.embedded import EmbeddedOptions 

7import pandas as pd 

8 

9from mindsdb.integrations.libs.response import RESPONSE_TYPE 

10from mindsdb.integrations.libs.response import HandlerResponse 

11from mindsdb.integrations.libs.response import HandlerResponse as Response 

12from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

13from mindsdb.integrations.libs.vectordatabase_handler import ( 

14 FilterCondition, 

15 FilterOperator, 

16 TableField, 

17 VectorStoreHandler, 

18) 

19from mindsdb.utilities import log 

20from weaviate.util import generate_uuid5 

21 

22logger = log.getLogger(__name__) 

23 

24 

25class WeaviateDBHandler(VectorStoreHandler): 

26 """This handler handles connection and execution of the Weaviate statements.""" 

27 

28 name = "weaviate" 

29 

30 def __init__(self, name: str, **kwargs): 

31 super().__init__(name) 

32 

33 self._connection_data = kwargs.get("connection_data") 

34 

35 self._client_config = { 

36 "weaviate_url": self._connection_data.get("weaviate_url"), 

37 "weaviate_api_key": self._connection_data.get("weaviate_api_key"), 

38 "persistence_directory": self._connection_data.get("persistence_directory"), 

39 } 

40 

41 if not ( 

42 self._client_config.get("weaviate_url") 

43 or self._client_config.get("persistence_directory") 

44 ): 

45 raise Exception( 

46 "Either url or persist_directory is required for weaviate connection!" 

47 ) 

48 

49 self._client = None 

50 self._embedded_options = None 

51 self.is_connected = False 

52 self.connect() 

53 

54 def _get_client(self) -> weaviate.Client: 

55 if not ( 

56 self._client_config 

57 and ( 

58 self._client_config.get("weaviate_url") 

59 or self._client_config.get("persistence_directory") 

60 ) 

61 ): 

62 raise Exception("Client config is not set! or missing parameters") 

63 

64 # decide the client type to be used, either persistent or httpclient 

65 if self._client_config.get("persistence_directory"): 

66 self._embedded_options = EmbeddedOptions( 

67 persistence_data_path=self._client_config.get("persistence_directory") 

68 ) 

69 return weaviate.Client(embedded_options=self._embedded_options) 

70 if self._client_config.get("weaviate_api_key"): 

71 return weaviate.Client( 

72 url=self._client_config["weaviate_url"], 

73 auth_client_secret=weaviate.AuthApiKey( 

74 api_key=self._client_config["weaviate_api_key"] 

75 ), 

76 ) 

77 return weaviate.Client(url=self._client_config["weaviate_url"]) 

78 

79 def __del__(self): 

80 self.is_connected = False 

81 if self._embedded_options: 

82 self._client._connection.embedded_db.stop() 

83 del self._embedded_options 

84 self._embedded_options = None 

85 self._client._connection.close() 

86 if self._client: 

87 del self._client 

88 

89 def connect(self): 

90 """Connect to a weaviate database.""" 

91 if self.is_connected: 

92 return self._client 

93 

94 try: 

95 self._client = self._get_client() 

96 self.is_connected = True 

97 return self._client 

98 except Exception as e: 

99 logger.error(f"Error connecting to weaviate client, {e}!") 

100 self.is_connected = False 

101 

102 def disconnect(self): 

103 """Close the database connection.""" 

104 

105 if not self.is_connected: 

106 return 

107 if self._embedded_options: 

108 self._client._connection.embedded_db.stop() 

109 del self._embedded_options 

110 del self._client 

111 self._embedded_options = None 

112 self._client = None 

113 self.is_connected = False 

114 

115 def check_connection(self): 

116 """Check the connection to the Weaviate database.""" 

117 response_code = StatusResponse(False) 

118 

119 try: 

120 if self._client.is_live(): 

121 response_code.success = True 

122 except Exception as e: 

123 logger.error(f"Error connecting to weaviate , {e}!") 

124 response_code.error_message = str(e) 

125 finally: 

126 if response_code.success and not self.is_connected: 

127 self.disconnect() 

128 if not response_code.success and self.is_connected: 

129 self.is_connected = False 

130 

131 return response_code 

132 

133 @staticmethod 

134 def _get_weaviate_operator(operator: FilterOperator) -> str: 

135 mapping = { 

136 FilterOperator.EQUAL: "Equal", 

137 FilterOperator.NOT_EQUAL: "NotEqual", 

138 FilterOperator.LESS_THAN: "LessThan", 

139 FilterOperator.LESS_THAN_OR_EQUAL: "LessThanEqual", 

140 FilterOperator.GREATER_THAN: "GreaterThan", 

141 FilterOperator.GREATER_THAN_OR_EQUAL: "GreaterThanEqual", 

142 FilterOperator.IS_NULL: "IsNull", 

143 FilterOperator.LIKE: "Like", 

144 } 

145 

146 if operator not in mapping: 

147 raise Exception(f"Operator {operator} is not supported by weaviate!") 

148 

149 return mapping[operator] 

150 

151 @staticmethod 

152 def _get_weaviate_value_type(value) -> str: 

153 # https://github.com/weaviate/weaviate-python-client/blob/c760b1d59b2a222e770d53cc257b1bf993a0a592/weaviate/gql/filter.py#L18 

154 if isinstance(value, list): 

155 value_list_types = { 

156 str: "valueTextList", 

157 int: "valueIntList", 

158 float: "valueIntList", 

159 bool: "valueBooleanList", 

160 } 

161 if not value: 

162 raise Exception("Empty list is not supported") 

163 value_type = value_list_types.get(type(value[0])) 

164 

165 else: 

166 value_primitive_types = { 

167 str: "valueText", 

168 int: "valueInt", 

169 float: "valueInt", 

170 datetime: "valueDate", 

171 bool: "valueBoolean", 

172 } 

173 value_type = value_primitive_types.get(type(value)) 

174 

175 if not value_type: 

176 raise Exception(f"Value type {type(value)} is not supported by weaviate!") 

177 

178 return value_type 

179 

180 def _translate_condition( 

181 self, 

182 table_name: str, 

183 conditions: List[FilterCondition] = None, 

184 meta_conditions: List[FilterCondition] = None, 

185 ) -> Optional[dict]: 

186 """ 

187 Translate a list of FilterCondition objects a dict that can be used by Weaviate. 

188 E.g., 

189 [ 

190 FilterCondition( 

191 column="metadata.created_at", 

192 op=FilterOperator.LESS_THAN, 

193 value="2020-01-01", 

194 ), 

195 FilterCondition( 

196 column="metadata.created_at", 

197 op=FilterOperator.GREATER_THAN, 

198 value="2019-01-01", 

199 ) 

200 ] 

201 --> 

202 {"operator": "And", 

203 "operands": [ 

204 { 

205 "path": ["created_at"], 

206 "operator": "LessThan", 

207 "valueText": "2020-01-01", 

208 }, 

209 { 

210 "path": ["created_at"], 

211 "operator": "GreaterThan", 

212 "valueInt": "2019-01-01", 

213 }, 

214 ]} 

215 """ 

216 table_name = table_name.capitalize() 

217 metadata_table_name = table_name.capitalize() + "_metadata" 

218 # 

219 if not (conditions or meta_conditions): 

220 return None 

221 

222 # we translate each condition into a single dict 

223 # conditions on columns 

224 weaviate_conditions = [] 

225 if conditions: 

226 for condition in conditions: 

227 column_key = condition.column 

228 value_type = self._get_weaviate_value_type(condition.value) 

229 weaviate_conditions.append( 

230 { 

231 "path": [column_key], 

232 "operator": self._get_weaviate_operator(condition.op), 

233 value_type: condition.value, 

234 } 

235 ) 

236 # condition on metadata columns 

237 if meta_conditions: 

238 for condition in meta_conditions: 

239 meta_key = condition.column.split(".")[-1] 

240 value_type = self._get_weaviate_value_type(condition.value) 

241 weaviate_conditions.append( 

242 { 

243 "path": [ 

244 "associatedMetadata", 

245 metadata_table_name, 

246 meta_key, 

247 ], 

248 "operator": self._get_weaviate_operator(condition.op), 

249 value_type: condition.value, 

250 } 

251 ) 

252 

253 # we combine all conditions into a single dict 

254 all_conditions = ( 

255 {"operator": "And", "operands": weaviate_conditions} 

256 # combining all conditions if there are more than one conditions 

257 if len(weaviate_conditions) > 1 

258 # only a single condition 

259 else weaviate_conditions[0] 

260 ) 

261 return all_conditions 

262 

263 def select( 

264 self, 

265 table_name: str, 

266 columns: List[str] = None, 

267 conditions: List[FilterCondition] = None, 

268 offset: int = None, 

269 limit: int = None, 

270 ): 

271 table_name = table_name.capitalize() 

272 # columns which we will always provide in the result 

273 filters = None 

274 if conditions: 

275 non_metadata_conditions = [ 

276 condition 

277 for condition in conditions 

278 if not condition.column.startswith(TableField.METADATA.value) 

279 and condition.column != TableField.SEARCH_VECTOR.value 

280 and condition.column != TableField.EMBEDDINGS.value 

281 ] 

282 metadata_conditions = [ 

283 condition 

284 for condition in conditions 

285 if condition.column.startswith(TableField.METADATA.value) 

286 ] 

287 filters = self._translate_condition( 

288 table_name, 

289 non_metadata_conditions if non_metadata_conditions else None, 

290 metadata_conditions if metadata_conditions else None, 

291 ) 

292 

293 # check if embedding vector filter is present 

294 vector_filter = ( 

295 None 

296 if not conditions 

297 else [ 

298 condition 

299 for condition in conditions 

300 if condition.column == TableField.SEARCH_VECTOR.value 

301 or condition.column == TableField.EMBEDDINGS.value 

302 ] 

303 ) 

304 

305 for col in ["id", "embeddings", "distance", "metadata"]: 

306 if col in columns: 

307 columns.remove(col) 

308 

309 metadata_table = table_name.capitalize() + "_metadata" 

310 

311 metadata_fields = " ".join( 

312 [ 

313 prop["name"] 

314 for prop in self._client.schema.get(metadata_table)["properties"] 

315 ] 

316 ) 

317 

318 # query to get all metadata fields 

319 metadata_query = ( 

320 f"associatedMetadata {{ ... on {metadata_table} {{ {metadata_fields} }} }}" 

321 ) 

322 

323 if columns: 

324 query = self._client.query.get( 

325 table_name, 

326 columns + [metadata_query], 

327 ).with_additional(["id vector distance"]) 

328 else: 

329 query = self._client.query.get( 

330 table_name, 

331 [metadata_query], 

332 ).with_additional(["id vector distance"]) 

333 if vector_filter: 

334 # similarity search 

335 # assuming the similarity search is on content 

336 # assuming there would be only one vector based search per query 

337 vector_filter = vector_filter[0] 

338 near_vector = { 

339 "vector": ast.literal_eval(vector_filter.value) 

340 if isinstance(vector_filter.value, str) 

341 else vector_filter.value 

342 } 

343 query = query.with_near_vector(near_vector) 

344 if filters: 

345 query = query.with_where(filters) 

346 if limit: 

347 query = query.with_limit(limit) 

348 result = query.do() 

349 result = result["data"]["Get"][table_name.capitalize()] 

350 ids = [query_obj["_additional"]["id"] for query_obj in result] 

351 contents = [query_obj.get("content") for query_obj in result] 

352 distances = [ 

353 query_obj.get("_additional").get("distance") for query_obj in result 

354 ] 

355 # distances will be null for non vector/embedding query 

356 vectors = [query_obj.get("_additional").get("vector") for query_obj in result] 

357 metadatas = [query_obj.get("associatedMetadata")[0] for query_obj in result] 

358 

359 payload = { 

360 TableField.ID.value: ids, 

361 TableField.CONTENT.value: contents, 

362 TableField.METADATA.value: metadatas, 

363 TableField.EMBEDDINGS.value: vectors, 

364 TableField.DISTANCE.value: distances, 

365 } 

366 

367 if columns: 

368 payload = { 

369 column: payload[column] 

370 for column in columns + ["id", "embeddings", "distance", "metadata"] 

371 if column != TableField.EMBEDDINGS.value 

372 } 

373 

374 # always include distance 

375 if distances: 

376 payload[TableField.DISTANCE.value] = distances 

377 result_df = pd.DataFrame(payload) 

378 return result_df 

379 

380 def insert( 

381 self, table_name: str, data: pd.DataFrame, columns: List[str] = None 

382 ): 

383 """ 

384 Insert data into the Weaviate database. 

385 """ 

386 

387 table_name = table_name.capitalize() 

388 

389 # drop columns with all None values 

390 

391 data.dropna(axis=1, inplace=True) 

392 

393 data = data.to_dict(orient="records") 

394 # parsing the records one by one as we need to update metadata (which has variable columns) 

395 for record in data: 

396 metadata_data = record.get(TableField.METADATA.value) 

397 data_object = {"content": record.get(TableField.CONTENT.value)} 

398 data_obj_id = ( 

399 record[TableField.ID.value] 

400 if TableField.ID.value in record.keys() 

401 else generate_uuid5(data_object) 

402 ) 

403 obj_id = self._client.data_object.create( 

404 data_object=data_object, 

405 class_name=table_name, 

406 vector=record[TableField.EMBEDDINGS.value], 

407 uuid=data_obj_id, 

408 ) 

409 if metadata_data: 

410 meta_id = self.add_metadata(metadata_data, table_name) 

411 self._client.data_object.reference.add( 

412 from_uuid=obj_id, 

413 from_property_name="associatedMetadata", 

414 to_uuid=meta_id, 

415 ) 

416 

417 def update( 

418 self, table_name: str, data: pd.DataFrame, columns: List[str] = None 

419 ): 

420 """ 

421 Update data in the weaviate database. 

422 """ 

423 table_name = table_name.capitalize() 

424 metadata_table_name = table_name.capitalize() + "_metadata" 

425 data_list = data.to_dict("records") 

426 for row in data_list: 

427 non_metadata_keys = [ 

428 key 

429 for key in row.keys() 

430 if key and not key.startswith(TableField.METADATA.value) 

431 ] 

432 metadata_keys = [ 

433 key.split(".")[1] 

434 for key in row.keys() 

435 if key and key.startswith(TableField.METADATA.value) 

436 ] 

437 

438 id_filter = {"path": ["id"], "operator": "Equal", "valueText": row["id"]} 

439 metadata_id_query = f"associatedMetadata {{ ... on {metadata_table_name} {{ _additional {{ id }} }} }}" 

440 result = ( 

441 self._client.query.get(table_name, metadata_id_query) 

442 .with_additional(["id"]) 

443 .with_where(id_filter) 

444 .do() 

445 ) 

446 

447 metadata_id = result["data"]["Get"][table_name][0]["associatedMetadata"][0][ 

448 "_additional" 

449 ]["id"][0] 

450 # updating table 

451 self._client.data_object.update( 

452 uuid=row["id"], 

453 class_name=table_name, 

454 data_object={key: row[key] for key in non_metadata_keys}, 

455 ) 

456 # updating metadata 

457 self._client.data_object.update( 

458 uuid=metadata_id, 

459 class_name=metadata_table_name, 

460 data_object={key: row[key] for key in metadata_keys}, 

461 ) 

462 

463 def delete( 

464 self, table_name: str, conditions: List[FilterCondition] = None 

465 ): 

466 table_name = table_name.capitalize() 

467 non_metadata_conditions = [ 

468 condition 

469 for condition in conditions 

470 if not condition.column.startswith(TableField.METADATA.value) 

471 and condition.column != TableField.SEARCH_VECTOR.value 

472 and condition.column != TableField.EMBEDDINGS.value 

473 ] 

474 metadata_conditions = [ 

475 condition 

476 for condition in conditions 

477 if condition.column.startswith(TableField.METADATA.value) 

478 ] 

479 filters = self._translate_condition( 

480 table_name, 

481 non_metadata_conditions if non_metadata_conditions else None, 

482 metadata_conditions if metadata_conditions else None, 

483 ) 

484 if not filters: 

485 raise Exception("Delete query must have at least one condition!") 

486 metadata_table_name = table_name.capitalize() + "_metadata" 

487 # query to get metadata ids 

488 metadata_query = f"associatedMetadata {{ ... on {metadata_table_name} {{ _additional {{ id }} }} }}" 

489 result = ( 

490 self._client.query.get(table_name, metadata_query) 

491 .with_additional(["id"]) 

492 .with_where(filters) 

493 .do() 

494 ) 

495 result = result["data"]["Get"][table_name] 

496 metadata_table_name = table_name.capitalize() + "_metadata" 

497 table_ids = [] 

498 metadata_ids = [] 

499 for i in result: 

500 table_ids.append(i["_additional"]["id"]) 

501 metadata_ids.append(i["associatedMetadata"][0]["_additional"]["id"]) 

502 self._client.batch.delete_objects( 

503 class_name=table_name, 

504 where={ 

505 "path": ["id"], 

506 "operator": "ContainsAny", 

507 "valueTextArray": table_ids, 

508 }, 

509 ) 

510 self._client.batch.delete_objects( 

511 class_name=metadata_table_name, 

512 where={ 

513 "path": ["id"], 

514 "operator": "ContainsAny", 

515 "valueTextArray": metadata_ids, 

516 }, 

517 ) 

518 

519 def create_table(self, table_name: str, if_not_exists=True): 

520 """ 

521 Create a class with the given name in the weaviate database. 

522 """ 

523 # separate metadata table for each table (as different tables will have different metadata columns) 

524 # this reduces the query time using metadata but increases the insertion time 

525 metadata_table_name = table_name + "_metadata" 

526 if not self._client.schema.exists(metadata_table_name): 

527 self._client.schema.create_class({"class": metadata_table_name}) 

528 if not self._client.schema.exists(table_name): 

529 self._client.schema.create_class( 

530 { 

531 "class": table_name, 

532 "properties": [ 

533 {"dataType": ["text"], "name": prop["name"]} 

534 for prop in self.SCHEMA 

535 if prop["name"] != "id" 

536 and prop["name"] != "embeddings" 

537 and prop["name"] != "metadata" 

538 ], 

539 "vectorIndexType": "hnsw", 

540 } 

541 ) 

542 add_prop = { 

543 "name": "associatedMetadata", 

544 "dataType": [metadata_table_name.capitalize()], 

545 } 

546 self._client.schema.property.create(table_name.capitalize(), add_prop) 

547 

548 def drop_table(self, table_name: str, if_exists=True): 

549 """ 

550 Delete a class from the weaviate database. 

551 """ 

552 table_name = table_name.capitalize() 

553 metadata_table_name = table_name.capitalize() + "_metadata" 

554 table_id_query = self._client.query.get(table_name).with_additional(["id"]).do() 

555 table_ids = [ 

556 i["_additional"]["id"] for i in table_id_query["data"]["Get"][table_name] 

557 ] 

558 metadata_table_id_query = ( 

559 self._client.query.get(metadata_table_name).with_additional(["id"]).do() 

560 ) 

561 metadata_ids = [ 

562 i["_additional"]["id"] 

563 for i in metadata_table_id_query["data"]["Get"][metadata_table_name] 

564 ] 

565 self._client.batch.delete_objects( 

566 class_name=table_name, 

567 where={ 

568 "path": ["id"], 

569 "operator": "ContainsAny", 

570 "valueTextArray": table_ids, 

571 }, 

572 ) 

573 self._client.batch.delete_objects( 

574 class_name=metadata_table_name, 

575 where={ 

576 "path": ["id"], 

577 "operator": "ContainsAny", 

578 "valueTextArray": metadata_ids, 

579 }, 

580 ) 

581 try: 

582 self._client.schema.delete_class(table_name) 

583 self._client.schema.delete_class(metadata_table_name) 

584 except ValueError: 

585 if not if_exists: 

586 raise Exception(f"Table {table_name} does not exist!") 

587 

588 def get_tables(self) -> HandlerResponse: 

589 """ 

590 Get the list of tables in the Weaviate database. 

591 """ 

592 query_tables = self._client.schema.get() 

593 tables = [] 

594 if query_tables: 

595 tables = [table["class"] for table in query_tables["classes"]] 

596 table_name = pd.DataFrame( 

597 columns=["table_name"], 

598 data=tables, 

599 ) 

600 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=table_name) 

601 

602 def get_columns(self, table_name: str) -> HandlerResponse: 

603 table_name = table_name.capitalize() 

604 # check if table exists 

605 try: 

606 table = self._client.schema.get(table_name) 

607 except ValueError: 

608 return Response( 

609 resp_type=RESPONSE_TYPE.ERROR, 

610 error_message=f"Table {table_name} does not exist!", 

611 ) 

612 data = pd.DataFrame( 

613 data=[ 

614 {"COLUMN_NAME": column["name"], "DATA_TYPE": column["dataType"][0]} 

615 for column in table["properties"] 

616 ] 

617 ) 

618 return Response(data_frame=data, resp_type=RESPONSE_TYPE.OK) 

619 

620 def add_metadata(self, data: dict, table_name: str): 

621 table_name = table_name.capitalize() 

622 metadata_table_name = table_name.capitalize() + "_metadata" 

623 self._client.schema.get(metadata_table_name) 

624 # getting existing metadata fields 

625 added_prop_list = [ 

626 prop["name"] 

627 for prop in self._client.schema.get(metadata_table_name)["properties"] 

628 ] 

629 # as metadata columns are not fixed, at every entry, a check takes place for the columns 

630 for prop in data.keys(): 

631 if prop not in added_prop_list: 

632 if isinstance(data[prop], int): 

633 add_prop = { 

634 "name": prop, 

635 "dataType": ["int"], 

636 } 

637 elif isinstance(data[prop][0], datetime): 

638 add_prop = { 

639 "name": prop, 

640 "dataType": ["date"], 

641 } 

642 else: 

643 add_prop = { 

644 "name": prop, 

645 "dataType": ["string"], 

646 } 

647 # when a new column is identified, it is added to the metadata table 

648 self._client.schema.property.create(metadata_table_name, add_prop) 

649 metadata_id = self._client.data_object.create( 

650 data_object=data, class_name=table_name.capitalize() + "_metadata" 

651 ) 

652 return metadata_id