Coverage for mindsdb / integrations / libs / vectordatabase_handler.py: 23%

258 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2import copy 

3import hashlib 

4from enum import Enum 

5from typing import Dict, List, Optional 

6import datetime as dt 

7 

8import pandas as pd 

9from mindsdb_sql_parser.ast import ( 

10 BinaryOperation, 

11 Constant, 

12 CreateTable, 

13 Delete, 

14 DropTables, 

15 Insert, 

16 Select, 

17 Star, 

18 Tuple, 

19 Update, 

20) 

21from mindsdb_sql_parser.ast.base import ASTNode 

22 

23from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse 

24from mindsdb.utilities import log 

25from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs 

26 

27from mindsdb.integrations.utilities.query_traversal import query_traversal 

28from .base import BaseHandler 

29 

30LOG = log.getLogger(__name__) 

31 

32 

33class VectorHandlerException(Exception): ... 

34 

35 

36class TableField(Enum): 

37 """ 

38 Enum for table fields. 

39 """ 

40 

41 ID = "id" 

42 CONTENT = "content" 

43 EMBEDDINGS = "embeddings" 

44 METADATA = "metadata" 

45 SEARCH_VECTOR = "search_vector" 

46 DISTANCE = "distance" 

47 RELEVANCE = "relevance" 

48 

49 

50class DistanceFunction(Enum): 

51 SQUARED_EUCLIDEAN_DISTANCE = ("<->",) 

52 NEGATIVE_DOT_PRODUCT = ("<#>",) 

53 COSINE_DISTANCE = "<=>" 

54 

55 

56class VectorStoreHandler(BaseHandler): 

57 """ 

58 Base class for handlers associated to vector databases. 

59 """ 

60 

61 SCHEMA = [ 

62 { 

63 "name": TableField.ID.value, 

64 "data_type": "string", 

65 }, 

66 { 

67 "name": TableField.CONTENT.value, 

68 "data_type": "string", 

69 }, 

70 { 

71 "name": TableField.EMBEDDINGS.value, 

72 "data_type": "list", 

73 }, 

74 { 

75 "name": TableField.METADATA.value, 

76 "data_type": "json", 

77 }, 

78 { 

79 "name": TableField.DISTANCE.value, 

80 "data_type": "float", 

81 }, 

82 ] 

83 

84 def validate_connection_parameters(self, name, **kwargs): 

85 """Create validation for input parameters.""" 

86 

87 return NotImplementedError() 

88 

89 def __del__(self): 

90 if self.is_connected is True: 

91 self.disconnect() 

92 

93 def disconnect(self): 

94 pass 

95 

96 def _value_or_self(self, value): 

97 if isinstance(value, Constant): 

98 return value.value 

99 else: 

100 return value 

101 

102 def extract_conditions(self, where_statement) -> Optional[List[FilterCondition]]: 

103 conditions = [] 

104 # parse conditions 

105 if where_statement is not None: 

106 # dfs to get all binary operators in the where statement 

107 def _extract_comparison_conditions(node, **kwargs): 

108 if isinstance(node, BinaryOperation): 

109 # if the op is and, continue 

110 # TODO: need to handle the OR case 

111 if node.op.upper() == "AND": 

112 return 

113 op = FilterOperator(node.op.upper()) 

114 # unquote the left hand side 

115 left_hand = node.args[0].parts[-1].strip("`") 

116 if isinstance(node.args[1], Constant): 

117 if left_hand == TableField.SEARCH_VECTOR.value: 

118 right_hand = ast.literal_eval(node.args[1].value) 

119 else: 

120 right_hand = node.args[1].value 

121 elif isinstance(node.args[1], Tuple): 

122 # Constant could be actually a list i.e. [1.2, 3.2] 

123 right_hand = [item.value for item in node.args[1].items] 

124 else: 

125 raise Exception(f"Unsupported right hand side: {node.args[1]}") 

126 conditions.append(FilterCondition(column=left_hand, op=op, value=right_hand)) 

127 

128 query_traversal(where_statement, _extract_comparison_conditions) 

129 

130 else: 

131 conditions = None 

132 

133 return conditions 

134 

135 def _convert_metadata_filters(self, conditions, allowed_metadata_columns=None): 

136 if conditions is None: 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true

137 return 

138 # try to treat conditions that are not in TableField as metadata conditions 

139 for condition in conditions: 139 ↛ 140line 139 didn't jump to line 140 because the loop on line 139 never started

140 if self._is_metadata_condition(condition): 

141 # check restriction 

142 if allowed_metadata_columns is not None: 

143 # system columns are underscored, skip them 

144 if condition.column.lower() not in allowed_metadata_columns and not condition.column.startswith( 

145 "_" 

146 ): 

147 raise ValueError(f"Column is not found: {condition.column}") 

148 

149 # convert if required 

150 if not condition.column.startswith(TableField.METADATA.value): 

151 condition.column = TableField.METADATA.value + "." + condition.column 

152 

153 def _is_columns_allowed(self, columns: List[str]) -> bool: 

154 """ 

155 Check if columns are allowed. 

156 """ 

157 allowed_columns = set([col["name"] for col in self.SCHEMA]) 

158 return set(columns).issubset(allowed_columns) 

159 

160 def _is_metadata_condition(self, condition: FilterCondition) -> bool: 

161 allowed_field_values = set([field.value for field in TableField]) 

162 if condition.column in allowed_field_values: 

163 return False 

164 return True 

165 

166 def _dispatch_create_table(self, query: CreateTable): 

167 """ 

168 Dispatch create table query to the appropriate method. 

169 """ 

170 # parse key arguments 

171 table_name = query.name.parts[-1] 

172 if_not_exists = getattr(query, "if_not_exists", False) 

173 return self.create_table(table_name, if_not_exists=if_not_exists) 

174 

175 def _dispatch_drop_table(self, query: DropTables): 

176 """ 

177 Dispatch drop table query to the appropriate method. 

178 """ 

179 table_name = query.tables[0].parts[-1] 

180 if_exists = getattr(query, "if_exists", False) 

181 

182 return self.drop_table(table_name, if_exists=if_exists) 

183 

184 def _dispatch_insert(self, query: Insert): 

185 """ 

186 Dispatch insert query to the appropriate method. 

187 """ 

188 # parse key arguments 

189 table_name = query.table.parts[-1] 

190 columns = [column.name for column in query.columns] 

191 

192 if not self._is_columns_allowed(columns): 

193 raise Exception(f"Columns {columns} not allowed.Allowed columns are {[col['name'] for col in self.SCHEMA]}") 

194 

195 # get content column if it is present 

196 if TableField.CONTENT.value in columns: 

197 content_col_index = columns.index("content") 

198 content = [self._value_or_self(row[content_col_index]) for row in query.values] 

199 else: 

200 content = None 

201 

202 # get id column if it is present 

203 ids = None 

204 if TableField.ID.value in columns: 

205 id_col_index = columns.index("id") 

206 ids = [self._value_or_self(row[id_col_index]) for row in query.values] 

207 elif TableField.CONTENT.value is None: 

208 raise Exception("Content or id is required!") 

209 

210 # get embeddings column if it is present 

211 if TableField.EMBEDDINGS.value in columns: 

212 embeddings_col_index = columns.index("embeddings") 

213 embeddings = [ast.literal_eval(self._value_or_self(row[embeddings_col_index])) for row in query.values] 

214 else: 

215 raise Exception("Embeddings column is required!") 

216 

217 if TableField.METADATA.value in columns: 

218 metadata_col_index = columns.index("metadata") 

219 metadata = [ast.literal_eval(self._value_or_self(row[metadata_col_index])) for row in query.values] 

220 else: 

221 metadata = None 

222 

223 # create dataframe 

224 data = { 

225 TableField.CONTENT.value: content, 

226 TableField.EMBEDDINGS.value: embeddings, 

227 TableField.METADATA.value: metadata, 

228 } 

229 if ids is not None: 

230 data[TableField.ID.value] = ids 

231 

232 return self.do_upsert(table_name, pd.DataFrame(data)) 

233 

234 def dispatch_update(self, query: Update, conditions: List[FilterCondition] = None): 

235 """ 

236 Dispatch update query to the appropriate method. 

237 """ 

238 table_name = query.table.parts[-1] 

239 

240 row = {} 

241 for k, v in query.update_columns.items(): 

242 k = k.lower() 

243 if isinstance(v, Constant): 

244 v = v.value 

245 if k == TableField.EMBEDDINGS.value and isinstance(v, str): 

246 # it could be embeddings in string 

247 try: 

248 v = ast.literal_eval(v) 

249 except Exception: 

250 pass 

251 row[k] = v 

252 

253 if conditions is None: 

254 where_statement = query.where 

255 conditions = self.extract_conditions(where_statement) 

256 

257 for condition in conditions: 

258 if condition.op != FilterOperator.EQUAL: 

259 raise NotImplementedError 

260 

261 row[condition.column] = condition.value 

262 

263 # checks 

264 if TableField.EMBEDDINGS.value not in row: 

265 raise Exception("Embeddings column is required!") 

266 

267 if TableField.CONTENT.value not in row: 

268 raise Exception("Content is required!") 

269 

270 # store 

271 df = pd.DataFrame([row]) 

272 

273 return self.do_upsert(table_name, df) 

274 

275 def set_metadata_cur_time(self, df, col_name): 

276 metadata_col = TableField.METADATA.value 

277 cur_date = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 

278 

279 def set_time(meta): 

280 meta[col_name] = cur_date 

281 

282 df[metadata_col].apply(set_time) 

283 

284 def do_upsert(self, table_name, df): 

285 """Upsert data into table, handling document updates and deletions. 

286 

287 Args: 

288 table_name (str): Name of the table 

289 df (pd.DataFrame): DataFrame containing the data to upsert 

290 

291 The function handles three cases: 

292 1. New documents: Insert them 

293 2. Updated documents: Delete old chunks and insert new ones 

294 """ 

295 id_col = TableField.ID.value 

296 metadata_col = TableField.METADATA.value 

297 content_col = TableField.CONTENT.value 

298 

299 def gen_hash(v): 

300 return hashlib.md5(str(v).encode()).hexdigest() 

301 

302 if id_col not in df.columns: 

303 # generate for all 

304 df[id_col] = df[content_col].apply(gen_hash) 

305 else: 

306 # generate for empty 

307 for i in range(len(df)): 

308 if pd.isna(df.loc[i, id_col]): 

309 df.loc[i, id_col] = gen_hash(df.loc[i, content_col]) 

310 

311 # remove duplicated ids 

312 df = df.drop_duplicates([TableField.ID.value]) 

313 

314 # id is string TODO is it ok? 

315 df[id_col] = df[id_col].apply(str) 

316 

317 # set updated_at 

318 self.set_metadata_cur_time(df, "_updated_at") 

319 

320 if hasattr(self, "upsert"): 

321 self.upsert(table_name, df) 

322 return 

323 

324 # find existing ids 

325 df_existed = self.select( 

326 table_name, 

327 columns=[id_col, metadata_col], 

328 conditions=[FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))], 

329 ) 

330 existed_ids = list(df_existed[id_col]) 

331 

332 # update existed 

333 df_update = df[df[id_col].isin(existed_ids)] 

334 df_insert = df[~df[id_col].isin(existed_ids)] 

335 

336 if not df_update.empty: 

337 # get values of existed `created_at` and return them to metadata 

338 origin_id_col = "_original_doc_id" 

339 

340 created_dates, ids = {}, {} 

341 for _, row in df_existed.iterrows(): 

342 chunk_id = row[id_col] 

343 created_dates[chunk_id] = row[metadata_col].get("_created_at") 

344 ids[chunk_id] = row[metadata_col].get(origin_id_col) 

345 

346 def keep_created_at(row): 

347 val = created_dates.get(row[id_col]) 

348 if val: 

349 row[metadata_col]["_created_at"] = val 

350 # keep id column 

351 if origin_id_col not in row[metadata_col]: 

352 row[metadata_col][origin_id_col] = ids.get(row[id_col]) 

353 return row 

354 

355 df_update.apply(keep_created_at, axis=1) 

356 

357 try: 

358 self.update(table_name, df_update, [id_col]) 

359 except NotImplementedError: 

360 # not implemented? do it with delete and insert 

361 conditions = [FilterCondition(column=id_col, op=FilterOperator.IN, value=list(df[id_col]))] 

362 self.delete(table_name, conditions) 

363 self.insert(table_name, df_update) 

364 if not df_insert.empty: 

365 # set created_at 

366 self.set_metadata_cur_time(df_insert, "_created_at") 

367 

368 self.insert(table_name, df_insert) 

369 

370 def dispatch_delete(self, query: Delete, conditions: List[FilterCondition] = None): 

371 """ 

372 Dispatch delete query to the appropriate method. 

373 """ 

374 # parse key arguments 

375 table_name = query.table.parts[-1] 

376 if conditions is None: 

377 where_statement = query.where 

378 conditions = self.extract_conditions(where_statement) 

379 self._convert_metadata_filters(conditions) 

380 

381 # dispatch delete 

382 return self.delete(table_name, conditions=conditions) 

383 

384 def dispatch_select( 

385 self, 

386 query: Select, 

387 conditions: Optional[List[FilterCondition]] = None, 

388 allowed_metadata_columns: List[str] = None, 

389 keyword_search_args: Optional[KeywordSearchArgs] = None, 

390 ): 

391 """ 

392 Dispatches a select query to the appropriate method, handling both 

393 standard selections and keyword searches based on the provided arguments. 

394 """ 

395 # 1. Parse common query arguments 

396 table_name = query.from_table.parts[-1] 

397 

398 # If targets are a star (*), select all schema columns 

399 if isinstance(query.targets[0], Star): 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true

400 columns = [col["name"] for col in self.SCHEMA] 

401 else: 

402 columns = [col.parts[-1] for col in query.targets] 

403 

404 # 2. Validate columns 

405 if not self._is_columns_allowed(columns): 405 ↛ 406line 405 didn't jump to line 406 because the condition on line 405 was never true

406 allowed_cols = [col["name"] for col in self.SCHEMA] 

407 raise Exception(f"Columns {columns} not allowed. Allowed columns are {allowed_cols}") 

408 

409 # 3. Extract and process conditions 

410 if conditions is None: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true

411 where_statement = query.where 

412 conditions = self.extract_conditions(where_statement) 

413 else: 

414 # it is mutated 

415 conditions = copy.deepcopy(conditions) 

416 self._convert_metadata_filters(conditions, allowed_metadata_columns=allowed_metadata_columns) 

417 

418 # 4. Get offset and limit 

419 offset = query.offset.value if query.offset is not None else None 

420 limit = query.limit.value if query.limit is not None else None 

421 

422 # 5. Conditionally dispatch to the correct select method 

423 if keyword_search_args: 423 ↛ 425line 423 didn't jump to line 425 because the condition on line 423 was never true

424 # It's a keyword search 

425 return self.keyword_select( 

426 table_name, 

427 columns=columns, 

428 conditions=conditions, 

429 offset=offset, 

430 limit=limit, 

431 keyword_search_args=keyword_search_args, 

432 ) 

433 else: 

434 # It's a standard select 

435 try: 

436 return self.select( 

437 table_name, 

438 columns=columns, 

439 conditions=conditions, 

440 offset=offset, 

441 limit=limit, 

442 ) 

443 

444 except Exception as e: 

445 handler_engine = self.__class__.name 

446 raise VectorHandlerException(f"Error in {handler_engine} database: {e}") 

447 

448 def _dispatch(self, query: ASTNode) -> HandlerResponse: 

449 """ 

450 Parse and Dispatch query to the appropriate method. 

451 """ 

452 dispatch_router = { 

453 CreateTable: self._dispatch_create_table, 

454 DropTables: self._dispatch_drop_table, 

455 Insert: self._dispatch_insert, 

456 Update: self.dispatch_update, 

457 Delete: self.dispatch_delete, 

458 Select: self.dispatch_select, 

459 } 

460 if type(query) in dispatch_router: 

461 resp = dispatch_router[type(query)](query) 

462 if resp is not None: 

463 return HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=resp) 

464 else: 

465 return HandlerResponse(resp_type=RESPONSE_TYPE.OK) 

466 

467 else: 

468 raise NotImplementedError(f"Query type {type(query)} not implemented.") 

469 

470 def query(self, query: ASTNode) -> HandlerResponse: 

471 """ 

472 Receive query as AST (abstract syntax tree) and act upon it somehow. 

473 

474 Args: 

475 query (ASTNode): sql query represented as AST. May be any kind 

476 of query: SELECT, INSERT, DELETE, etc 

477 

478 Returns: 

479 HandlerResponse 

480 """ 

481 return self._dispatch(query) 

482 

483 def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse: 

484 """Create table 

485 

486 Args: 

487 table_name (str): table name 

488 if_not_exists (bool): if True, do nothing if table exists 

489 

490 Returns: 

491 HandlerResponse 

492 """ 

493 raise NotImplementedError() 

494 

495 def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse: 

496 """Drop table 

497 

498 Args: 

499 table_name (str): table name 

500 if_exists (bool): if True, do nothing if table does not exist 

501 

502 Returns: 

503 HandlerResponse 

504 """ 

505 raise NotImplementedError() 

506 

507 def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse: 

508 """Insert data into table 

509 

510 Args: 

511 table_name (str): table name 

512 data (pd.DataFrame): data to insert 

513 columns (List[str]): columns to insert 

514 

515 Returns: 

516 HandlerResponse 

517 """ 

518 raise NotImplementedError() 

519 

520 def update(self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None): 

521 """Update data in table 

522 

523 Args: 

524 table_name (str): table name 

525 data (pd.DataFrame): data to update 

526 key_columns (List[str]): key to to update 

527 

528 Returns: 

529 HandlerResponse 

530 """ 

531 raise NotImplementedError() 

532 

533 def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> HandlerResponse: 

534 """Delete data from table 

535 

536 Args: 

537 table_name (str): table name 

538 conditions (List[FilterCondition]): conditions to delete 

539 

540 Returns: 

541 HandlerResponse 

542 """ 

543 raise NotImplementedError() 

544 

545 def select( 

546 self, 

547 table_name: str, 

548 columns: List[str] = None, 

549 conditions: List[FilterCondition] = None, 

550 offset: int = None, 

551 limit: int = None, 

552 ) -> pd.DataFrame: 

553 """Select data from table 

554 

555 Args: 

556 table_name (str): table name 

557 columns (List[str]): columns to select 

558 conditions (List[FilterCondition]): conditions to select 

559 

560 Returns: 

561 HandlerResponse 

562 """ 

563 raise NotImplementedError() 

564 

565 def get_columns(self, table_name: str) -> HandlerResponse: 

566 # return a fixed set of columns 

567 data = pd.DataFrame(self.SCHEMA) 

568 data.columns = ["COLUMN_NAME", "DATA_TYPE"] 

569 return HandlerResponse( 

570 resp_type=RESPONSE_TYPE.TABLE, 

571 data_frame=data, 

572 ) 

573 

574 def hybrid_search( 

575 self, 

576 table_name: str, 

577 embeddings: List[float], 

578 query: str = None, 

579 metadata: Dict[str, str] = None, 

580 distance_function=DistanceFunction.COSINE_DISTANCE, 

581 **kwargs, 

582 ) -> pd.DataFrame: 

583 """ 

584 Executes a hybrid search, combining semantic search and one or both of keyword/metadata search. 

585 

586 For insight on the query construction, see: https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search. 

587 

588 Args: 

589 table_name(str): Name of underlying table containing content, embeddings, & metadata 

590 embeddings(List[float]): Embedding vector to perform semantic search against 

591 query(str): User query to convert into keywords for keyword search 

592 metadata(Dict[str, str]): Metadata filters to filter content rows against 

593 distance_function(DistanceFunction): Distance function used to compare embeddings vectors for semantic search 

594 

595 Returns: 

596 df(pd.DataFrame): Hybrid search result, sorted by hybrid search rank 

597 """ 

598 raise NotImplementedError(f"Hybrid search not supported for VectorStoreHandler {self.name}") 

599 

600 def check_existing_ids(self, table_name: str, ids: List[str]) -> List[str]: 

601 """ 

602 Check which IDs from the provided list already exist in the table. 

603 

604 Args: 

605 table_name (str): Name of the table to check 

606 ids (List[str]): List of IDs to check for existence 

607 

608 Returns: 

609 List[str]: List of IDs that already exist in the table 

610 """ 

611 if not ids: 

612 return [] 

613 

614 try: 

615 # Query existing IDs 

616 df_existing = self.select( 

617 table_name, 

618 columns=[TableField.ID.value], 

619 conditions=[FilterCondition(column=TableField.ID.value, op=FilterOperator.IN, value=ids)], 

620 ) 

621 return list(df_existing[TableField.ID.value]) if not df_existing.empty else [] 

622 except Exception: 

623 # If select fails for any reason, return empty list to be safe 

624 return [] 

625 

626 def create_index(self, *args, **kwargs): 

627 """ 

628 Create an index on the specified table. 

629 """ 

630 raise NotImplementedError(f"create_index not supported for VectorStoreHandler {self.name}")