Coverage for mindsdb / integrations / handlers / chromadb_handler / chromadb_handler.py: 38%

257 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import ast 

3import shutil 

4import hashlib 

5from typing import Dict, List, Optional, Union 

6 

7import pandas as pd 

8import chromadb 

9from chromadb.api.shared_system_client import SharedSystemClient 

10 

11from mindsdb.integrations.handlers.chromadb_handler.settings import ChromaHandlerConfig 

12from mindsdb.integrations.libs.response import RESPONSE_TYPE 

13from mindsdb.integrations.libs.response import HandlerResponse 

14from mindsdb.integrations.libs.response import HandlerResponse as Response 

15from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

16from mindsdb.integrations.libs.vectordatabase_handler import ( 

17 FilterCondition, 

18 FilterOperator, 

19 TableField, 

20 VectorStoreHandler, 

21) 

22from mindsdb.utilities import log 

23 

24logger = log.getLogger(__name__) 

25 

26 

27class ChromaDBHandler(VectorStoreHandler): 

28 """This handler handles connection and execution of the ChromaDB statements.""" 

29 

30 name = "chromadb" 

31 

32 def __init__(self, name: str, **kwargs): 

33 super().__init__(name) 

34 self.handler_storage = kwargs["handler_storage"] 

35 self._client = None 

36 self.persist_directory = None 

37 self.is_connected = False 

38 self._use_handler_storage = False 

39 

40 config = self.validate_connection_parameters(name, **kwargs) 

41 

42 self._client_config = { 

43 "chroma_server_host": config.host, 

44 "chroma_server_http_port": config.port, 

45 "persist_directory": self.persist_directory, 

46 } 

47 

48 self.create_collection_metadata = { 

49 "hnsw:space": config.distance, 

50 } 

51 

52 def validate_connection_parameters(self, name, **kwargs): 

53 """ 

54 Validate the connection parameters. 

55 """ 

56 

57 _config = kwargs.get("connection_data") 

58 _config["vector_store"] = name 

59 

60 config = ChromaHandlerConfig(**_config) 

61 

62 if config.persist_directory: 62 ↛ 70line 62 didn't jump to line 70 because the condition on line 62 was always true

63 if os.path.isabs(config.persist_directory): 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true

64 self.persist_directory = config.persist_directory 

65 else: 

66 # get full persistence directory from handler storage 

67 self.persist_directory = self.handler_storage.folder_get(config.persist_directory) 

68 self._use_handler_storage = True 

69 

70 return config 

71 

72 def _get_client(self): 

73 client_config = self._client_config 

74 if client_config is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 raise Exception("Client config is not set!") 

76 

77 # decide the client type to be used, either persistent or httpclient 

78 if client_config["persist_directory"] is not None: 78 ↛ 82line 78 didn't jump to line 82 because the condition on line 78 was always true

79 SharedSystemClient.clear_system_cache() 

80 return chromadb.PersistentClient(path=client_config["persist_directory"]) 

81 else: 

82 return chromadb.HttpClient( 

83 host=client_config["chroma_server_host"], 

84 port=client_config["chroma_server_http_port"], 

85 ) 

86 

87 def _sync(self): 

88 """Sync the database to disk if using persistent storage""" 

89 if self.persist_directory and self._use_handler_storage: 89 ↛ exitline 89 didn't return from function '_sync' because the condition on line 89 was always true

90 self.handler_storage.folder_sync(self.persist_directory) 

91 

92 def __del__(self): 

93 """Ensure proper cleanup when the handler is destroyed""" 

94 if self.is_connected: 

95 self._sync() 

96 self.disconnect() 

97 

98 def connect(self): 

99 """Connect to a ChromaDB database.""" 

100 if self.is_connected is True: 

101 return self._client 

102 

103 try: 

104 self._client = self._get_client() 

105 self.is_connected = True 

106 return self._client 

107 except Exception as e: 

108 self.is_connected = False 

109 raise Exception(f"Error connecting to ChromaDB client, {e}!") 

110 

111 def disconnect(self): 

112 """Close the database connection.""" 

113 if self.is_connected: 113 ↛ exitline 113 didn't return from function 'disconnect' because the condition on line 113 was always true

114 if hasattr(self._client, "close"): 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 self._client.close() # Some ChromaDB clients have a close method 

116 self._client = None 

117 self.is_connected = False 

118 

119 def check_connection(self): 

120 """Check the connection to the ChromaDB database.""" 

121 response_code = StatusResponse(False) 

122 need_to_close = self.is_connected is False 

123 

124 try: 

125 self.connect() 

126 self._client.heartbeat() 

127 response_code.success = True 

128 except Exception as e: 

129 logger.error(f"Error connecting to ChromaDB , {e}!") 

130 response_code.error_message = str(e) 

131 finally: 

132 if response_code.success is True and need_to_close: 

133 self.disconnect() 

134 if response_code.success is False and self.is_connected is True: 

135 self.is_connected = False 

136 

137 return response_code 

138 

139 def _get_chromadb_operator(self, operator: FilterOperator) -> str: 

140 mapping = { 

141 FilterOperator.EQUAL: "$eq", 

142 FilterOperator.NOT_EQUAL: "$ne", 

143 FilterOperator.LESS_THAN: "$lt", 

144 FilterOperator.LESS_THAN_OR_EQUAL: "$lte", 

145 FilterOperator.GREATER_THAN: "$gt", 

146 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte", 

147 FilterOperator.IN: "$in", 

148 FilterOperator.NOT_IN: "$nin", 

149 } 

150 

151 if operator not in mapping: 

152 raise Exception(f"Operator {operator} is not supported by ChromaDB!") 

153 

154 return mapping[operator] 

155 

156 def _translate_metadata_condition(self, conditions: List[FilterCondition]) -> Optional[dict]: 

157 """ 

158 Translate a list of FilterCondition objects a dict that can be used by ChromaDB. 

159 E.g., 

160 [ 

161 FilterCondition( 

162 column="metadata.created_at", 

163 op=FilterOperator.LESS_THAN, 

164 value="2020-01-01", 

165 ), 

166 FilterCondition( 

167 column="metadata.created_at", 

168 op=FilterOperator.GREATER_THAN, 

169 value="2019-01-01", 

170 ) 

171 ] 

172 --> 

173 { 

174 "$and": [ 

175 {"created_at": {"$lt": "2020-01-01"}}, 

176 {"created_at": {"$gt": "2019-01-01"}} 

177 ] 

178 } 

179 """ 

180 # we ignore all non-metadata conditions 

181 if conditions is None: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true

182 return None 

183 metadata_conditions = [ 

184 condition for condition in conditions if condition.column.startswith(TableField.METADATA.value) 

185 ] 

186 if len(metadata_conditions) == 0: 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was always true

187 return None 

188 

189 # we translate each metadata condition into a dict 

190 chroma_db_conditions = [] 

191 for condition in metadata_conditions: 

192 metadata_key = condition.column.split(".")[-1] 

193 

194 chroma_db_conditions.append({metadata_key: {self._get_chromadb_operator(condition.op): condition.value}}) 

195 

196 # we combine all metadata conditions into a single dict 

197 metadata_condition = ( 

198 {"$and": chroma_db_conditions} if len(chroma_db_conditions) > 1 else chroma_db_conditions[0] 

199 ) 

200 return metadata_condition 

201 

202 def select( 

203 self, 

204 table_name: str, 

205 columns: List[str] = None, 

206 conditions: List[FilterCondition] = None, 

207 offset: int = None, 

208 limit: int = None, 

209 ) -> pd.DataFrame: 

210 self.disconnect() 

211 self.connect() 

212 collection = self._client.get_collection(table_name) 

213 filters = self._translate_metadata_condition(conditions) 

214 

215 include = ["metadatas", "documents", "embeddings"] 

216 

217 # check if embedding vector filter is present 

218 vector_filter = ( 

219 [] 

220 if conditions is None 

221 else [condition for condition in conditions if condition.column == TableField.EMBEDDINGS.value] 

222 ) 

223 

224 if len(vector_filter) > 0: 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true

225 vector_filter = vector_filter[0] 

226 else: 

227 vector_filter = None 

228 ids_include = [] 

229 ids_exclude = [] 

230 

231 if conditions is not None: 231 ↛ 244line 231 didn't jump to line 244 because the condition on line 231 was always true

232 for condition in conditions: 232 ↛ 233line 232 didn't jump to line 233 because the loop on line 232 never started

233 if condition.column != TableField.ID.value: 

234 continue 

235 if condition.op == FilterOperator.EQUAL: 

236 ids_include.append(condition.value) 

237 elif condition.op == FilterOperator.IN: 

238 ids_include.extend(condition.value) 

239 elif condition.op == FilterOperator.NOT_EQUAL: 

240 ids_exclude.append(condition.value) 

241 elif condition.op == FilterOperator.NOT_IN: 

242 ids_exclude.extend(condition.value) 

243 

244 if vector_filter is not None: 244 ↛ 246line 244 didn't jump to line 246 because the condition on line 244 was never true

245 # similarity search 

246 query_payload = { 

247 "where": filters, 

248 "query_embeddings": vector_filter.value if vector_filter is not None else None, 

249 "include": include + ["distances"], 

250 } 

251 

252 if limit is not None: 

253 if len(ids_include) == 0 and len(ids_exclude) == 0: 

254 query_payload["n_results"] = limit 

255 else: 

256 # get more results if we have filters by id 

257 query_payload["n_results"] = limit * 10 

258 

259 result = collection.query(**query_payload) 

260 ids = result["ids"][0] 

261 documents = result["documents"][0] 

262 metadatas = result["metadatas"][0] 

263 distances = result["distances"][0] 

264 embeddings = result["embeddings"][0] 

265 

266 else: 

267 # general get query 

268 result = collection.get( 

269 ids=ids_include or None, 

270 where=filters, 

271 limit=limit, 

272 offset=offset, 

273 include=include, 

274 ) 

275 ids = result["ids"] 

276 documents = result["documents"] 

277 metadatas = result["metadatas"] 

278 embeddings = result["embeddings"] 

279 distances = None 

280 

281 # project based on columns 

282 payload = { 

283 TableField.ID.value: ids, 

284 TableField.CONTENT.value: documents, 

285 TableField.METADATA.value: metadatas, 

286 TableField.EMBEDDINGS.value: list(embeddings), 

287 } 

288 

289 if columns is not None: 289 ↛ 293line 289 didn't jump to line 293 because the condition on line 289 was always true

290 payload = {column: payload[column] for column in columns if column != TableField.DISTANCE.value} 

291 

292 # always include distance 

293 distance_filter = None 

294 distance_col = TableField.DISTANCE.value 

295 if distances is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true

296 payload[distance_col] = distances 

297 

298 if conditions is not None: 

299 for cond in conditions: 

300 if cond.column == distance_col: 

301 distance_filter = cond 

302 break 

303 

304 df = pd.DataFrame(payload) 

305 if ids_exclude or ids_include: 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true

306 if ids_exclude: 

307 df = df[~df[TableField.ID.value].isin(ids_exclude)] 

308 if ids_include: 

309 df = df[df[TableField.ID.value].isin(ids_include)] 

310 if limit is not None: 

311 df = df[:limit] 

312 

313 if distance_filter is not None: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 op_map = { 

315 "<": "__lt__", 

316 "<=": "__le__", 

317 ">": "__gt__", 

318 ">=": "__ge__", 

319 "=": "__eq__", 

320 } 

321 op = op_map.get(distance_filter.op.value) 

322 if op: 

323 df = df[getattr(df[distance_col], op)(distance_filter.value)] 

324 return df 

325 

326 def _dataframe_metadata_to_chroma_metadata(self, metadata: Union[Dict[str, str], str]) -> Optional[Dict[str, str]]: 

327 """Convert DataFrame metadata to ChromaDB compatible metadata format""" 

328 if pd.isna(metadata) or metadata is None: 

329 return None 

330 if isinstance(metadata, dict): 

331 if not metadata: 

332 # ChromaDB does not support empty metadata dicts, but it does support None. 

333 # Related: https://github.com/chroma-core/chroma/issues/791. 

334 return None 

335 # Filter out None values from the metadata dict 

336 return {k: v for k, v in metadata.items() if pd.notna(v) and v is not None} 

337 # Metadata is a string representation of a dictionary instead. 

338 try: 

339 parsed = ast.literal_eval(metadata) 

340 if isinstance(parsed, dict): 

341 # Filter out None values from the parsed dict 

342 return {k: v for k, v in parsed.items() if pd.notna(v) and v is not None} 

343 return None 

344 except (ValueError, SyntaxError): 

345 return None 

346 

347 def _process_document_ids(self, df: pd.DataFrame) -> pd.DataFrame: 

348 """ 

349 Process document IDs for ChromaDB insertion/update. 

350 Only generates IDs if none are provided, otherwise ensures IDs are strings. 

351 

352 Args: 

353 df (pd.DataFrame): Input DataFrame containing document data 

354 

355 Returns: 

356 pd.DataFrame: DataFrame with processed IDs 

357 """ 

358 df = df.copy() # Create a copy to avoid modifying the original 

359 

360 if TableField.ID.value not in df.columns: 

361 # No IDs provided - generate hash-based IDs from content 

362 df = df.drop_duplicates(subset=[TableField.CONTENT.value]) 

363 df[TableField.ID.value] = df[TableField.CONTENT.value].apply( 

364 lambda content: hashlib.sha256(content.encode()).hexdigest() 

365 ) 

366 else: 

367 # Convert IDs to strings and remove any duplicates 

368 df[TableField.ID.value] = df[TableField.ID.value].astype(str) 

369 df = df.drop_duplicates(subset=[TableField.ID.value], keep="last") 

370 

371 return df 

372 

373 def insert(self, collection_name: str, df: pd.DataFrame) -> Response: 

374 """ 

375 Insert/Upsert data into ChromaDB collection. 

376 If records with same IDs exist, they will be updated. 

377 """ 

378 self.connect() 

379 collection = self._client.get_or_create_collection(collection_name, metadata=self.create_collection_metadata) 

380 

381 # Convert metadata from string to dict if needed 

382 if TableField.METADATA.value in df.columns: 

383 df[TableField.METADATA.value] = df[TableField.METADATA.value].apply( 

384 self._dataframe_metadata_to_chroma_metadata 

385 ) 

386 # Drop rows where metadata conversion failed 

387 df = df.dropna(subset=[TableField.METADATA.value]) 

388 

389 # Convert embeddings from string to list if they are strings 

390 if TableField.EMBEDDINGS.value in df.columns and df[TableField.EMBEDDINGS.value].dtype == "object": 

391 df[TableField.EMBEDDINGS.value] = df[TableField.EMBEDDINGS.value].apply( 

392 lambda x: ast.literal_eval(x) if isinstance(x, str) else x 

393 ) 

394 

395 # Process document IDs 

396 df = self._process_document_ids(df) 

397 

398 # Extract data from DataFrame 

399 data_dict = df.to_dict(orient="list") 

400 

401 try: 

402 collection.upsert( 

403 ids=data_dict[TableField.ID.value], 

404 documents=data_dict[TableField.CONTENT.value], 

405 embeddings=data_dict.get(TableField.EMBEDDINGS.value, None), 

406 metadatas=data_dict.get(TableField.METADATA.value, None), 

407 ) 

408 self._sync() 

409 except Exception as e: 

410 logger.error(f"Error during upsert operation: {str(e)}") 

411 raise Exception(f"Failed to insert/update data: {str(e)}") 

412 return Response(RESPONSE_TYPE.OK, affected_rows=len(df)) 

413 

414 def upsert(self, table_name: str, data: pd.DataFrame): 

415 """ 

416 Alias for insert since insert handles upsert functionality 

417 """ 

418 return self.insert(table_name, data) 

419 

420 def update( 

421 self, 

422 table_name: str, 

423 data: pd.DataFrame, 

424 key_columns: List[str] = None, 

425 ): 

426 """ 

427 Update data in the ChromaDB database. 

428 """ 

429 self.connect() 

430 collection = self._client.get_collection(table_name) 

431 

432 # drop columns with all None values 

433 

434 data.dropna(axis=1, inplace=True) 

435 

436 data = data.to_dict(orient="list") 

437 

438 collection.update( 

439 ids=data[TableField.ID.value], 

440 documents=data.get(TableField.CONTENT.value), 

441 embeddings=data[TableField.EMBEDDINGS.value], 

442 metadatas=data.get(TableField.METADATA.value), 

443 ) 

444 self._sync() 

445 

446 def delete(self, table_name: str, conditions: List[FilterCondition] = None): 

447 self.connect() 

448 filters = self._translate_metadata_condition(conditions) 

449 # get id filters 

450 id_filters = [condition.value for condition in conditions if condition.column == TableField.ID.value] or None 

451 

452 if filters is None and id_filters is None: 

453 raise Exception("Delete query must have at least one condition!") 

454 collection = self._client.get_collection(table_name) 

455 collection.delete(ids=id_filters, where=filters) 

456 self._sync() 

457 

458 def create_table(self, table_name: str, if_not_exists=True): 

459 """ 

460 Create a collection with the given name in the ChromaDB database. 

461 """ 

462 self.connect() 

463 self._client.create_collection( 

464 table_name, get_or_create=if_not_exists, metadata=self.create_collection_metadata 

465 ) 

466 self._sync() 

467 

468 def drop_table(self, table_name: str, if_exists=True): 

469 """ 

470 Delete a collection from the ChromaDB database. 

471 """ 

472 self.connect() 

473 try: 

474 # NOTE: there is a bug in chromadb v0.6.3 - it delete only segments that loaded in memory, 

475 # so we delete them manually 

476 if self._client_config.get("persist_directory") is not None: 476 ↛ 485line 476 didn't jump to line 485 because the condition on line 476 was always true

477 collection = self._client.get_collection(table_name) 

478 segments = self._client._server._sysdb.get_segments(collection.id) 

479 for segment in segments: 

480 self._client._server._sysdb.delete_segment(collection=collection.id, id=segment["id"]) 

481 shutil.rmtree( 

482 os.path.join(self._client_config["persist_directory"], str(segment["id"])), ignore_errors=True 

483 ) 

484 

485 self._client.delete_collection(table_name) 

486 self._sync() 

487 except ValueError: 

488 if if_exists: 

489 return 

490 else: 

491 raise Exception(f"Collection {table_name} does not exist!") 

492 

493 def get_tables(self) -> HandlerResponse: 

494 """ 

495 Get the list of collections in the ChromaDB database. 

496 """ 

497 self.connect() 

498 collections = self._client.list_collections() 

499 collections_name = pd.DataFrame( 

500 columns=["table_name"], 

501 data=collections, 

502 ) 

503 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name) 

504 

505 def get_columns(self, table_name: str) -> HandlerResponse: 

506 # check if collection exists 

507 self.connect() 

508 try: 

509 _ = self._client.get_collection(table_name) 

510 except ValueError: 

511 return Response( 

512 resp_type=RESPONSE_TYPE.ERROR, 

513 error_message=f"Table {table_name} does not exist!", 

514 ) 

515 return super().get_columns(table_name)