Coverage for mindsdb / integrations / handlers / chromadb_handler / chromadb_handler.py: 38%
257 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import ast
3import shutil
4import hashlib
5from typing import Dict, List, Optional, Union
7import pandas as pd
8import chromadb
9from chromadb.api.shared_system_client import SharedSystemClient
11from mindsdb.integrations.handlers.chromadb_handler.settings import ChromaHandlerConfig
12from mindsdb.integrations.libs.response import RESPONSE_TYPE
13from mindsdb.integrations.libs.response import HandlerResponse
14from mindsdb.integrations.libs.response import HandlerResponse as Response
15from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
16from mindsdb.integrations.libs.vectordatabase_handler import (
17 FilterCondition,
18 FilterOperator,
19 TableField,
20 VectorStoreHandler,
21)
22from mindsdb.utilities import log
24logger = log.getLogger(__name__)
27class ChromaDBHandler(VectorStoreHandler):
28 """This handler handles connection and execution of the ChromaDB statements."""
30 name = "chromadb"
32 def __init__(self, name: str, **kwargs):
33 super().__init__(name)
34 self.handler_storage = kwargs["handler_storage"]
35 self._client = None
36 self.persist_directory = None
37 self.is_connected = False
38 self._use_handler_storage = False
40 config = self.validate_connection_parameters(name, **kwargs)
42 self._client_config = {
43 "chroma_server_host": config.host,
44 "chroma_server_http_port": config.port,
45 "persist_directory": self.persist_directory,
46 }
48 self.create_collection_metadata = {
49 "hnsw:space": config.distance,
50 }
52 def validate_connection_parameters(self, name, **kwargs):
53 """
54 Validate the connection parameters.
55 """
57 _config = kwargs.get("connection_data")
58 _config["vector_store"] = name
60 config = ChromaHandlerConfig(**_config)
62 if config.persist_directory: 62 ↛ 70line 62 didn't jump to line 70 because the condition on line 62 was always true
63 if os.path.isabs(config.persist_directory): 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true
64 self.persist_directory = config.persist_directory
65 else:
66 # get full persistence directory from handler storage
67 self.persist_directory = self.handler_storage.folder_get(config.persist_directory)
68 self._use_handler_storage = True
70 return config
72 def _get_client(self):
73 client_config = self._client_config
74 if client_config is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 raise Exception("Client config is not set!")
77 # decide the client type to be used, either persistent or httpclient
78 if client_config["persist_directory"] is not None: 78 ↛ 82line 78 didn't jump to line 82 because the condition on line 78 was always true
79 SharedSystemClient.clear_system_cache()
80 return chromadb.PersistentClient(path=client_config["persist_directory"])
81 else:
82 return chromadb.HttpClient(
83 host=client_config["chroma_server_host"],
84 port=client_config["chroma_server_http_port"],
85 )
87 def _sync(self):
88 """Sync the database to disk if using persistent storage"""
89 if self.persist_directory and self._use_handler_storage: 89 ↛ exitline 89 didn't return from function '_sync' because the condition on line 89 was always true
90 self.handler_storage.folder_sync(self.persist_directory)
92 def __del__(self):
93 """Ensure proper cleanup when the handler is destroyed"""
94 if self.is_connected:
95 self._sync()
96 self.disconnect()
98 def connect(self):
99 """Connect to a ChromaDB database."""
100 if self.is_connected is True:
101 return self._client
103 try:
104 self._client = self._get_client()
105 self.is_connected = True
106 return self._client
107 except Exception as e:
108 self.is_connected = False
109 raise Exception(f"Error connecting to ChromaDB client, {e}!")
111 def disconnect(self):
112 """Close the database connection."""
113 if self.is_connected: 113 ↛ exitline 113 didn't return from function 'disconnect' because the condition on line 113 was always true
114 if hasattr(self._client, "close"): 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 self._client.close() # Some ChromaDB clients have a close method
116 self._client = None
117 self.is_connected = False
119 def check_connection(self):
120 """Check the connection to the ChromaDB database."""
121 response_code = StatusResponse(False)
122 need_to_close = self.is_connected is False
124 try:
125 self.connect()
126 self._client.heartbeat()
127 response_code.success = True
128 except Exception as e:
129 logger.error(f"Error connecting to ChromaDB , {e}!")
130 response_code.error_message = str(e)
131 finally:
132 if response_code.success is True and need_to_close:
133 self.disconnect()
134 if response_code.success is False and self.is_connected is True:
135 self.is_connected = False
137 return response_code
139 def _get_chromadb_operator(self, operator: FilterOperator) -> str:
140 mapping = {
141 FilterOperator.EQUAL: "$eq",
142 FilterOperator.NOT_EQUAL: "$ne",
143 FilterOperator.LESS_THAN: "$lt",
144 FilterOperator.LESS_THAN_OR_EQUAL: "$lte",
145 FilterOperator.GREATER_THAN: "$gt",
146 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte",
147 FilterOperator.IN: "$in",
148 FilterOperator.NOT_IN: "$nin",
149 }
151 if operator not in mapping:
152 raise Exception(f"Operator {operator} is not supported by ChromaDB!")
154 return mapping[operator]
156 def _translate_metadata_condition(self, conditions: List[FilterCondition]) -> Optional[dict]:
157 """
158 Translate a list of FilterCondition objects a dict that can be used by ChromaDB.
159 E.g.,
160 [
161 FilterCondition(
162 column="metadata.created_at",
163 op=FilterOperator.LESS_THAN,
164 value="2020-01-01",
165 ),
166 FilterCondition(
167 column="metadata.created_at",
168 op=FilterOperator.GREATER_THAN,
169 value="2019-01-01",
170 )
171 ]
172 -->
173 {
174 "$and": [
175 {"created_at": {"$lt": "2020-01-01"}},
176 {"created_at": {"$gt": "2019-01-01"}}
177 ]
178 }
179 """
180 # we ignore all non-metadata conditions
181 if conditions is None: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true
182 return None
183 metadata_conditions = [
184 condition for condition in conditions if condition.column.startswith(TableField.METADATA.value)
185 ]
186 if len(metadata_conditions) == 0: 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was always true
187 return None
189 # we translate each metadata condition into a dict
190 chroma_db_conditions = []
191 for condition in metadata_conditions:
192 metadata_key = condition.column.split(".")[-1]
194 chroma_db_conditions.append({metadata_key: {self._get_chromadb_operator(condition.op): condition.value}})
196 # we combine all metadata conditions into a single dict
197 metadata_condition = (
198 {"$and": chroma_db_conditions} if len(chroma_db_conditions) > 1 else chroma_db_conditions[0]
199 )
200 return metadata_condition
202 def select(
203 self,
204 table_name: str,
205 columns: List[str] = None,
206 conditions: List[FilterCondition] = None,
207 offset: int = None,
208 limit: int = None,
209 ) -> pd.DataFrame:
210 self.disconnect()
211 self.connect()
212 collection = self._client.get_collection(table_name)
213 filters = self._translate_metadata_condition(conditions)
215 include = ["metadatas", "documents", "embeddings"]
217 # check if embedding vector filter is present
218 vector_filter = (
219 []
220 if conditions is None
221 else [condition for condition in conditions if condition.column == TableField.EMBEDDINGS.value]
222 )
224 if len(vector_filter) > 0: 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true
225 vector_filter = vector_filter[0]
226 else:
227 vector_filter = None
228 ids_include = []
229 ids_exclude = []
231 if conditions is not None: 231 ↛ 244line 231 didn't jump to line 244 because the condition on line 231 was always true
232 for condition in conditions: 232 ↛ 233line 232 didn't jump to line 233 because the loop on line 232 never started
233 if condition.column != TableField.ID.value:
234 continue
235 if condition.op == FilterOperator.EQUAL:
236 ids_include.append(condition.value)
237 elif condition.op == FilterOperator.IN:
238 ids_include.extend(condition.value)
239 elif condition.op == FilterOperator.NOT_EQUAL:
240 ids_exclude.append(condition.value)
241 elif condition.op == FilterOperator.NOT_IN:
242 ids_exclude.extend(condition.value)
244 if vector_filter is not None: 244 ↛ 246line 244 didn't jump to line 246 because the condition on line 244 was never true
245 # similarity search
246 query_payload = {
247 "where": filters,
248 "query_embeddings": vector_filter.value if vector_filter is not None else None,
249 "include": include + ["distances"],
250 }
252 if limit is not None:
253 if len(ids_include) == 0 and len(ids_exclude) == 0:
254 query_payload["n_results"] = limit
255 else:
256 # get more results if we have filters by id
257 query_payload["n_results"] = limit * 10
259 result = collection.query(**query_payload)
260 ids = result["ids"][0]
261 documents = result["documents"][0]
262 metadatas = result["metadatas"][0]
263 distances = result["distances"][0]
264 embeddings = result["embeddings"][0]
266 else:
267 # general get query
268 result = collection.get(
269 ids=ids_include or None,
270 where=filters,
271 limit=limit,
272 offset=offset,
273 include=include,
274 )
275 ids = result["ids"]
276 documents = result["documents"]
277 metadatas = result["metadatas"]
278 embeddings = result["embeddings"]
279 distances = None
281 # project based on columns
282 payload = {
283 TableField.ID.value: ids,
284 TableField.CONTENT.value: documents,
285 TableField.METADATA.value: metadatas,
286 TableField.EMBEDDINGS.value: list(embeddings),
287 }
289 if columns is not None: 289 ↛ 293line 289 didn't jump to line 293 because the condition on line 289 was always true
290 payload = {column: payload[column] for column in columns if column != TableField.DISTANCE.value}
292 # always include distance
293 distance_filter = None
294 distance_col = TableField.DISTANCE.value
295 if distances is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true
296 payload[distance_col] = distances
298 if conditions is not None:
299 for cond in conditions:
300 if cond.column == distance_col:
301 distance_filter = cond
302 break
304 df = pd.DataFrame(payload)
305 if ids_exclude or ids_include: 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true
306 if ids_exclude:
307 df = df[~df[TableField.ID.value].isin(ids_exclude)]
308 if ids_include:
309 df = df[df[TableField.ID.value].isin(ids_include)]
310 if limit is not None:
311 df = df[:limit]
313 if distance_filter is not None: 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 op_map = {
315 "<": "__lt__",
316 "<=": "__le__",
317 ">": "__gt__",
318 ">=": "__ge__",
319 "=": "__eq__",
320 }
321 op = op_map.get(distance_filter.op.value)
322 if op:
323 df = df[getattr(df[distance_col], op)(distance_filter.value)]
324 return df
326 def _dataframe_metadata_to_chroma_metadata(self, metadata: Union[Dict[str, str], str]) -> Optional[Dict[str, str]]:
327 """Convert DataFrame metadata to ChromaDB compatible metadata format"""
328 if pd.isna(metadata) or metadata is None:
329 return None
330 if isinstance(metadata, dict):
331 if not metadata:
332 # ChromaDB does not support empty metadata dicts, but it does support None.
333 # Related: https://github.com/chroma-core/chroma/issues/791.
334 return None
335 # Filter out None values from the metadata dict
336 return {k: v for k, v in metadata.items() if pd.notna(v) and v is not None}
337 # Metadata is a string representation of a dictionary instead.
338 try:
339 parsed = ast.literal_eval(metadata)
340 if isinstance(parsed, dict):
341 # Filter out None values from the parsed dict
342 return {k: v for k, v in parsed.items() if pd.notna(v) and v is not None}
343 return None
344 except (ValueError, SyntaxError):
345 return None
347 def _process_document_ids(self, df: pd.DataFrame) -> pd.DataFrame:
348 """
349 Process document IDs for ChromaDB insertion/update.
350 Only generates IDs if none are provided, otherwise ensures IDs are strings.
352 Args:
353 df (pd.DataFrame): Input DataFrame containing document data
355 Returns:
356 pd.DataFrame: DataFrame with processed IDs
357 """
358 df = df.copy() # Create a copy to avoid modifying the original
360 if TableField.ID.value not in df.columns:
361 # No IDs provided - generate hash-based IDs from content
362 df = df.drop_duplicates(subset=[TableField.CONTENT.value])
363 df[TableField.ID.value] = df[TableField.CONTENT.value].apply(
364 lambda content: hashlib.sha256(content.encode()).hexdigest()
365 )
366 else:
367 # Convert IDs to strings and remove any duplicates
368 df[TableField.ID.value] = df[TableField.ID.value].astype(str)
369 df = df.drop_duplicates(subset=[TableField.ID.value], keep="last")
371 return df
373 def insert(self, collection_name: str, df: pd.DataFrame) -> Response:
374 """
375 Insert/Upsert data into ChromaDB collection.
376 If records with same IDs exist, they will be updated.
377 """
378 self.connect()
379 collection = self._client.get_or_create_collection(collection_name, metadata=self.create_collection_metadata)
381 # Convert metadata from string to dict if needed
382 if TableField.METADATA.value in df.columns:
383 df[TableField.METADATA.value] = df[TableField.METADATA.value].apply(
384 self._dataframe_metadata_to_chroma_metadata
385 )
386 # Drop rows where metadata conversion failed
387 df = df.dropna(subset=[TableField.METADATA.value])
389 # Convert embeddings from string to list if they are strings
390 if TableField.EMBEDDINGS.value in df.columns and df[TableField.EMBEDDINGS.value].dtype == "object":
391 df[TableField.EMBEDDINGS.value] = df[TableField.EMBEDDINGS.value].apply(
392 lambda x: ast.literal_eval(x) if isinstance(x, str) else x
393 )
395 # Process document IDs
396 df = self._process_document_ids(df)
398 # Extract data from DataFrame
399 data_dict = df.to_dict(orient="list")
401 try:
402 collection.upsert(
403 ids=data_dict[TableField.ID.value],
404 documents=data_dict[TableField.CONTENT.value],
405 embeddings=data_dict.get(TableField.EMBEDDINGS.value, None),
406 metadatas=data_dict.get(TableField.METADATA.value, None),
407 )
408 self._sync()
409 except Exception as e:
410 logger.error(f"Error during upsert operation: {str(e)}")
411 raise Exception(f"Failed to insert/update data: {str(e)}")
412 return Response(RESPONSE_TYPE.OK, affected_rows=len(df))
414 def upsert(self, table_name: str, data: pd.DataFrame):
415 """
416 Alias for insert since insert handles upsert functionality
417 """
418 return self.insert(table_name, data)
420 def update(
421 self,
422 table_name: str,
423 data: pd.DataFrame,
424 key_columns: List[str] = None,
425 ):
426 """
427 Update data in the ChromaDB database.
428 """
429 self.connect()
430 collection = self._client.get_collection(table_name)
432 # drop columns with all None values
434 data.dropna(axis=1, inplace=True)
436 data = data.to_dict(orient="list")
438 collection.update(
439 ids=data[TableField.ID.value],
440 documents=data.get(TableField.CONTENT.value),
441 embeddings=data[TableField.EMBEDDINGS.value],
442 metadatas=data.get(TableField.METADATA.value),
443 )
444 self._sync()
446 def delete(self, table_name: str, conditions: List[FilterCondition] = None):
447 self.connect()
448 filters = self._translate_metadata_condition(conditions)
449 # get id filters
450 id_filters = [condition.value for condition in conditions if condition.column == TableField.ID.value] or None
452 if filters is None and id_filters is None:
453 raise Exception("Delete query must have at least one condition!")
454 collection = self._client.get_collection(table_name)
455 collection.delete(ids=id_filters, where=filters)
456 self._sync()
458 def create_table(self, table_name: str, if_not_exists=True):
459 """
460 Create a collection with the given name in the ChromaDB database.
461 """
462 self.connect()
463 self._client.create_collection(
464 table_name, get_or_create=if_not_exists, metadata=self.create_collection_metadata
465 )
466 self._sync()
468 def drop_table(self, table_name: str, if_exists=True):
469 """
470 Delete a collection from the ChromaDB database.
471 """
472 self.connect()
473 try:
474 # NOTE: there is a bug in chromadb v0.6.3 - it delete only segments that loaded in memory,
475 # so we delete them manually
476 if self._client_config.get("persist_directory") is not None: 476 ↛ 485line 476 didn't jump to line 485 because the condition on line 476 was always true
477 collection = self._client.get_collection(table_name)
478 segments = self._client._server._sysdb.get_segments(collection.id)
479 for segment in segments:
480 self._client._server._sysdb.delete_segment(collection=collection.id, id=segment["id"])
481 shutil.rmtree(
482 os.path.join(self._client_config["persist_directory"], str(segment["id"])), ignore_errors=True
483 )
485 self._client.delete_collection(table_name)
486 self._sync()
487 except ValueError:
488 if if_exists:
489 return
490 else:
491 raise Exception(f"Collection {table_name} does not exist!")
493 def get_tables(self) -> HandlerResponse:
494 """
495 Get the list of collections in the ChromaDB database.
496 """
497 self.connect()
498 collections = self._client.list_collections()
499 collections_name = pd.DataFrame(
500 columns=["table_name"],
501 data=collections,
502 )
503 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name)
505 def get_columns(self, table_name: str) -> HandlerResponse:
506 # check if collection exists
507 self.connect()
508 try:
509 _ = self._client.get_collection(table_name)
510 except ValueError:
511 return Response(
512 resp_type=RESPONSE_TYPE.ERROR,
513 error_message=f"Table {table_name} does not exist!",
514 )
515 return super().get_columns(table_name)