Coverage for mindsdb / integrations / handlers / weaviate_handler / weaviate_handler.py: 0%
249 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import ast
2from datetime import datetime
3from typing import List, Optional
5import weaviate
6from weaviate.embedded import EmbeddedOptions
7import pandas as pd
9from mindsdb.integrations.libs.response import RESPONSE_TYPE
10from mindsdb.integrations.libs.response import HandlerResponse
11from mindsdb.integrations.libs.response import HandlerResponse as Response
12from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
13from mindsdb.integrations.libs.vectordatabase_handler import (
14 FilterCondition,
15 FilterOperator,
16 TableField,
17 VectorStoreHandler,
18)
19from mindsdb.utilities import log
20from weaviate.util import generate_uuid5
22logger = log.getLogger(__name__)
25class WeaviateDBHandler(VectorStoreHandler):
26 """This handler handles connection and execution of the Weaviate statements."""
28 name = "weaviate"
30 def __init__(self, name: str, **kwargs):
31 super().__init__(name)
33 self._connection_data = kwargs.get("connection_data")
35 self._client_config = {
36 "weaviate_url": self._connection_data.get("weaviate_url"),
37 "weaviate_api_key": self._connection_data.get("weaviate_api_key"),
38 "persistence_directory": self._connection_data.get("persistence_directory"),
39 }
41 if not (
42 self._client_config.get("weaviate_url")
43 or self._client_config.get("persistence_directory")
44 ):
45 raise Exception(
46 "Either url or persist_directory is required for weaviate connection!"
47 )
49 self._client = None
50 self._embedded_options = None
51 self.is_connected = False
52 self.connect()
54 def _get_client(self) -> weaviate.Client:
55 if not (
56 self._client_config
57 and (
58 self._client_config.get("weaviate_url")
59 or self._client_config.get("persistence_directory")
60 )
61 ):
62 raise Exception("Client config is not set! or missing parameters")
64 # decide the client type to be used, either persistent or httpclient
65 if self._client_config.get("persistence_directory"):
66 self._embedded_options = EmbeddedOptions(
67 persistence_data_path=self._client_config.get("persistence_directory")
68 )
69 return weaviate.Client(embedded_options=self._embedded_options)
70 if self._client_config.get("weaviate_api_key"):
71 return weaviate.Client(
72 url=self._client_config["weaviate_url"],
73 auth_client_secret=weaviate.AuthApiKey(
74 api_key=self._client_config["weaviate_api_key"]
75 ),
76 )
77 return weaviate.Client(url=self._client_config["weaviate_url"])
79 def __del__(self):
80 self.is_connected = False
81 if self._embedded_options:
82 self._client._connection.embedded_db.stop()
83 del self._embedded_options
84 self._embedded_options = None
85 self._client._connection.close()
86 if self._client:
87 del self._client
89 def connect(self):
90 """Connect to a weaviate database."""
91 if self.is_connected:
92 return self._client
94 try:
95 self._client = self._get_client()
96 self.is_connected = True
97 return self._client
98 except Exception as e:
99 logger.error(f"Error connecting to weaviate client, {e}!")
100 self.is_connected = False
102 def disconnect(self):
103 """Close the database connection."""
105 if not self.is_connected:
106 return
107 if self._embedded_options:
108 self._client._connection.embedded_db.stop()
109 del self._embedded_options
110 del self._client
111 self._embedded_options = None
112 self._client = None
113 self.is_connected = False
115 def check_connection(self):
116 """Check the connection to the Weaviate database."""
117 response_code = StatusResponse(False)
119 try:
120 if self._client.is_live():
121 response_code.success = True
122 except Exception as e:
123 logger.error(f"Error connecting to weaviate , {e}!")
124 response_code.error_message = str(e)
125 finally:
126 if response_code.success and not self.is_connected:
127 self.disconnect()
128 if not response_code.success and self.is_connected:
129 self.is_connected = False
131 return response_code
133 @staticmethod
134 def _get_weaviate_operator(operator: FilterOperator) -> str:
135 mapping = {
136 FilterOperator.EQUAL: "Equal",
137 FilterOperator.NOT_EQUAL: "NotEqual",
138 FilterOperator.LESS_THAN: "LessThan",
139 FilterOperator.LESS_THAN_OR_EQUAL: "LessThanEqual",
140 FilterOperator.GREATER_THAN: "GreaterThan",
141 FilterOperator.GREATER_THAN_OR_EQUAL: "GreaterThanEqual",
142 FilterOperator.IS_NULL: "IsNull",
143 FilterOperator.LIKE: "Like",
144 }
146 if operator not in mapping:
147 raise Exception(f"Operator {operator} is not supported by weaviate!")
149 return mapping[operator]
151 @staticmethod
152 def _get_weaviate_value_type(value) -> str:
153 # https://github.com/weaviate/weaviate-python-client/blob/c760b1d59b2a222e770d53cc257b1bf993a0a592/weaviate/gql/filter.py#L18
154 if isinstance(value, list):
155 value_list_types = {
156 str: "valueTextList",
157 int: "valueIntList",
158 float: "valueIntList",
159 bool: "valueBooleanList",
160 }
161 if not value:
162 raise Exception("Empty list is not supported")
163 value_type = value_list_types.get(type(value[0]))
165 else:
166 value_primitive_types = {
167 str: "valueText",
168 int: "valueInt",
169 float: "valueInt",
170 datetime: "valueDate",
171 bool: "valueBoolean",
172 }
173 value_type = value_primitive_types.get(type(value))
175 if not value_type:
176 raise Exception(f"Value type {type(value)} is not supported by weaviate!")
178 return value_type
180 def _translate_condition(
181 self,
182 table_name: str,
183 conditions: List[FilterCondition] = None,
184 meta_conditions: List[FilterCondition] = None,
185 ) -> Optional[dict]:
186 """
187 Translate a list of FilterCondition objects a dict that can be used by Weaviate.
188 E.g.,
189 [
190 FilterCondition(
191 column="metadata.created_at",
192 op=FilterOperator.LESS_THAN,
193 value="2020-01-01",
194 ),
195 FilterCondition(
196 column="metadata.created_at",
197 op=FilterOperator.GREATER_THAN,
198 value="2019-01-01",
199 )
200 ]
201 -->
202 {"operator": "And",
203 "operands": [
204 {
205 "path": ["created_at"],
206 "operator": "LessThan",
207 "valueText": "2020-01-01",
208 },
209 {
210 "path": ["created_at"],
211 "operator": "GreaterThan",
212 "valueInt": "2019-01-01",
213 },
214 ]}
215 """
216 table_name = table_name.capitalize()
217 metadata_table_name = table_name.capitalize() + "_metadata"
218 #
219 if not (conditions or meta_conditions):
220 return None
222 # we translate each condition into a single dict
223 # conditions on columns
224 weaviate_conditions = []
225 if conditions:
226 for condition in conditions:
227 column_key = condition.column
228 value_type = self._get_weaviate_value_type(condition.value)
229 weaviate_conditions.append(
230 {
231 "path": [column_key],
232 "operator": self._get_weaviate_operator(condition.op),
233 value_type: condition.value,
234 }
235 )
236 # condition on metadata columns
237 if meta_conditions:
238 for condition in meta_conditions:
239 meta_key = condition.column.split(".")[-1]
240 value_type = self._get_weaviate_value_type(condition.value)
241 weaviate_conditions.append(
242 {
243 "path": [
244 "associatedMetadata",
245 metadata_table_name,
246 meta_key,
247 ],
248 "operator": self._get_weaviate_operator(condition.op),
249 value_type: condition.value,
250 }
251 )
253 # we combine all conditions into a single dict
254 all_conditions = (
255 {"operator": "And", "operands": weaviate_conditions}
256 # combining all conditions if there are more than one conditions
257 if len(weaviate_conditions) > 1
258 # only a single condition
259 else weaviate_conditions[0]
260 )
261 return all_conditions
263 def select(
264 self,
265 table_name: str,
266 columns: List[str] = None,
267 conditions: List[FilterCondition] = None,
268 offset: int = None,
269 limit: int = None,
270 ):
271 table_name = table_name.capitalize()
272 # columns which we will always provide in the result
273 filters = None
274 if conditions:
275 non_metadata_conditions = [
276 condition
277 for condition in conditions
278 if not condition.column.startswith(TableField.METADATA.value)
279 and condition.column != TableField.SEARCH_VECTOR.value
280 and condition.column != TableField.EMBEDDINGS.value
281 ]
282 metadata_conditions = [
283 condition
284 for condition in conditions
285 if condition.column.startswith(TableField.METADATA.value)
286 ]
287 filters = self._translate_condition(
288 table_name,
289 non_metadata_conditions if non_metadata_conditions else None,
290 metadata_conditions if metadata_conditions else None,
291 )
293 # check if embedding vector filter is present
294 vector_filter = (
295 None
296 if not conditions
297 else [
298 condition
299 for condition in conditions
300 if condition.column == TableField.SEARCH_VECTOR.value
301 or condition.column == TableField.EMBEDDINGS.value
302 ]
303 )
305 for col in ["id", "embeddings", "distance", "metadata"]:
306 if col in columns:
307 columns.remove(col)
309 metadata_table = table_name.capitalize() + "_metadata"
311 metadata_fields = " ".join(
312 [
313 prop["name"]
314 for prop in self._client.schema.get(metadata_table)["properties"]
315 ]
316 )
318 # query to get all metadata fields
319 metadata_query = (
320 f"associatedMetadata {{ ... on {metadata_table} {{ {metadata_fields} }} }}"
321 )
323 if columns:
324 query = self._client.query.get(
325 table_name,
326 columns + [metadata_query],
327 ).with_additional(["id vector distance"])
328 else:
329 query = self._client.query.get(
330 table_name,
331 [metadata_query],
332 ).with_additional(["id vector distance"])
333 if vector_filter:
334 # similarity search
335 # assuming the similarity search is on content
336 # assuming there would be only one vector based search per query
337 vector_filter = vector_filter[0]
338 near_vector = {
339 "vector": ast.literal_eval(vector_filter.value)
340 if isinstance(vector_filter.value, str)
341 else vector_filter.value
342 }
343 query = query.with_near_vector(near_vector)
344 if filters:
345 query = query.with_where(filters)
346 if limit:
347 query = query.with_limit(limit)
348 result = query.do()
349 result = result["data"]["Get"][table_name.capitalize()]
350 ids = [query_obj["_additional"]["id"] for query_obj in result]
351 contents = [query_obj.get("content") for query_obj in result]
352 distances = [
353 query_obj.get("_additional").get("distance") for query_obj in result
354 ]
355 # distances will be null for non vector/embedding query
356 vectors = [query_obj.get("_additional").get("vector") for query_obj in result]
357 metadatas = [query_obj.get("associatedMetadata")[0] for query_obj in result]
359 payload = {
360 TableField.ID.value: ids,
361 TableField.CONTENT.value: contents,
362 TableField.METADATA.value: metadatas,
363 TableField.EMBEDDINGS.value: vectors,
364 TableField.DISTANCE.value: distances,
365 }
367 if columns:
368 payload = {
369 column: payload[column]
370 for column in columns + ["id", "embeddings", "distance", "metadata"]
371 if column != TableField.EMBEDDINGS.value
372 }
374 # always include distance
375 if distances:
376 payload[TableField.DISTANCE.value] = distances
377 result_df = pd.DataFrame(payload)
378 return result_df
380 def insert(
381 self, table_name: str, data: pd.DataFrame, columns: List[str] = None
382 ):
383 """
384 Insert data into the Weaviate database.
385 """
387 table_name = table_name.capitalize()
389 # drop columns with all None values
391 data.dropna(axis=1, inplace=True)
393 data = data.to_dict(orient="records")
394 # parsing the records one by one as we need to update metadata (which has variable columns)
395 for record in data:
396 metadata_data = record.get(TableField.METADATA.value)
397 data_object = {"content": record.get(TableField.CONTENT.value)}
398 data_obj_id = (
399 record[TableField.ID.value]
400 if TableField.ID.value in record.keys()
401 else generate_uuid5(data_object)
402 )
403 obj_id = self._client.data_object.create(
404 data_object=data_object,
405 class_name=table_name,
406 vector=record[TableField.EMBEDDINGS.value],
407 uuid=data_obj_id,
408 )
409 if metadata_data:
410 meta_id = self.add_metadata(metadata_data, table_name)
411 self._client.data_object.reference.add(
412 from_uuid=obj_id,
413 from_property_name="associatedMetadata",
414 to_uuid=meta_id,
415 )
417 def update(
418 self, table_name: str, data: pd.DataFrame, columns: List[str] = None
419 ):
420 """
421 Update data in the weaviate database.
422 """
423 table_name = table_name.capitalize()
424 metadata_table_name = table_name.capitalize() + "_metadata"
425 data_list = data.to_dict("records")
426 for row in data_list:
427 non_metadata_keys = [
428 key
429 for key in row.keys()
430 if key and not key.startswith(TableField.METADATA.value)
431 ]
432 metadata_keys = [
433 key.split(".")[1]
434 for key in row.keys()
435 if key and key.startswith(TableField.METADATA.value)
436 ]
438 id_filter = {"path": ["id"], "operator": "Equal", "valueText": row["id"]}
439 metadata_id_query = f"associatedMetadata {{ ... on {metadata_table_name} {{ _additional {{ id }} }} }}"
440 result = (
441 self._client.query.get(table_name, metadata_id_query)
442 .with_additional(["id"])
443 .with_where(id_filter)
444 .do()
445 )
447 metadata_id = result["data"]["Get"][table_name][0]["associatedMetadata"][0][
448 "_additional"
449 ]["id"][0]
450 # updating table
451 self._client.data_object.update(
452 uuid=row["id"],
453 class_name=table_name,
454 data_object={key: row[key] for key in non_metadata_keys},
455 )
456 # updating metadata
457 self._client.data_object.update(
458 uuid=metadata_id,
459 class_name=metadata_table_name,
460 data_object={key: row[key] for key in metadata_keys},
461 )
463 def delete(
464 self, table_name: str, conditions: List[FilterCondition] = None
465 ):
466 table_name = table_name.capitalize()
467 non_metadata_conditions = [
468 condition
469 for condition in conditions
470 if not condition.column.startswith(TableField.METADATA.value)
471 and condition.column != TableField.SEARCH_VECTOR.value
472 and condition.column != TableField.EMBEDDINGS.value
473 ]
474 metadata_conditions = [
475 condition
476 for condition in conditions
477 if condition.column.startswith(TableField.METADATA.value)
478 ]
479 filters = self._translate_condition(
480 table_name,
481 non_metadata_conditions if non_metadata_conditions else None,
482 metadata_conditions if metadata_conditions else None,
483 )
484 if not filters:
485 raise Exception("Delete query must have at least one condition!")
486 metadata_table_name = table_name.capitalize() + "_metadata"
487 # query to get metadata ids
488 metadata_query = f"associatedMetadata {{ ... on {metadata_table_name} {{ _additional {{ id }} }} }}"
489 result = (
490 self._client.query.get(table_name, metadata_query)
491 .with_additional(["id"])
492 .with_where(filters)
493 .do()
494 )
495 result = result["data"]["Get"][table_name]
496 metadata_table_name = table_name.capitalize() + "_metadata"
497 table_ids = []
498 metadata_ids = []
499 for i in result:
500 table_ids.append(i["_additional"]["id"])
501 metadata_ids.append(i["associatedMetadata"][0]["_additional"]["id"])
502 self._client.batch.delete_objects(
503 class_name=table_name,
504 where={
505 "path": ["id"],
506 "operator": "ContainsAny",
507 "valueTextArray": table_ids,
508 },
509 )
510 self._client.batch.delete_objects(
511 class_name=metadata_table_name,
512 where={
513 "path": ["id"],
514 "operator": "ContainsAny",
515 "valueTextArray": metadata_ids,
516 },
517 )
519 def create_table(self, table_name: str, if_not_exists=True):
520 """
521 Create a class with the given name in the weaviate database.
522 """
523 # separate metadata table for each table (as different tables will have different metadata columns)
524 # this reduces the query time using metadata but increases the insertion time
525 metadata_table_name = table_name + "_metadata"
526 if not self._client.schema.exists(metadata_table_name):
527 self._client.schema.create_class({"class": metadata_table_name})
528 if not self._client.schema.exists(table_name):
529 self._client.schema.create_class(
530 {
531 "class": table_name,
532 "properties": [
533 {"dataType": ["text"], "name": prop["name"]}
534 for prop in self.SCHEMA
535 if prop["name"] != "id"
536 and prop["name"] != "embeddings"
537 and prop["name"] != "metadata"
538 ],
539 "vectorIndexType": "hnsw",
540 }
541 )
542 add_prop = {
543 "name": "associatedMetadata",
544 "dataType": [metadata_table_name.capitalize()],
545 }
546 self._client.schema.property.create(table_name.capitalize(), add_prop)
548 def drop_table(self, table_name: str, if_exists=True):
549 """
550 Delete a class from the weaviate database.
551 """
552 table_name = table_name.capitalize()
553 metadata_table_name = table_name.capitalize() + "_metadata"
554 table_id_query = self._client.query.get(table_name).with_additional(["id"]).do()
555 table_ids = [
556 i["_additional"]["id"] for i in table_id_query["data"]["Get"][table_name]
557 ]
558 metadata_table_id_query = (
559 self._client.query.get(metadata_table_name).with_additional(["id"]).do()
560 )
561 metadata_ids = [
562 i["_additional"]["id"]
563 for i in metadata_table_id_query["data"]["Get"][metadata_table_name]
564 ]
565 self._client.batch.delete_objects(
566 class_name=table_name,
567 where={
568 "path": ["id"],
569 "operator": "ContainsAny",
570 "valueTextArray": table_ids,
571 },
572 )
573 self._client.batch.delete_objects(
574 class_name=metadata_table_name,
575 where={
576 "path": ["id"],
577 "operator": "ContainsAny",
578 "valueTextArray": metadata_ids,
579 },
580 )
581 try:
582 self._client.schema.delete_class(table_name)
583 self._client.schema.delete_class(metadata_table_name)
584 except ValueError:
585 if not if_exists:
586 raise Exception(f"Table {table_name} does not exist!")
588 def get_tables(self) -> HandlerResponse:
589 """
590 Get the list of tables in the Weaviate database.
591 """
592 query_tables = self._client.schema.get()
593 tables = []
594 if query_tables:
595 tables = [table["class"] for table in query_tables["classes"]]
596 table_name = pd.DataFrame(
597 columns=["table_name"],
598 data=tables,
599 )
600 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=table_name)
602 def get_columns(self, table_name: str) -> HandlerResponse:
603 table_name = table_name.capitalize()
604 # check if table exists
605 try:
606 table = self._client.schema.get(table_name)
607 except ValueError:
608 return Response(
609 resp_type=RESPONSE_TYPE.ERROR,
610 error_message=f"Table {table_name} does not exist!",
611 )
612 data = pd.DataFrame(
613 data=[
614 {"COLUMN_NAME": column["name"], "DATA_TYPE": column["dataType"][0]}
615 for column in table["properties"]
616 ]
617 )
618 return Response(data_frame=data, resp_type=RESPONSE_TYPE.OK)
620 def add_metadata(self, data: dict, table_name: str):
621 table_name = table_name.capitalize()
622 metadata_table_name = table_name.capitalize() + "_metadata"
623 self._client.schema.get(metadata_table_name)
624 # getting existing metadata fields
625 added_prop_list = [
626 prop["name"]
627 for prop in self._client.schema.get(metadata_table_name)["properties"]
628 ]
629 # as metadata columns are not fixed, at every entry, a check takes place for the columns
630 for prop in data.keys():
631 if prop not in added_prop_list:
632 if isinstance(data[prop], int):
633 add_prop = {
634 "name": prop,
635 "dataType": ["int"],
636 }
637 elif isinstance(data[prop][0], datetime):
638 add_prop = {
639 "name": prop,
640 "dataType": ["date"],
641 }
642 else:
643 add_prop = {
644 "name": prop,
645 "dataType": ["string"],
646 }
647 # when a new column is identified, it is added to the metadata table
648 self._client.schema.property.create(metadata_table_name, add_prop)
649 metadata_id = self._client.data_object.create(
650 data_object=data, class_name=table_name.capitalize() + "_metadata"
651 )
652 return metadata_id