Coverage for mindsdb / integrations / handlers / duckdb_faiss_handler / duckdb_faiss_handler.py: 0%
254 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2from typing import List
4import pandas as pd
5import orjson
6import duckdb
7from mindsdb_sql_parser.ast import (
8 Select,
9 Delete,
10 Identifier,
11 BinaryOperation,
12 Constant,
13 NullConstant,
14 Star,
15 Tuple as AstTuple,
16 Function,
17 TypeCast,
18)
20from mindsdb.integrations.libs.response import (
21 RESPONSE_TYPE,
22 HandlerResponse as Response,
23 HandlerStatusResponse as StatusResponse,
24)
25from mindsdb.integrations.libs.vectordatabase_handler import (
26 FilterCondition,
27 VectorStoreHandler,
28 FilterOperator,
29)
30from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase
31from mindsdb.integrations.utilities.sql_utils import KeywordSearchArgs
33from mindsdb.utilities import log
34from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender
36from .faiss_index import FaissIndex
38logger = log.getLogger(__name__)
41class DuckDBFaissHandler(VectorStoreHandler, KeywordSearchBase):
42 """This handler handles connection and execution of DuckDB with Faiss vector indexing."""
44 name = "duckdb_faiss"
46 def __init__(self, name: str, **kwargs):
47 super().__init__(name=name)
48 self.single_instance = True
49 self.usage_lock = False
51 # Extract configuration
52 self.connection_data = kwargs.get("connection_data", {})
53 self.handler_storage = kwargs.get("handler_storage")
54 self.renderer = SqlalchemyRender("postgres")
56 # Storage paths
57 self._use_handler_storage = False
58 self.persist_directory = self.connection_data.get("persist_directory")
59 if self.persist_directory:
60 if not os.path.exists(self.persist_directory):
61 raise ValueError(f"Persist directory {self.persist_directory} does not exist")
62 else:
63 # Use default handler storage
64 self.persist_directory = self.handler_storage.folder_get("data")
65 self._use_handler_storage = True
67 # DuckDB connection
68 self.connection = None
69 self.is_connected = False
71 # Initialize storage paths
72 self.duckdb_path = os.path.join(self.persist_directory, "duckdb.db")
73 self.faiss_index_path = os.path.join(self.persist_directory, "faiss_index")
74 self.connect()
76 # check keyword index
77 self.is_kw_index_enabled = False
78 with self.connection.cursor() as cur:
79 # check index exists
80 df = cur.execute(
81 "SELECT * FROM information_schema.schemata WHERE schema_name = 'fts_main_meta_data'"
82 ).fetchdf()
83 if len(df) > 0:
84 self.is_kw_index_enabled = True
86 def connect(self) -> duckdb.DuckDBPyConnection:
87 """Connect to DuckDB database."""
88 if self.is_connected:
89 return self.connection
91 try:
92 self.connection = duckdb.connect(self.duckdb_path)
93 self.faiss_index = FaissIndex(self.faiss_index_path, self.connection_data)
94 self.is_connected = True
96 logger.info("Connected to DuckDB database")
97 return self.connection
99 except Exception as e:
100 logger.error(f"Error connecting to DuckDB: {e}")
101 raise
103 def disconnect(self):
104 """Close DuckDB connection."""
105 if self.is_connected and self.connection:
106 self.connection.close()
107 self.faiss_index.close()
108 self.is_connected = False
110 def create_table(self, table_name: str, if_not_exists=True):
111 with self.connection.cursor() as cur:
112 cur.execute("CREATE SEQUENCE IF NOT EXISTS faiss_id_sequence START 1")
114 cur.execute("""
115 CREATE TABLE IF NOT EXISTS meta_data (
116 faiss_id INTEGER PRIMARY KEY DEFAULT nextval('faiss_id_sequence'), -- id in FAISS index
117 id TEXT NOT NULL, -- chunk id
118 content TEXT,
119 metadata JSON
120 )
121 """)
123 def drop_table(self, table_name: str, if_exists=True):
124 """Drop table from both DuckDB and Faiss."""
125 with self.connection.cursor() as cur:
126 drop_sql = f"DROP TABLE {'IF EXISTS' if if_exists else ''} meta_data"
127 cur.execute(drop_sql)
129 if self.faiss_index:
130 self.faiss_index.drop()
132 def insert(self, table_name: str, data: pd.DataFrame):
133 """Insert data into both DuckDB and Faiss."""
135 if self.is_kw_index_enabled:
136 # drop index, it will be created before a first keyword search
137 self.drop_kw_index()
139 with self.connection.cursor() as cur:
140 df_ids = cur.execute("""
141 insert into meta_data (id, content, metadata) (
142 select id, content, metadata from data
143 )
144 RETURNING faiss_id, id
145 """).fetchdf()
147 data = data.merge(df_ids, on="id")
149 vectors = data["embeddings"]
150 ids = data["faiss_id"]
152 self.faiss_index.insert(list(vectors), list(ids))
153 self._sync()
155 # def upsert(self, table_name: str, data: pd.DataFrame):
156 # # delete by ids and insert
157 # ids = list(data['id'])
158 # self.delete(table_name, [FilterCondition(column='id', op=FilterOperator.IN, value=ids)])
159 # self.insert(table_name, data)
161 def select(
162 self,
163 table_name: str,
164 columns: List[str] = None,
165 conditions: List[FilterCondition] = None,
166 offset: int = None,
167 limit: int = None,
168 ) -> pd.DataFrame:
169 """Select data with hybrid search logic."""
171 vector_filter = None
172 meta_filters = []
173 if conditions is None:
174 conditions = []
175 for condition in conditions:
176 if condition.column == "embeddings":
177 vector_filter = condition
178 else:
179 meta_filters.append(condition)
181 if vector_filter is None:
182 # If only metadata in filter:
183 # query duckdb only
184 return self._select_from_metadata(meta_filters=meta_filters, limit=limit).drop("faiss_id", axis=1)
186 # vector_filter is not None
187 if not meta_filters:
188 # If only content in filter: query faiss and attach to metadata
189 return self._select_with_vector(vector_filter=vector_filter, limit=limit)
191 """
192 If metadata + content:
193 Query faiss, use limit = 1000
194 Query duckdb with `id in (...)`
195 If count of results is less than input LIMIT value
196 Repeat the search with increased limit value
197 Limit value for step = 1000 * 5^i (1000, 2000, 25000, 125000 …)
198 """
200 df = pd.DataFrame()
202 total_size = self.get_total_size()
204 for i in range(10):
205 batch_size = 1000 * 5**i
207 # TODO implement reverse search:
208 # if batch_size > 25% of db: search metadata first and then in faiss by list of ids
210 df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=batch_size)
211 if batch_size >= total_size or len(df) >= limit:
212 break
214 return df[:limit]
216 def create_kw_index(self):
217 with self.connection.cursor() as cur:
218 cur.execute("PRAGMA create_fts_index('meta_data', 'id', 'content')")
219 self.is_kw_index_enabled = True
221 def drop_kw_index(self):
222 with self.connection.cursor() as cur:
223 cur.execute("pragma drop_fts_index('meta_data')")
224 self.is_kw_index_enabled = False
226 def keyword_select(
227 self,
228 table_name: str,
229 columns: List[str] = None,
230 conditions: List[FilterCondition] = None,
231 offset: int = None,
232 limit: int = None,
233 keyword_search_args: KeywordSearchArgs = None,
234 ) -> pd.DataFrame:
235 if not self.is_kw_index_enabled:
236 # keyword search is used for first time: create index
237 self.create_kw_index()
239 with self.connection.cursor() as cur:
240 where_clause = self._translate_filters(conditions)
242 score = Function(
243 namespace="fts_main_meta_data",
244 op="match_bm25",
245 args=[
246 Identifier("id"),
247 Constant(keyword_search_args.query),
248 BinaryOperation(op=":=", args=[Identifier("fields"), Constant(keyword_search_args.column)]),
249 ],
250 )
252 no_emtpy_score = BinaryOperation(op="is not", args=[score, NullConstant()])
253 if where_clause:
254 where_clause = BinaryOperation(op="and", args=[where_clause, no_emtpy_score])
255 else:
256 where_clause = no_emtpy_score
258 query = Select(
259 targets=[Star(), BinaryOperation(op="-", args=[Constant(1), score], alias=Identifier("distance"))],
260 from_table=Identifier("meta_data"),
261 where=where_clause,
262 )
264 sql = self.renderer.get_string(query, with_failback=True)
265 cur.execute(sql)
266 df = cur.fetchdf()
267 df["metadata"] = df["metadata"].apply(orjson.loads)
268 return df
270 def get_total_size(self):
271 with self.connection.cursor() as cur:
272 cur.execute("select count(1) size from meta_data")
273 df = cur.fetchdf()
274 return df["size"].iloc[0]
276 def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, limit=None) -> pd.DataFrame:
277 embedding = vector_filter.value
278 if isinstance(embedding, str):
279 embedding = orjson.loads(embedding)
281 distances, faiss_ids = self.faiss_index.search(embedding, limit or 100)
283 # Fetch full data from DuckDB
284 if len(faiss_ids) > 0:
285 # ids = [str(idx) for idx in faiss_ids]
286 meta_df = self._select_from_metadata(faiss_ids=faiss_ids, meta_filters=meta_filters)
287 vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances})
288 return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance")
290 return pd.DataFrame([], columns=["id", "content", "metadata", "distance"])
292 def _select_from_metadata(self, faiss_ids=None, meta_filters=None, limit=None):
293 query = Select(
294 targets=[Star()],
295 from_table=Identifier("meta_data"),
296 )
298 where_clause = self._translate_filters(meta_filters)
300 if faiss_ids:
301 # TODO what if ids list is too long - split search into batches
302 in_filter = BinaryOperation(
303 op="IN", args=[Identifier("faiss_id"), AstTuple([Constant(i) for i in faiss_ids])]
304 )
305 # split into chunks
306 chunk_size = 10000
307 if len(faiss_ids) > chunk_size:
308 dfs = []
309 chunk = 0
310 total = 0
311 while chunk * chunk_size < len(faiss_ids):
312 # create results with partition
313 ids = faiss_ids[chunk * chunk_size : (chunk + 1) * chunk_size]
314 chunk += 1
315 df = self._select_from_metadata(faiss_ids=ids, meta_filters=meta_filters, limit=limit)
316 total += len(df)
317 if limit is not None and limit <= total:
318 # cut the extra from the end
319 df = df[: -(total - limit)]
320 dfs.append(df)
321 break
322 if len(df) > 0:
323 dfs.append(df)
324 if len(dfs) == 0:
325 return pd.DataFrame([], columns=["faiss_id", "id", "content", "metadata"])
326 return pd.concat(dfs)
328 if where_clause is None:
329 where_clause = in_filter
330 else:
331 where_clause = BinaryOperation(op="AND", args=[where_clause, in_filter])
333 if limit is not None:
334 query.limit = Constant(limit)
336 query.where = where_clause
338 with self.connection.cursor() as cur:
339 sql = self.renderer.get_string(query, with_failback=True)
340 cur.execute(sql)
341 df = cur.fetchdf()
342 df["metadata"] = df["metadata"].apply(orjson.loads)
343 return df
345 def _translate_filters(self, meta_filters):
346 if not meta_filters:
347 return None
349 where_clause = None
350 for item in meta_filters:
351 parts = item.column.split(".")
352 key = Identifier(parts[0])
354 # converts 'col.el1.el2' to col->'el1'->>'el2'
355 if len(parts) > 1:
356 # intermediate elements
357 for el in parts[1:-1]:
358 key = BinaryOperation(op="->", args=[key, Constant(el)])
360 # last element
361 key = BinaryOperation(op="->>", args=[key, Constant(parts[-1])])
363 is_orig_id = item.column == "metadata._original_doc_id"
365 type_cast = None
366 value = item.value
368 if isinstance(value, list) and len(value) > 0 and item.op in (FilterOperator.IN, FilterOperator.NOT_IN):
369 if is_orig_id:
370 # convert to str
371 item.value = [str(i) for i in value]
372 value = item.value[0]
373 elif is_orig_id:
374 if not isinstance(value, str):
375 value = item.value = str(item.value)
377 if isinstance(value, int):
378 type_cast = "int"
379 elif isinstance(value, float):
380 type_cast = "float"
382 if type_cast is not None:
383 key = TypeCast(type_cast, key)
385 if item.op in (FilterOperator.NOT_IN, FilterOperator.IN):
386 values = [Constant(i) for i in item.value]
387 value = AstTuple(values)
388 else:
389 value = Constant(item.value)
391 condition = BinaryOperation(op=item.op.value, args=[key, value])
393 if where_clause is None:
394 where_clause = condition
395 else:
396 where_clause = BinaryOperation(op="AND", args=[where_clause, condition])
397 return where_clause
399 def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> Response:
400 """Delete data from both DuckDB and Faiss."""
402 with self.connection.cursor() as cur:
403 where_clause = self._translate_filters(conditions)
405 query = Select(targets=[Identifier("faiss_id")], from_table=Identifier("meta_data"), where=where_clause)
406 cur.execute(self.renderer.get_string(query, with_failback=True))
407 df = cur.fetchdf()
408 ids = list(df["faiss_id"])
410 self.faiss_index.delete_ids(ids)
412 query = Delete(table=Identifier("meta_data"), where=where_clause)
413 cur.execute(self.renderer.get_string(query, with_failback=True))
415 self._sync()
417 def get_dimension(self, table_name: str) -> int:
418 if self.faiss_index:
419 return self.faiss_index.dim
421 def _sync(self):
422 """Sync the database to disk if using persistent storage"""
423 self.faiss_index.dump()
424 if self._use_handler_storage:
425 self.handler_storage.folder_sync(self.persist_directory)
427 def get_tables(self) -> Response:
428 """Get list of tables."""
429 with self.connection.cursor() as cur:
430 df = cur.execute("show tables").fetchdf()
431 df = df.rename(columns={"name": "table_name"})
433 return Response(RESPONSE_TYPE.TABLE, data_frame=df)
435 def check_connection(self) -> Response:
436 """Check the connection to the database."""
437 try:
438 if not self.is_connected:
439 self.connect()
440 return StatusResponse(RESPONSE_TYPE.OK)
441 except Exception as e:
442 logger.error(f"Connection check failed: {e}")
443 return StatusResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
445 def native_query(self, query: str) -> Response:
446 """Execute a native SQL query."""
447 try:
448 with self.connection.cursor() as cur:
449 cur.execute(query)
450 result = cur.fetchdf()
451 return Response(RESPONSE_TYPE.TABLE, data_frame=result)
452 except Exception as e:
453 logger.error(f"Error executing native query: {e}")
454 return Response(RESPONSE_TYPE.ERROR, error_message=str(e))
456 def __del__(self):
457 """Cleanup on deletion."""
458 if self.is_connected:
459 self._sync()
460 self.disconnect()