Coverage for mindsdb / integrations / handlers / pgvector_handler / pgvector_handler.py: 0%

304 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import json 

3from typing import Dict, List, Literal, Tuple 

4from urllib.parse import urlparse 

5 

6import pandas as pd 

7import psycopg 

8from mindsdb_sql_parser.ast import ( 

9 Parameter, 

10 Identifier, 

11 BinaryOperation, 

12 Tuple as AstTuple, 

13 Constant, 

14 Select, 

15 OrderBy, 

16 TypeCast, 

17 Delete, 

18 Update, 

19 Function, 

20 DropTables, 

21) 

22from mindsdb_sql_parser.ast.base import ASTNode 

23from pgvector.psycopg import register_vector 

24 

25from mindsdb.integrations.handlers.postgres_handler.postgres_handler import ( 

26 PostgresHandler, 

27) 

28from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse as Response 

29from mindsdb.integrations.libs.vectordatabase_handler import ( 

30 FilterCondition, 

31 VectorStoreHandler, 

32 DistanceFunction, 

33 TableField, 

34 FilterOperator, 

35) 

36from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase 

37from mindsdb.integrations.utilities.sql_utils import KeywordSearchArgs 

38from mindsdb.utilities import log 

39from mindsdb.utilities.profiler import profiler 

40from mindsdb.utilities.context import context as ctx 

41 

42logger = log.getLogger(__name__) 

43 

44 

45# todo Issue #7316 add support for different indexes and search algorithms e.g. cosine similarity or L2 norm 

46class PgVectorHandler(PostgresHandler, VectorStoreHandler, KeywordSearchBase): 

47 """This handler handles connection and execution of the PostgreSQL with pgvector extension statements.""" 

48 

49 name = "pgvector" 

50 

51 def __init__(self, name: str, **kwargs): 

52 super().__init__(name=name, **kwargs) 

53 self._is_shared_db = False 

54 self._is_vector_registered = False 

55 # we get these from the connection args on PostgresHandler parent 

56 self._is_sparse = self.connection_args.get("is_sparse", False) 

57 self._vector_size = self.connection_args.get("vector_size", None) 

58 

59 if self._is_sparse: 

60 if not self._vector_size: 

61 raise ValueError("vector_size is required when is_sparse=True") 

62 

63 # Use inner product for sparse vectors 

64 distance_op = "<#>" 

65 

66 else: 

67 distance_op = "<=>" 

68 if "distance" in self.connection_args: 

69 distance_ops = { 

70 "l1": "<+>", 

71 "l2": "<->", 

72 "ip": "<#>", # inner product 

73 "cosine": "<=>", 

74 "hamming": "<~>", 

75 "jaccard": "<%>", 

76 } 

77 

78 distance_op = distance_ops.get(self.connection_args["distance"]) 

79 if distance_op is None: 

80 raise ValueError(f"Wrong distance type. Allowed options are {list(distance_ops.keys())}") 

81 

82 self.distance_op = distance_op 

83 self.connect() 

84 

85 def get_metric_type(self) -> str: 

86 """ 

87 Get the metric type from the distance ops 

88 

89 """ 

90 distance_ops_to_metric_type_map = { 

91 "<->": "vector_l2_ops", 

92 "<#>": "vector_ip_ops", 

93 "<=>": "vector_cosine_ops", 

94 "<+>": "vector_l1_ops", 

95 "<~>": "bit_hamming_ops", 

96 "<%>": "bit_jaccard_ops", 

97 } 

98 return distance_ops_to_metric_type_map.get(self.distance_op, "vector_cosine_ops") 

99 

100 def _make_connection_args(self): 

101 cloud_pgvector_url = os.environ.get("KB_PGVECTOR_URL") 

102 # if no connection args and shared pg vector defined - use it 

103 if len(self.connection_args) == 0 and cloud_pgvector_url is not None: 

104 result = urlparse(cloud_pgvector_url) 

105 self.connection_args = { 

106 "host": result.hostname, 

107 "port": result.port, 

108 "user": result.username, 

109 "password": result.password, 

110 "database": result.path[1:], 

111 } 

112 self._is_shared_db = True 

113 return super()._make_connection_args() 

114 

115 def get_tables(self) -> Response: 

116 # Hide list of tables from all users 

117 if self._is_shared_db: 

118 return Response(RESPONSE_TYPE.OK) 

119 return super().get_tables() 

120 

121 def query(self, query: ASTNode) -> Response: 

122 # Option to drop table of shared pgvector connection 

123 if isinstance(query, DropTables): 

124 query.tables = [self._check_table(table.parts[-1]) for table in query.tables] 

125 query_str, params = self.renderer.get_exec_params(query, with_failback=True) 

126 return self.native_query(query_str, params, no_restrict=True) 

127 return super().query(query) 

128 

129 def native_query(self, query, params=None, no_restrict=False) -> Response: 

130 """ 

131 Altered `native_query` method of postgres handler. 

132 Restrict usage of native query from executor with shared pg vector connection 

133 Exceptions: if it is used by pgvector itself (with no_restrict = True) 

134 """ 

135 # Prevent execute native queries 

136 if self._is_shared_db and not no_restrict: 

137 return Response(RESPONSE_TYPE.OK) 

138 return super().native_query(query, params=params) 

139 

140 def raw_query(self, query, params=None) -> Response: 

141 resp = super().native_query(query, params) 

142 if resp.resp_type == RESPONSE_TYPE.ERROR: 

143 raise RuntimeError(resp.error_message) 

144 if resp.resp_type == RESPONSE_TYPE.TABLE: 

145 return resp.data_frame 

146 

147 @profiler.profile() 

148 def connect(self) -> psycopg.connection: 

149 """ 

150 Handles the connection to a PostgreSQL database instance. 

151 """ 

152 self.connection = super().connect() 

153 if self._is_vector_registered: 

154 return self.connection 

155 

156 with self.connection.cursor() as cur: 

157 try: 

158 # load pg_vector extension 

159 cur.execute("CREATE EXTENSION IF NOT EXISTS vector") 

160 logger.info("pg_vector extension loaded") 

161 

162 except psycopg.Error as e: 

163 self.connection.rollback() 

164 logger.error(f"Error loading pg_vector extension, ensure you have installed it before running, {e}!") 

165 raise 

166 

167 # register vector type with psycopg2 connection 

168 register_vector(self.connection) 

169 self._is_vector_registered = True 

170 

171 return self.connection 

172 

173 def add_full_text_index(self, table_name: str, column_name: str) -> Response: 

174 """ 

175 Add a full text index to the specified column of the table. 

176 Args: 

177 table_name (str): Name of the table to add the index to. 

178 column_name (str): Name of the column to add the index to. 

179 Returns: 

180 Response: Response object indicating success or failure. 

181 """ 

182 table_name = self._check_table(table_name) 

183 query = f"CREATE INDEX IF NOT EXISTS {table_name}_{column_name}_fts_idx ON {table_name} USING gin(to_tsvector('english', {column_name}))" 

184 self.raw_query(query) 

185 return Response(RESPONSE_TYPE.OK) 

186 

187 @staticmethod 

188 def _translate_conditions(conditions: List[FilterCondition]) -> Tuple[List[dict], dict]: 

189 """ 

190 Translate filter conditions to a dictionary 

191 """ 

192 

193 if conditions is None: 

194 conditions = [] 

195 

196 filter_conditions = [] 

197 embedding_condition = None 

198 

199 for condition in conditions: 

200 is_embedding = condition.column == "embeddings" 

201 

202 parts = condition.column.split(".") 

203 key = Identifier(parts[0]) 

204 

205 # converts 'col.el1.el2' to col->'el1'->>'el2' 

206 if len(parts) > 1: 

207 # intermediate elements 

208 for el in parts[1:-1]: 

209 key = BinaryOperation(op="->", args=[key, Constant(el)]) 

210 

211 # last element 

212 key = BinaryOperation(op="->>", args=[key, Constant(parts[-1])]) 

213 

214 type_cast = None 

215 value = condition.value 

216 if ( 

217 isinstance(value, list) 

218 and len(value) > 0 

219 and condition.op in (FilterOperator.IN, FilterOperator.NOT_IN) 

220 ): 

221 value = condition.value[0] 

222 

223 if isinstance(value, int): 

224 type_cast = "int" 

225 elif isinstance(value, float): 

226 type_cast = "float" 

227 if type_cast is not None: 

228 key = TypeCast(type_cast, key) 

229 

230 item = { 

231 "name": key, 

232 "op": condition.op.value, 

233 "value": condition.value, 

234 } 

235 if is_embedding: 

236 embedding_condition = item 

237 else: 

238 filter_conditions.append(item) 

239 

240 return filter_conditions, embedding_condition 

241 

242 @staticmethod 

243 def _construct_where_clause(filter_conditions=None): 

244 """ 

245 Construct where clauses from filter conditions 

246 """ 

247 

248 where_clause = None 

249 

250 for item in filter_conditions: 

251 key = item["name"] 

252 

253 if item["op"].lower() in ("in", "not in"): 

254 values = [Constant(i) for i in item["value"]] 

255 value = AstTuple(values) 

256 else: 

257 value = Constant(item["value"]) 

258 condition = BinaryOperation(op=item["op"], args=[key, value]) 

259 

260 if where_clause is None: 

261 where_clause = condition 

262 else: 

263 where_clause = BinaryOperation(op="AND", args=[where_clause, condition]) 

264 return where_clause 

265 

266 @staticmethod 

267 def _construct_full_after_from_clause( 

268 where_clause: str, 

269 offset_clause: str, 

270 limit_clause: str, 

271 ) -> str: 

272 return f"{where_clause} {offset_clause} {limit_clause}" 

273 

274 def _build_keyword_bm25_query( 

275 self, 

276 table_name: str, 

277 keyword_search_args: KeywordSearchArgs, 

278 columns: List[str] = None, 

279 conditions: List[FilterCondition] = None, 

280 limit: int = None, 

281 offset: int = None, 

282 ): 

283 if columns is None: 

284 columns = ["id", "content", "metadata"] 

285 

286 filter_conditions, _ = self._translate_conditions(conditions) 

287 where_clause = self._construct_where_clause(filter_conditions) 

288 

289 if keyword_search_args: 

290 keyword_query_condition = BinaryOperation( 

291 op="@@", 

292 args=[ 

293 Function("to_tsvector", args=[Constant("english"), Identifier(keyword_search_args.column)]), 

294 Function("websearch_to_tsquery", args=[Constant("english"), Constant(keyword_search_args.query)]), 

295 ], 

296 ) 

297 

298 if where_clause: 

299 where_clause = BinaryOperation(op="AND", args=[where_clause, keyword_query_condition]) 

300 else: 

301 where_clause = keyword_query_condition 

302 

303 distance = Function( 

304 "ts_rank_cd", 

305 args=[ 

306 Function("to_tsvector", args=[Constant("english"), Identifier(keyword_search_args.column)]), 

307 Function("websearch_to_tsquery", args=[Constant("english"), Constant(keyword_search_args.query)]), 

308 ], 

309 alias=Identifier("distance"), 

310 ) 

311 

312 targets = [Identifier(col) for col in columns] 

313 targets.append(distance) 

314 

315 limit_clause = Constant(limit) if limit else None 

316 offset_clause = Constant(offset) if offset else None 

317 

318 return Select( 

319 targets=targets, 

320 from_table=Identifier(table_name), 

321 where=where_clause, 

322 limit=limit_clause, 

323 offset=offset_clause, 

324 ) 

325 

326 def _build_select_query( 

327 self, 

328 table_name: str, 

329 columns: List[str] = None, 

330 conditions: List[FilterCondition] = None, 

331 limit: int = None, 

332 offset: int = None, 

333 ) -> Select: 

334 """ 

335 given inputs, build string query 

336 """ 

337 limit_clause = Constant(limit) if limit else None 

338 offset_clause = Constant(offset) if offset else None 

339 

340 # translate filter conditions to dictionary 

341 filter_conditions, embedding_search = self._translate_conditions(conditions) 

342 

343 # given filter conditions, construct where clause 

344 where_clause = self._construct_where_clause(filter_conditions) 

345 

346 # Handle distance column specially since it's calculated, not stored 

347 modified_columns = [] 

348 has_distance = False 

349 if columns is not None: 

350 for col in columns: 

351 if col == TableField.DISTANCE.value: 

352 has_distance = True 

353 else: 

354 modified_columns.append(col) 

355 else: 

356 modified_columns = ["id", "content", "embeddings", "metadata"] 

357 has_distance = True 

358 

359 targets = [Identifier(col) for col in modified_columns] 

360 

361 query = Select( 

362 targets=targets, 

363 from_table=Identifier(table_name), 

364 where=where_clause, 

365 limit=limit_clause, 

366 offset=offset_clause, 

367 ) 

368 

369 if embedding_search: 

370 search_vector = embedding_search["value"] 

371 

372 if self._is_sparse: 

373 # Convert dict to sparse vector if needed 

374 if isinstance(search_vector, dict): 

375 from pgvector.utils import SparseVector 

376 

377 embedding = SparseVector(search_vector, self._vector_size) 

378 search_vector = embedding.to_text() 

379 else: 

380 # Convert list to vector string if needed 

381 if isinstance(search_vector, list): 

382 search_vector = f"[{','.join(str(x) for x in search_vector)}]" 

383 

384 vector_op = BinaryOperation( 

385 op=self.distance_op, 

386 args=[Identifier("embeddings"), Constant(search_vector)], 

387 alias=Identifier("distance"), 

388 ) 

389 # Calculate distance as part of the query if needed 

390 if has_distance: 

391 query.targets.append(vector_op) 

392 

393 query.order_by = [OrderBy(vector_op, direction="ASC")] 

394 

395 return query 

396 

397 def _check_table(self, table_name: str): 

398 # Apply namespace for a user 

399 if self._is_shared_db: 

400 company_id = ctx.company_id or "x" 

401 return f"t_{company_id}_{table_name}" 

402 return table_name 

403 

404 def select( 

405 self, 

406 table_name: str, 

407 columns: List[str] = None, 

408 conditions: List[FilterCondition] = None, 

409 offset: int = None, 

410 limit: int = None, 

411 ) -> pd.DataFrame: 

412 """ 

413 Retrieve the data from the SQL statement with eliminated rows that dont satisfy the WHERE condition 

414 """ 

415 table_name = self._check_table(table_name) 

416 

417 if columns is None: 

418 columns = ["id", "content", "embeddings", "metadata"] 

419 

420 query = self._build_select_query(table_name, columns, conditions, limit, offset) 

421 query_str = self.renderer.get_string(query, with_failback=True) 

422 result = self.raw_query(query_str) 

423 

424 # ensure embeddings are returned as string so they can be parsed by mindsdb 

425 if "embeddings" in columns: 

426 result["embeddings"] = result["embeddings"].apply(list) 

427 

428 return result 

429 

430 def keyword_select( 

431 self, 

432 table_name: str, 

433 columns: List[str] = None, 

434 conditions: List[FilterCondition] = None, 

435 offset: int = None, 

436 limit: int = None, 

437 keyword_search_args: KeywordSearchArgs = None, 

438 ) -> pd.DataFrame: 

439 table_name = self._check_table(table_name) 

440 

441 if columns is None: 

442 columns = ["id", "content", "embeddings", "metadata"] 

443 

444 query = self._build_keyword_bm25_query(table_name, keyword_search_args, columns, conditions, limit, offset) 

445 query_str = self.renderer.get_string(query, with_failback=True) 

446 result = self.raw_query(query_str) 

447 

448 # ensure embeddings are returned as string so they can be parsed by mindsdb 

449 if "embeddings" in columns: 

450 result["embeddings"] = result["embeddings"].astype(str) 

451 

452 return result 

453 

454 def hybrid_search( 

455 self, 

456 table_name: str, 

457 embeddings: List[float], 

458 query: str = None, 

459 metadata: Dict[str, str] = None, 

460 distance_function=DistanceFunction.COSINE_DISTANCE, 

461 **kwargs, 

462 ) -> pd.DataFrame: 

463 """ 

464 Executes a hybrid search, combining semantic search and one or both of keyword/metadata search. 

465 

466 For insight on the query construction, see: https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search. 

467 

468 Args: 

469 table_name(str): Name of underlying table containing content, embeddings, & metadata 

470 embeddings(List[float]): Embedding vector to perform semantic search against 

471 query(str): User query to convert into keywords for keyword search 

472 metadata(Dict[str, str]): Metadata filters to filter content rows against 

473 distance_function(DistanceFunction): Distance function used to compare embeddings vectors for semantic search 

474 

475 Kwargs: 

476 id_column_name(str): Name of ID column in underlying table 

477 content_column_name(str): Name of column containing document content in underlying table 

478 embeddings_column_name(str): Name of column containing embeddings vectors in underlying table 

479 metadata_column_name(str): Name of column containing metadata key-value pairs in underlying table 

480 

481 Returns: 

482 df(pd.DataFrame): Hybrid search result, sorted by hybrid search rank 

483 """ 

484 if query is None and metadata is None: 

485 raise ValueError( 

486 "Must provide at least one of: query for keyword search, or metadata filters. For only embeddings search, use normal search instead." 

487 ) 

488 

489 id_column_name = kwargs.get("id_column_name", "id") 

490 content_column_name = kwargs.get("content_column_name", "content") 

491 embeddings_column_name = kwargs.get("embeddings_column_name", "embeddings") 

492 metadata_column_name = kwargs.get("metadata_column_name", "metadata") 

493 # Filter by given metadata for semantic search & full text search CTEs, if present. 

494 where_clause = " WHERE " 

495 if metadata is None: 

496 where_clause = "" 

497 metadata = {} 

498 for i, (k, v) in enumerate(metadata.items()): 

499 where_clause += f"{metadata_column_name}->>'{k}' = '{v}'" 

500 if i < len(metadata.items()) - 1: 

501 where_clause += " AND " 

502 

503 # See https://docs.pgvecto.rs/use-case/hybrid-search.html#advanced-search-merge-the-results-of-full-text-search-and-vector-search. 

504 # 

505 # We can break down the below query as follows: 

506 # 

507 # Start with a CTE (Common Table Expression) called semantic_search (https://www.postgresql.org/docs/current/queries-with.html). 

508 # This expression calculates rank by the defined distance function, which measures the distance between the 

509 # embeddings column and the given embeddings vector. Results are ordered by this rank. 

510 # 

511 # Next, define another CTE called full_text_search if we are doing keyword search. 

512 # This calculates rank using the built-in ts_rank function (https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-RANKING). 

513 # We convert the content column to a ts_vector and match rows for the given tsquery in the content column. Results are ordered by this ts_rank. 

514 # 

515 # For both of these CTEs, we filter by any given metadata fields. 

516 # 

517 # See https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-PARSING-DOCUMENTS for to_tsvector 

518 # See https://www.postgresql.org/docs/current/functions-textsearch.html#FUNCTIONS-TEXTSEARCH for tsquery syntax 

519 # 

520 # Finally, we use a FULL OUTER JOIN to SELECT from both CTEs defined above. 

521 # The COALESCE function is used to handle cases where one CTE has null values. 

522 # 

523 # Or, if we are only doing metadata search, we leave out the JOIN & full text search CTEs. 

524 # 

525 # We calculate the final "hybrid" rank by summing the reciprocals of the ranks from each individual CTE. 

526 semantic_search_cte = f"""WITH semantic_search AS ( 

527 SELECT {id_column_name}, {content_column_name}, {embeddings_column_name}, 

528 RANK () OVER (ORDER BY {embeddings_column_name} {distance_function.value} '{str(embeddings)}') AS rank 

529 FROM {table_name}{where_clause} 

530 ORDER BY {embeddings_column_name} {distance_function.value} '{str(embeddings)}'::vector 

531 )""" 

532 

533 full_text_search_cte = "" 

534 if query is not None: 

535 ts_vector_clause = ( 

536 f"WHERE to_tsvector('english', {content_column_name}) @@ plainto_tsquery('english', '{query}')" 

537 ) 

538 if metadata: 

539 ts_vector_clause = ( 

540 f"AND to_tsvector('english', {content_column_name}) @@ plainto_tsquery('english', '{query}')" 

541 ) 

542 full_text_search_cte = f""", 

543 full_text_search AS ( 

544 SELECT {id_column_name}, {content_column_name}, {embeddings_column_name}, 

545 RANK () OVER (ORDER BY ts_rank(to_tsvector('english', {content_column_name}), plainto_tsquery('english', '{query}')) DESC) AS rank 

546 FROM {table_name}{where_clause} 

547 {ts_vector_clause} 

548 ORDER BY ts_rank(to_tsvector('english', {content_column_name}), plainto_tsquery('english', '{query}')) DESC 

549 )""" 

550 

551 hybrid_select = """ 

552 SELECT * FROM semantic_search""" 

553 if query is not None: 

554 hybrid_select = f""" 

555 SELECT 

556 COALESCE(semantic_search.{id_column_name}, full_text_search.{id_column_name}) AS id, 

557 COALESCE(semantic_search.{content_column_name}, full_text_search.{content_column_name}) AS content, 

558 COALESCE(semantic_search.{embeddings_column_name}, full_text_search.{embeddings_column_name}) AS embeddings, 

559 COALESCE(1.0 / (1 + semantic_search.rank), 0.0) + COALESCE(1.0 / (1 + full_text_search.rank), 0.0) AS rank 

560 FROM semantic_search FULL OUTER JOIN full_text_search USING ({id_column_name}) ORDER BY rank DESC; 

561 """ 

562 

563 full_search_query = f"{semantic_search_cte}{full_text_search_cte}{hybrid_select}" 

564 return self.raw_query(full_search_query) 

565 

566 def create_table(self, table_name: str): 

567 """Create a table with a vector column.""" 

568 

569 table_name = self._check_table(table_name) 

570 

571 with self.connection.cursor() as cur: 

572 # For sparse vectors, use sparsevec type 

573 vector_column_type = "sparsevec" if self._is_sparse else "vector" 

574 

575 # Vector size is required for sparse vectors, optional for dense 

576 if self._is_sparse and not self._vector_size: 

577 raise ValueError("vector_size is required for sparse vectors") 

578 

579 # Add vector size specification only if provided 

580 size_spec = f"({self._vector_size})" if self._vector_size is not None else "()" 

581 if vector_column_type == "vector": 

582 size_spec = "" 

583 

584 cur.execute(f""" 

585 CREATE TABLE IF NOT EXISTS {table_name} ( 

586 id TEXT PRIMARY KEY, 

587 embeddings {vector_column_type}{size_spec}, 

588 content TEXT, 

589 metadata JSONB 

590 ) 

591 """) 

592 self.connection.commit() 

593 

594 def insert(self, table_name: str, data: pd.DataFrame): 

595 """ 

596 Insert data into the pgvector table database. 

597 """ 

598 table_name = self._check_table(table_name) 

599 

600 if "metadata" in data.columns: 

601 data["metadata"] = data["metadata"].apply(json.dumps) 

602 

603 resp = super().insert(table_name, data) 

604 if resp.resp_type == RESPONSE_TYPE.ERROR: 

605 raise RuntimeError(resp.error_message) 

606 if resp.resp_type == RESPONSE_TYPE.TABLE: 

607 return resp.data_frame 

608 

609 def update(self, table_name: str, data: pd.DataFrame, key_columns: List[str] = None): 

610 """ 

611 Udate data into the pgvector table database. 

612 """ 

613 table_name = self._check_table(table_name) 

614 

615 where = None 

616 update_columns = {} 

617 

618 for col in data.columns: 

619 value = Parameter("%s") 

620 

621 if col in key_columns: 

622 cond = BinaryOperation(op="=", args=[Identifier(col), value]) 

623 if where is None: 

624 where = cond 

625 else: 

626 where = BinaryOperation(op="AND", args=[where, cond]) 

627 else: 

628 update_columns[col] = value 

629 

630 query = Update(table=Identifier(table_name), update_columns=update_columns, where=where) 

631 

632 if TableField.METADATA.value in data.columns: 

633 

634 def fnc(v): 

635 if isinstance(v, dict): 

636 return json.dumps(v) 

637 

638 data[TableField.METADATA.value] = data[TableField.METADATA.value].apply(fnc) 

639 

640 data = data.astype({TableField.METADATA.value: str}) 

641 

642 transposed_data = [] 

643 for _, record in data.iterrows(): 

644 row = [record[col] for col in update_columns.keys()] 

645 for key_column in key_columns: 

646 row.append(record[key_column]) 

647 transposed_data.append(row) 

648 

649 query_str = self.renderer.get_string(query) 

650 self.raw_query(query_str, transposed_data) 

651 

652 def delete(self, table_name: str, conditions: List[FilterCondition] = None): 

653 table_name = self._check_table(table_name) 

654 

655 filter_conditions, _ = self._translate_conditions(conditions) 

656 where_clause = self._construct_where_clause(filter_conditions) 

657 

658 query = Delete(table=Identifier(table_name), where=where_clause) 

659 query_str = self.renderer.get_string(query, with_failback=True) 

660 self.raw_query(query_str) 

661 

662 def drop_table(self, table_name: str, if_exists=True): 

663 """ 

664 Run a drop table query on the pgvector database. 

665 """ 

666 table_name = self._check_table(table_name) 

667 self.raw_query(f"DROP TABLE IF EXISTS {table_name}") 

668 

669 def create_index( 

670 self, 

671 table_name: str, 

672 column_name: str = "embeddings", 

673 index_type: Literal["ivfflat", "hnsw"] = "hnsw", 

674 metric_type: str = None, 

675 ): 

676 """ 

677 Create an index on the pgvector table. 

678 Args: 

679 table_name (str): Name of the table to create the index on. 

680 column_name (str): Name of the column to create the index on. 

681 index_type (str): Type of the index to create. Supported types are 'ivfflat' and 'hnsw'. 

682 metric_type (str): Metric type for the index. Supported types are 'vector_l2_ops', 'vector_ip_ops', and 'vector_cosine_ops'. 

683 """ 

684 if metric_type is None: 

685 metric_type = self.get_metric_type() 

686 # Check if the index type is supported 

687 if index_type not in ["ivfflat", "hnsw"]: 

688 raise ValueError("Invalid index type. Supported types are 'ivfflat' and 'hnsw'.") 

689 table_name = self._check_table(table_name) 

690 # first we make sure embedding dimension is set 

691 embedding_dim_size_df = self.raw_query(f"SELECT vector_dims({column_name}) FROM {table_name} LIMIT 1") 

692 # check if answer is empty 

693 if embedding_dim_size_df.empty: 

694 raise ValueError("Could not determine embedding dimension size. Make sure that knowledge base isn't empty") 

695 try: 

696 embedding_dim = int(embedding_dim_size_df.iloc[0, 0]) 

697 # alter table to add dimension 

698 self.raw_query(f"ALTER TABLE {table_name} ALTER COLUMN {column_name} TYPE vector({embedding_dim})") 

699 except Exception: 

700 raise ValueError("Could not determine embedding dimension size. Make sure that knowledge base isn't empty") 

701 

702 # Create the index 

703 self.raw_query(f"CREATE INDEX ON {table_name} USING {index_type} ({column_name} {metric_type})")