Coverage for mindsdb / integrations / handlers / duckdb_faiss_handler / duckdb_faiss_handler.py: 0%

254 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2from typing import List 

3 

4import pandas as pd 

5import orjson 

6import duckdb 

7from mindsdb_sql_parser.ast import ( 

8 Select, 

9 Delete, 

10 Identifier, 

11 BinaryOperation, 

12 Constant, 

13 NullConstant, 

14 Star, 

15 Tuple as AstTuple, 

16 Function, 

17 TypeCast, 

18) 

19 

20from mindsdb.integrations.libs.response import ( 

21 RESPONSE_TYPE, 

22 HandlerResponse as Response, 

23 HandlerStatusResponse as StatusResponse, 

24) 

25from mindsdb.integrations.libs.vectordatabase_handler import ( 

26 FilterCondition, 

27 VectorStoreHandler, 

28 FilterOperator, 

29) 

30from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase 

31from mindsdb.integrations.utilities.sql_utils import KeywordSearchArgs 

32 

33from mindsdb.utilities import log 

34from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

35 

36from .faiss_index import FaissIndex 

37 

38logger = log.getLogger(__name__) 

39 

40 

41class DuckDBFaissHandler(VectorStoreHandler, KeywordSearchBase): 

42 """This handler handles connection and execution of DuckDB with Faiss vector indexing.""" 

43 

44 name = "duckdb_faiss" 

45 

46 def __init__(self, name: str, **kwargs): 

47 super().__init__(name=name) 

48 self.single_instance = True 

49 self.usage_lock = False 

50 

51 # Extract configuration 

52 self.connection_data = kwargs.get("connection_data", {}) 

53 self.handler_storage = kwargs.get("handler_storage") 

54 self.renderer = SqlalchemyRender("postgres") 

55 

56 # Storage paths 

57 self._use_handler_storage = False 

58 self.persist_directory = self.connection_data.get("persist_directory") 

59 if self.persist_directory: 

60 if not os.path.exists(self.persist_directory): 

61 raise ValueError(f"Persist directory {self.persist_directory} does not exist") 

62 else: 

63 # Use default handler storage 

64 self.persist_directory = self.handler_storage.folder_get("data") 

65 self._use_handler_storage = True 

66 

67 # DuckDB connection 

68 self.connection = None 

69 self.is_connected = False 

70 

71 # Initialize storage paths 

72 self.duckdb_path = os.path.join(self.persist_directory, "duckdb.db") 

73 self.faiss_index_path = os.path.join(self.persist_directory, "faiss_index") 

74 self.connect() 

75 

76 # check keyword index 

77 self.is_kw_index_enabled = False 

78 with self.connection.cursor() as cur: 

79 # check index exists 

80 df = cur.execute( 

81 "SELECT * FROM information_schema.schemata WHERE schema_name = 'fts_main_meta_data'" 

82 ).fetchdf() 

83 if len(df) > 0: 

84 self.is_kw_index_enabled = True 

85 

86 def connect(self) -> duckdb.DuckDBPyConnection: 

87 """Connect to DuckDB database.""" 

88 if self.is_connected: 

89 return self.connection 

90 

91 try: 

92 self.connection = duckdb.connect(self.duckdb_path) 

93 self.faiss_index = FaissIndex(self.faiss_index_path, self.connection_data) 

94 self.is_connected = True 

95 

96 logger.info("Connected to DuckDB database") 

97 return self.connection 

98 

99 except Exception as e: 

100 logger.error(f"Error connecting to DuckDB: {e}") 

101 raise 

102 

103 def disconnect(self): 

104 """Close DuckDB connection.""" 

105 if self.is_connected and self.connection: 

106 self.connection.close() 

107 self.faiss_index.close() 

108 self.is_connected = False 

109 

110 def create_table(self, table_name: str, if_not_exists=True): 

111 with self.connection.cursor() as cur: 

112 cur.execute("CREATE SEQUENCE IF NOT EXISTS faiss_id_sequence START 1") 

113 

114 cur.execute(""" 

115 CREATE TABLE IF NOT EXISTS meta_data ( 

116 faiss_id INTEGER PRIMARY KEY DEFAULT nextval('faiss_id_sequence'), -- id in FAISS index  

117 id TEXT NOT NULL, -- chunk id  

118 content TEXT, 

119 metadata JSON 

120 ) 

121 """) 

122 

123 def drop_table(self, table_name: str, if_exists=True): 

124 """Drop table from both DuckDB and Faiss.""" 

125 with self.connection.cursor() as cur: 

126 drop_sql = f"DROP TABLE {'IF EXISTS' if if_exists else ''} meta_data" 

127 cur.execute(drop_sql) 

128 

129 if self.faiss_index: 

130 self.faiss_index.drop() 

131 

132 def insert(self, table_name: str, data: pd.DataFrame): 

133 """Insert data into both DuckDB and Faiss.""" 

134 

135 if self.is_kw_index_enabled: 

136 # drop index, it will be created before a first keyword search 

137 self.drop_kw_index() 

138 

139 with self.connection.cursor() as cur: 

140 df_ids = cur.execute(""" 

141 insert into meta_data (id, content, metadata) ( 

142 select id, content, metadata from data 

143 ) 

144 RETURNING faiss_id, id 

145 """).fetchdf() 

146 

147 data = data.merge(df_ids, on="id") 

148 

149 vectors = data["embeddings"] 

150 ids = data["faiss_id"] 

151 

152 self.faiss_index.insert(list(vectors), list(ids)) 

153 self._sync() 

154 

155 # def upsert(self, table_name: str, data: pd.DataFrame): 

156 # # delete by ids and insert 

157 # ids = list(data['id']) 

158 # self.delete(table_name, [FilterCondition(column='id', op=FilterOperator.IN, value=ids)]) 

159 # self.insert(table_name, data) 

160 

161 def select( 

162 self, 

163 table_name: str, 

164 columns: List[str] = None, 

165 conditions: List[FilterCondition] = None, 

166 offset: int = None, 

167 limit: int = None, 

168 ) -> pd.DataFrame: 

169 """Select data with hybrid search logic.""" 

170 

171 vector_filter = None 

172 meta_filters = [] 

173 if conditions is None: 

174 conditions = [] 

175 for condition in conditions: 

176 if condition.column == "embeddings": 

177 vector_filter = condition 

178 else: 

179 meta_filters.append(condition) 

180 

181 if vector_filter is None: 

182 # If only metadata in filter: 

183 # query duckdb only 

184 return self._select_from_metadata(meta_filters=meta_filters, limit=limit).drop("faiss_id", axis=1) 

185 

186 # vector_filter is not None 

187 if not meta_filters: 

188 # If only content in filter: query faiss and attach to metadata 

189 return self._select_with_vector(vector_filter=vector_filter, limit=limit) 

190 

191 """ 

192 If metadata + content: 

193 Query faiss, use limit = 1000 

194 Query duckdb with `id in (...)`  

195 If count of results is less than input LIMIT value 

196 Repeat the search with increased limit value 

197 Limit value for step = 1000 * 5^i (1000, 2000, 25000, 125000 …) 

198 """ 

199 

200 df = pd.DataFrame() 

201 

202 total_size = self.get_total_size() 

203 

204 for i in range(10): 

205 batch_size = 1000 * 5**i 

206 

207 # TODO implement reverse search: 

208 # if batch_size > 25% of db: search metadata first and then in faiss by list of ids 

209 

210 df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=batch_size) 

211 if batch_size >= total_size or len(df) >= limit: 

212 break 

213 

214 return df[:limit] 

215 

216 def create_kw_index(self): 

217 with self.connection.cursor() as cur: 

218 cur.execute("PRAGMA create_fts_index('meta_data', 'id', 'content')") 

219 self.is_kw_index_enabled = True 

220 

221 def drop_kw_index(self): 

222 with self.connection.cursor() as cur: 

223 cur.execute("pragma drop_fts_index('meta_data')") 

224 self.is_kw_index_enabled = False 

225 

226 def keyword_select( 

227 self, 

228 table_name: str, 

229 columns: List[str] = None, 

230 conditions: List[FilterCondition] = None, 

231 offset: int = None, 

232 limit: int = None, 

233 keyword_search_args: KeywordSearchArgs = None, 

234 ) -> pd.DataFrame: 

235 if not self.is_kw_index_enabled: 

236 # keyword search is used for first time: create index 

237 self.create_kw_index() 

238 

239 with self.connection.cursor() as cur: 

240 where_clause = self._translate_filters(conditions) 

241 

242 score = Function( 

243 namespace="fts_main_meta_data", 

244 op="match_bm25", 

245 args=[ 

246 Identifier("id"), 

247 Constant(keyword_search_args.query), 

248 BinaryOperation(op=":=", args=[Identifier("fields"), Constant(keyword_search_args.column)]), 

249 ], 

250 ) 

251 

252 no_emtpy_score = BinaryOperation(op="is not", args=[score, NullConstant()]) 

253 if where_clause: 

254 where_clause = BinaryOperation(op="and", args=[where_clause, no_emtpy_score]) 

255 else: 

256 where_clause = no_emtpy_score 

257 

258 query = Select( 

259 targets=[Star(), BinaryOperation(op="-", args=[Constant(1), score], alias=Identifier("distance"))], 

260 from_table=Identifier("meta_data"), 

261 where=where_clause, 

262 ) 

263 

264 sql = self.renderer.get_string(query, with_failback=True) 

265 cur.execute(sql) 

266 df = cur.fetchdf() 

267 df["metadata"] = df["metadata"].apply(orjson.loads) 

268 return df 

269 

270 def get_total_size(self): 

271 with self.connection.cursor() as cur: 

272 cur.execute("select count(1) size from meta_data") 

273 df = cur.fetchdf() 

274 return df["size"].iloc[0] 

275 

276 def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, limit=None) -> pd.DataFrame: 

277 embedding = vector_filter.value 

278 if isinstance(embedding, str): 

279 embedding = orjson.loads(embedding) 

280 

281 distances, faiss_ids = self.faiss_index.search(embedding, limit or 100) 

282 

283 # Fetch full data from DuckDB 

284 if len(faiss_ids) > 0: 

285 # ids = [str(idx) for idx in faiss_ids] 

286 meta_df = self._select_from_metadata(faiss_ids=faiss_ids, meta_filters=meta_filters) 

287 vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances}) 

288 return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance") 

289 

290 return pd.DataFrame([], columns=["id", "content", "metadata", "distance"]) 

291 

292 def _select_from_metadata(self, faiss_ids=None, meta_filters=None, limit=None): 

293 query = Select( 

294 targets=[Star()], 

295 from_table=Identifier("meta_data"), 

296 ) 

297 

298 where_clause = self._translate_filters(meta_filters) 

299 

300 if faiss_ids: 

301 # TODO what if ids list is too long - split search into batches 

302 in_filter = BinaryOperation( 

303 op="IN", args=[Identifier("faiss_id"), AstTuple([Constant(i) for i in faiss_ids])] 

304 ) 

305 # split into chunks 

306 chunk_size = 10000 

307 if len(faiss_ids) > chunk_size: 

308 dfs = [] 

309 chunk = 0 

310 total = 0 

311 while chunk * chunk_size < len(faiss_ids): 

312 # create results with partition 

313 ids = faiss_ids[chunk * chunk_size : (chunk + 1) * chunk_size] 

314 chunk += 1 

315 df = self._select_from_metadata(faiss_ids=ids, meta_filters=meta_filters, limit=limit) 

316 total += len(df) 

317 if limit is not None and limit <= total: 

318 # cut the extra from the end 

319 df = df[: -(total - limit)] 

320 dfs.append(df) 

321 break 

322 if len(df) > 0: 

323 dfs.append(df) 

324 if len(dfs) == 0: 

325 return pd.DataFrame([], columns=["faiss_id", "id", "content", "metadata"]) 

326 return pd.concat(dfs) 

327 

328 if where_clause is None: 

329 where_clause = in_filter 

330 else: 

331 where_clause = BinaryOperation(op="AND", args=[where_clause, in_filter]) 

332 

333 if limit is not None: 

334 query.limit = Constant(limit) 

335 

336 query.where = where_clause 

337 

338 with self.connection.cursor() as cur: 

339 sql = self.renderer.get_string(query, with_failback=True) 

340 cur.execute(sql) 

341 df = cur.fetchdf() 

342 df["metadata"] = df["metadata"].apply(orjson.loads) 

343 return df 

344 

345 def _translate_filters(self, meta_filters): 

346 if not meta_filters: 

347 return None 

348 

349 where_clause = None 

350 for item in meta_filters: 

351 parts = item.column.split(".") 

352 key = Identifier(parts[0]) 

353 

354 # converts 'col.el1.el2' to col->'el1'->>'el2' 

355 if len(parts) > 1: 

356 # intermediate elements 

357 for el in parts[1:-1]: 

358 key = BinaryOperation(op="->", args=[key, Constant(el)]) 

359 

360 # last element 

361 key = BinaryOperation(op="->>", args=[key, Constant(parts[-1])]) 

362 

363 is_orig_id = item.column == "metadata._original_doc_id" 

364 

365 type_cast = None 

366 value = item.value 

367 

368 if isinstance(value, list) and len(value) > 0 and item.op in (FilterOperator.IN, FilterOperator.NOT_IN): 

369 if is_orig_id: 

370 # convert to str 

371 item.value = [str(i) for i in value] 

372 value = item.value[0] 

373 elif is_orig_id: 

374 if not isinstance(value, str): 

375 value = item.value = str(item.value) 

376 

377 if isinstance(value, int): 

378 type_cast = "int" 

379 elif isinstance(value, float): 

380 type_cast = "float" 

381 

382 if type_cast is not None: 

383 key = TypeCast(type_cast, key) 

384 

385 if item.op in (FilterOperator.NOT_IN, FilterOperator.IN): 

386 values = [Constant(i) for i in item.value] 

387 value = AstTuple(values) 

388 else: 

389 value = Constant(item.value) 

390 

391 condition = BinaryOperation(op=item.op.value, args=[key, value]) 

392 

393 if where_clause is None: 

394 where_clause = condition 

395 else: 

396 where_clause = BinaryOperation(op="AND", args=[where_clause, condition]) 

397 return where_clause 

398 

399 def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> Response: 

400 """Delete data from both DuckDB and Faiss.""" 

401 

402 with self.connection.cursor() as cur: 

403 where_clause = self._translate_filters(conditions) 

404 

405 query = Select(targets=[Identifier("faiss_id")], from_table=Identifier("meta_data"), where=where_clause) 

406 cur.execute(self.renderer.get_string(query, with_failback=True)) 

407 df = cur.fetchdf() 

408 ids = list(df["faiss_id"]) 

409 

410 self.faiss_index.delete_ids(ids) 

411 

412 query = Delete(table=Identifier("meta_data"), where=where_clause) 

413 cur.execute(self.renderer.get_string(query, with_failback=True)) 

414 

415 self._sync() 

416 

417 def get_dimension(self, table_name: str) -> int: 

418 if self.faiss_index: 

419 return self.faiss_index.dim 

420 

421 def _sync(self): 

422 """Sync the database to disk if using persistent storage""" 

423 self.faiss_index.dump() 

424 if self._use_handler_storage: 

425 self.handler_storage.folder_sync(self.persist_directory) 

426 

427 def get_tables(self) -> Response: 

428 """Get list of tables.""" 

429 with self.connection.cursor() as cur: 

430 df = cur.execute("show tables").fetchdf() 

431 df = df.rename(columns={"name": "table_name"}) 

432 

433 return Response(RESPONSE_TYPE.TABLE, data_frame=df) 

434 

435 def check_connection(self) -> Response: 

436 """Check the connection to the database.""" 

437 try: 

438 if not self.is_connected: 

439 self.connect() 

440 return StatusResponse(RESPONSE_TYPE.OK) 

441 except Exception as e: 

442 logger.error(f"Connection check failed: {e}") 

443 return StatusResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) 

444 

445 def native_query(self, query: str) -> Response: 

446 """Execute a native SQL query.""" 

447 try: 

448 with self.connection.cursor() as cur: 

449 cur.execute(query) 

450 result = cur.fetchdf() 

451 return Response(RESPONSE_TYPE.TABLE, data_frame=result) 

452 except Exception as e: 

453 logger.error(f"Error executing native query: {e}") 

454 return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) 

455 

456 def __del__(self): 

457 """Cleanup on deletion.""" 

458 if self.is_connected: 

459 self._sync() 

460 self.disconnect()