Coverage for mindsdb / integrations / libs / vectordatabase_handler.py: 23%
258 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import ast
2import copy
3import hashlib
4from enum import Enum
5from typing import Dict, List, Optional
6import datetime as dt
8import pandas as pd
9from mindsdb_sql_parser.ast import (
10 BinaryOperation,
11 Constant,
12 CreateTable,
13 Delete,
14 DropTables,
15 Insert,
16 Select,
17 Star,
18 Tuple,
19 Update,
20)
21from mindsdb_sql_parser.ast.base import ASTNode
23from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse
24from mindsdb.utilities import log
25from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs
27from mindsdb.integrations.utilities.query_traversal import query_traversal
28from .base import BaseHandler
30LOG = log.getLogger(__name__)
33class VectorHandlerException(Exception): ...
36class TableField(Enum):
37 """
38 Enum for table fields.
39 """
41 ID = "id"
42 CONTENT = "content"
43 EMBEDDINGS = "embeddings"
44 METADATA = "metadata"
45 SEARCH_VECTOR = "search_vector"
46 DISTANCE = "distance"
47 RELEVANCE = "relevance"
50class DistanceFunction(Enum):
51 SQUARED_EUCLIDEAN_DISTANCE = ("<->",)
52 NEGATIVE_DOT_PRODUCT = ("<#>",)
53 COSINE_DISTANCE = "<=>"
56class VectorStoreHandler(BaseHandler):
57 """
58 Base class for handlers associated to vector databases.
59 """
61 SCHEMA = [
62 {
63 "name": TableField.ID.value,
64 "data_type": "string",
65 },
66 {
67 "name": TableField.CONTENT.value,
68 "data_type": "string",
69 },
70 {
71 "name": TableField.EMBEDDINGS.value,
72 "data_type": "list",
73 },
74 {
75 "name": TableField.METADATA.value,
76 "data_type": "json",
77 },
78 {
79 "name": TableField.DISTANCE.value,
80 "data_type": "float",
81 },
82 ]
84 def validate_connection_parameters(self, name, **kwargs):
85 """Create validation for input parameters."""
87 return NotImplementedError()
89 def __del__(self):
90 if self.is_connected is True:
91 self.disconnect()
93 def disconnect(self):
94 pass
96 def _value_or_self(self, value):
97 if isinstance(value, Constant):
98 return value.value
99 else:
100 return value
102 def extract_conditions(self, where_statement) -> Optional[List[FilterCondition]]:
103 conditions = []
104 # parse conditions
105 if where_statement is not None:
106 # dfs to get all binary operators in the where statement
107 def _extract_comparison_conditions(node, **kwargs):
108 if isinstance(node, BinaryOperation):
109 # if the op is and, continue
110 # TODO: need to handle the OR case
111 if node.op.upper() == "AND":
112 return
113 op = FilterOperator(node.op.upper())
114 # unquote the left hand side
115 left_hand = node.args[0].parts[-1].strip("`")
116 if isinstance(node.args[1], Constant):
117 if left_hand == TableField.SEARCH_VECTOR.value:
118 right_hand = ast.literal_eval(node.args[1].value)
119 else:
120 right_hand = node.args[1].value
121 elif isinstance(node.args[1], Tuple):
122 # Constant could be actually a list i.e. [1.2, 3.2]
123 right_hand = [item.value for item in node.args[1].items]
124 else:
125 raise Exception(f"Unsupported right hand side: {node.args[1]}")
126 conditions.append(FilterCondition(column=left_hand, op=op, value=right_hand))
128 query_traversal(where_statement, _extract_comparison_conditions)
130 else:
131 conditions = None
133 return conditions
135 def _convert_metadata_filters(self, conditions, allowed_metadata_columns=None):
136 if conditions is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true
137 return
138 # try to treat conditions that are not in TableField as metadata conditions
139 for condition in conditions: 139 ↛ 140line 139 didn't jump to line 140 because the loop on line 139 never started
140 if self._is_metadata_condition(condition):
141 # check restriction
142 if allowed_metadata_columns is not None:
143 # system columns are underscored, skip them
144 if condition.column.lower() not in allowed_metadata_columns and not condition.column.startswith(
145 "_"
146 ):
147 raise ValueError(f"Column is not found: {condition.column}")
149 # convert if required
150 if not condition.column.startswith(TableField.METADATA.value):
151 condition.column = TableField.METADATA.value + "." + condition.column
153 def _is_columns_allowed(self, columns: List[str]) -> bool:
154 """
155 Check if columns are allowed.
156 """
157 allowed_columns = set([col["name"] for col in self.SCHEMA])
158 return set(columns).issubset(allowed_columns)
160 def _is_metadata_condition(self, condition: FilterCondition) -> bool:
161 allowed_field_values = set([field.value for field in TableField])
162 if condition.column in allowed_field_values:
163 return False
164 return True
166 def _dispatch_create_table(self, query: CreateTable):
167 """
168 Dispatch create table query to the appropriate method.
169 """
170 # parse key arguments
171 table_name = query.name.parts[-1]
172 if_not_exists = getattr(query, "if_not_exists", False)
173 return self.create_table(table_name, if_not_exists=if_not_exists)
175 def _dispatch_drop_table(self, query: DropTables):
176 """
177 Dispatch drop table query to the appropriate method.
178 """
179 table_name = query.tables[0].parts[-1]
180 if_exists = getattr(query, "if_exists", False)
182 return self.drop_table(table_name, if_exists=if_exists)
184 def _dispatch_insert(self, query: Insert):
185 """
186 Dispatch insert query to the appropriate method.
187 """
188 # parse key arguments
189 table_name = query.table.parts[-1]
190 columns = [column.name for column in query.columns]
192 if not self._is_columns_allowed(columns):
193 raise Exception(f"Columns {columns} not allowed.Allowed columns are {[col['name'] for col in self.SCHEMA]}")
195 # get content column if it is present
196 if TableField.CONTENT.value in columns:
197 content_col_index = columns.index("content")
198 content = [self._value_or_self(row[content_col_index]) for row in query.values]
199 else:
200 content = None
202 # get id column if it is present
203 ids = None
204 if TableField.ID.value in columns:
205 id_col_index = columns.index("id")
206 ids = [self._value_or_self(row[id_col_index]) for row in query.values]
207 elif TableField.CONTENT.value is None:
208 raise Exception("Content or id is required!")
210 # get embeddings column if it is present
211 if TableField.EMBEDDINGS.value in columns:
212 embeddings_col_index = columns.index("embeddings")
213 embeddings = [ast.literal_eval(self._value_or_self(row[embeddings_col_index])) for row in query.values]
214 else:
215 raise Exception("Embeddings column is required!")
217 if TableField.METADATA.value in columns:
218 metadata_col_index = columns.index("metadata")
219 metadata = [ast.literal_eval(self._value_or_self(row[metadata_col_index])) for row in query.values]
220 else:
221 metadata = None
223 # create dataframe
224 data = {
225 TableField.CONTENT.value: content,
226 TableField.EMBEDDINGS.value: embeddings,
227 TableField.METADATA.value: metadata,
228 }
229 if ids is not None:
230 data[TableField.ID.value] = ids
232 return self.do_upsert(table_name, pd.DataFrame(data))
234 def dispatch_update(self, query: Update, conditions: List[FilterCondition] = None):
235 """
236 Dispatch update query to the appropriate method.
237 """
238 table_name = query.table.parts[-1]
240 row = {}
241 for k, v in query.update_columns.items():
242 k = k.lower()
243 if isinstance(v, Constant):
244 v = v.value
245 if k == TableField.EMBEDDINGS.value and isinstance(v, str):
246 # it could be embeddings in string
247 try:
248 v = ast.literal_eval(v)
249 except Exception:
250 pass
251 row[k] = v
253 if conditions is None:
254 where_statement = query.where
255 conditions = self.extract_conditions(where_statement)
257 for condition in conditions:
258 if condition.op != FilterOperator.EQUAL:
259 raise NotImplementedError
261 row[condition.column] = condition.value
263 # checks
264 if TableField.EMBEDDINGS.value not in row:
265 raise Exception("Embeddings column is required!")
267 if TableField.CONTENT.value not in row:
268 raise Exception("Content is required!")
270 # store
271 df = pd.DataFrame([row])
273 return self.do_upsert(table_name, df)
275 def set_metadata_cur_time(self, df, col_name):
276 metadata_col = TableField.METADATA.value
277 cur_date = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
279 def set_time(meta):
280 meta[col_name] = cur_date
282 df[metadata_col].apply(set_time)
284 def do_upsert(self, table_name, df):
285 """Upsert data into table, handling document updates and deletions.
287 Args:
288 table_name (str): Name of the table
289 df (pd.DataFrame): DataFrame containing the data to upsert
291 The function handles three cases:
292 1. New documents: Insert them
293 2. Updated documents: Delete old chunks and insert new ones
294 """
295 id_col = TableField.ID.value
296 metadata_col = TableField.METADATA.value
297 content_col = TableField.CONTENT.value
299 def gen_hash(v):
300 return hashlib.md5(str(v).encode()).hexdigest()
302 if id_col not in df.columns:
303 # generate for all
304 df[id_col] = df[content_col].apply(gen_hash)
305 else:
306 # generate for empty
307 for i in range(len(df)):
308 if pd.isna(df.loc[i, id_col]):
309 df.loc[i, id_col] = gen_hash(df.loc[i, content_col])
311 # remove duplicated ids
312 df = df.drop_duplicates([TableField.ID.value])
314 # id is string TODO is it ok?
315 df[id_col] = df[id_col].apply(str)
317 # set updated_at
318 self.set_metadata_cur_time(df, "_updated_at")
320 if hasattr(self, "upsert"):
321 self.upsert(table_name, df)
322 return
324 # find existing ids
325 df_existed = self.select(
326 table_name,
327 columns=[id_col, metadata_col],
328 conditions=[FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))],
329 )
330 existed_ids = list(df_existed[id_col])
332 # update existed
333 df_update = df[df[id_col].isin(existed_ids)]
334 df_insert = df[~df[id_col].isin(existed_ids)]
336 if not df_update.empty:
337 # get values of existed `created_at` and return them to metadata
338 origin_id_col = "_original_doc_id"
340 created_dates, ids = {}, {}
341 for _, row in df_existed.iterrows():
342 chunk_id = row[id_col]
343 created_dates[chunk_id] = row[metadata_col].get("_created_at")
344 ids[chunk_id] = row[metadata_col].get(origin_id_col)
346 def keep_created_at(row):
347 val = created_dates.get(row[id_col])
348 if val:
349 row[metadata_col]["_created_at"] = val
350 # keep id column
351 if origin_id_col not in row[metadata_col]:
352 row[metadata_col][origin_id_col] = ids.get(row[id_col])
353 return row
355 df_update.apply(keep_created_at, axis=1)
357 try:
358 self.update(table_name, df_update, [id_col])
359 except NotImplementedError:
360 # not implemented? do it with delete and insert
361 conditions = [FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))]
362 self.delete(table_name, conditions)
363 self.insert(table_name, df_update)
364 if not df_insert.empty:
365 # set created_at
366 self.set_metadata_cur_time(df_insert, "_created_at")
368 self.insert(table_name, df_insert)
370 def dispatch_delete(self, query: Delete, conditions: List[FilterCondition] = None):
371 """
372 Dispatch delete query to the appropriate method.
373 """
374 # parse key arguments
375 table_name = query.table.parts[-1]
376 if conditions is None:
377 where_statement = query.where
378 conditions = self.extract_conditions(where_statement)
379 self._convert_metadata_filters(conditions)
381 # dispatch delete
382 return self.delete(table_name, conditions=conditions)
384 def dispatch_select(
385 self,
386 query: Select,
387 conditions: Optional[List[FilterCondition]] = None,
388 allowed_metadata_columns: List[str] = None,
389 keyword_search_args: Optional[KeywordSearchArgs] = None,
390 ):
391 """
392 Dispatches a select query to the appropriate method, handling both
393 standard selections and keyword searches based on the provided arguments.
394 """
395 # 1. Parse common query arguments
396 table_name = query.from_table.parts[-1]
398 # If targets are a star (*), select all schema columns
399 if isinstance(query.targets[0], Star): 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true
400 columns = [col["name"] for col in self.SCHEMA]
401 else:
402 columns = [col.parts[-1] for col in query.targets]
404 # 2. Validate columns
405 if not self._is_columns_allowed(columns): 405 ↛ 406line 405 didn't jump to line 406 because the condition on line 405 was never true
406 allowed_cols = [col["name"] for col in self.SCHEMA]
407 raise Exception(f"Columns {columns} not allowed. Allowed columns are {allowed_cols}")
409 # 3. Extract and process conditions
410 if conditions is None: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true
411 where_statement = query.where
412 conditions = self.extract_conditions(where_statement)
413 else:
414 # it is mutated
415 conditions = copy.deepcopy(conditions)
416 self._convert_metadata_filters(conditions, allowed_metadata_columns=allowed_metadata_columns)
418 # 4. Get offset and limit
419 offset = query.offset.value if query.offset is not None else None
420 limit = query.limit.value if query.limit is not None else None
422 # 5. Conditionally dispatch to the correct select method
423 if keyword_search_args: 423 ↛ 425line 423 didn't jump to line 425 because the condition on line 423 was never true
424 # It's a keyword search
425 return self.keyword_select(
426 table_name,
427 columns=columns,
428 conditions=conditions,
429 offset=offset,
430 limit=limit,
431 keyword_search_args=keyword_search_args,
432 )
433 else:
434 # It's a standard select
435 try:
436 return self.select(
437 table_name,
438 columns=columns,
439 conditions=conditions,
440 offset=offset,
441 limit=limit,
442 )
444 except Exception as e:
445 handler_engine = self.__class__.name
446 raise VectorHandlerException(f"Error in {handler_engine} database: {e}")
448 def _dispatch(self, query: ASTNode) -> HandlerResponse:
449 """
450 Parse and Dispatch query to the appropriate method.
451 """
452 dispatch_router = {
453 CreateTable: self._dispatch_create_table,
454 DropTables: self._dispatch_drop_table,
455 Insert: self._dispatch_insert,
456 Update: self.dispatch_update,
457 Delete: self.dispatch_delete,
458 Select: self.dispatch_select,
459 }
460 if type(query) in dispatch_router:
461 resp = dispatch_router[type(query)](query)
462 if resp is not None:
463 return HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=resp)
464 else:
465 return HandlerResponse(resp_type=RESPONSE_TYPE.OK)
467 else:
468 raise NotImplementedError(f"Query type {type(query)} not implemented.")
470 def query(self, query: ASTNode) -> HandlerResponse:
471 """
472 Receive query as AST (abstract syntax tree) and act upon it somehow.
474 Args:
475 query (ASTNode): sql query represented as AST. May be any kind
476 of query: SELECT, INSERT, DELETE, etc
478 Returns:
479 HandlerResponse
480 """
481 return self._dispatch(query)
483 def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse:
484 """Create table
486 Args:
487 table_name (str): table name
488 if_not_exists (bool): if True, do nothing if table exists
490 Returns:
491 HandlerResponse
492 """
493 raise NotImplementedError()
495 def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse:
496 """Drop table
498 Args:
499 table_name (str): table name
500 if_exists (bool): if True, do nothing if table does not exist
502 Returns:
503 HandlerResponse
504 """
505 raise NotImplementedError()
507 def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse:
508 """Insert data into table
510 Args:
511 table_name (str): table name
512 data (pd.DataFrame): data to insert
513 columns (List[str]): columns to insert
515 Returns:
516 HandlerResponse
517 """
518 raise NotImplementedError()
520 def update(self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None):
521 """Update data in table
523 Args:
524 table_name (str): table name
525 data (pd.DataFrame): data to update
526 key_columns (List[str]): key to to update
528 Returns:
529 HandlerResponse
530 """
531 raise NotImplementedError()
533 def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> HandlerResponse:
534 """Delete data from table
536 Args:
537 table_name (str): table name
538 conditions (List[FilterCondition]): conditions to delete
540 Returns:
541 HandlerResponse
542 """
543 raise NotImplementedError()
545 def select(
546 self,
547 table_name: str,
548 columns: List[str] = None,
549 conditions: List[FilterCondition] = None,
550 offset: int = None,
551 limit: int = None,
552 ) -> pd.DataFrame:
553 """Select data from table
555 Args:
556 table_name (str): table name
557 columns (List[str]): columns to select
558 conditions (List[FilterCondition]): conditions to select
560 Returns:
561 HandlerResponse
562 """
563 raise NotImplementedError()
565 def get_columns(self, table_name: str) -> HandlerResponse:
566 # return a fixed set of columns
567 data = pd.DataFrame(self.SCHEMA)
568 data.columns = ["COLUMN_NAME", "DATA_TYPE"]
569 return HandlerResponse(
570 resp_type=RESPONSE_TYPE.TABLE,
571 data_frame=data,
572 )
574 def hybrid_search(
575 self,
576 table_name: str,
577 embeddings: List[float],
578 query: str = None,
579 metadata: Dict[str, str] = None,
580 distance_function=DistanceFunction.COSINE_DISTANCE,
581 **kwargs,
582 ) -> pd.DataFrame:
583 """
584 Executes a hybrid search, combining semantic search and one or both of keyword/metadata search.
586 For insight on the query construction, see: https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search.
588 Args:
589 table_name(str): Name of underlying table containing content, embeddings, & metadata
590 embeddings(List[float]): Embedding vector to perform semantic search against
591 query(str): User query to convert into keywords for keyword search
592 metadata(Dict[str, str]): Metadata filters to filter content rows against
593 distance_function(DistanceFunction): Distance function used to compare embeddings vectors for semantic search
595 Returns:
596 df(pd.DataFrame): Hybrid search result, sorted by hybrid search rank
597 """
598 raise NotImplementedError(f"Hybrid search not supported for VectorStoreHandler {self.name}")
600 def check_existing_ids(self, table_name: str, ids: List[str]) -> List[str]:
601 """
602 Check which IDs from the provided list already exist in the table.
604 Args:
605 table_name (str): Name of the table to check
606 ids (List[str]): List of IDs to check for existence
608 Returns:
609 List[str]: List of IDs that already exist in the table
610 """
611 if not ids:
612 return []
614 try:
615 # Query existing IDs
616 df_existing = self.select(
617 table_name,
618 columns=[TableField.ID.value],
619 conditions=[FilterCondition(column=TableField.ID.value, op=FilterOperator.IN, value=ids)],
620 )
621 return list(df_existing[TableField.ID.value]) if not df_existing.empty else []
622 except Exception:
623 # If select fails for any reason, return empty list to be safe
624 return []
626 def create_index(self, *args, **kwargs):
627 """
628 Create an index on the specified table.
629 """
630 raise NotImplementedError(f"create_index not supported for VectorStoreHandler {self.name}")