Coverage for mindsdb / integrations / handlers / milvus_handler / milvus_handler.py: 0%
173 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Optional
3import pandas as pd
4import json
5from pymilvus import MilvusClient, CollectionSchema, DataType, FieldSchema
7from mindsdb.integrations.libs.response import RESPONSE_TYPE
8from mindsdb.integrations.libs.response import HandlerResponse
9from mindsdb.integrations.libs.response import HandlerResponse as Response
10from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
11from mindsdb.integrations.libs.vectordatabase_handler import FilterCondition, FilterOperator, TableField, VectorStoreHandler
12from mindsdb.utilities import log
14logger = log.getLogger(__name__)
17class MilvusHandler(VectorStoreHandler):
18 """This handler handles connection and execution of the Milvus statements."""
20 name = "milvus"
22 def __init__(self, name: str, **kwargs):
23 super().__init__(name)
24 self.milvus_client = None
25 self._connection_data = kwargs["connection_data"]
26 # Extract parameters used while searching and leave the rest for establishing connection
27 self._search_limit = 100
28 if "search_default_limit" in self._connection_data:
29 self._search_limit = self._connection_data["search_default_limit"]
30 self._search_params = {
31 "search_metric_type": "L2",
32 "search_ignore_growing": False,
33 "search_params": {"nprobe": 10},
34 }
35 for search_param_name in self._search_params:
36 if search_param_name in self._connection_data:
37 self._search_params[search_param_name] = self._connection_data[search_param_name]
38 # Extract parameters used for creating tables
39 self._create_table_params = {
40 "create_auto_id": False,
41 "create_id_max_len": 64,
42 "create_embedding_dim": 8,
43 "create_dynamic_field": True,
44 "create_content_max_len": 200,
45 "create_content_default_value": "",
46 "create_schema_description": "MindsDB generated table",
47 "create_alias": "default",
48 "create_index_params": {},
49 "create_index_metric_type": "L2",
50 "create_index_type": "AUTOINDEX",
51 }
52 for create_table_param in self._create_table_params:
53 if create_table_param in self._connection_data:
54 self._create_table_params[create_table_param] = self._connection_data[create_table_param]
55 self.is_connected = False
56 self.connect()
58 def __del__(self):
59 if self.is_connected is True:
60 self.disconnect()
62 def connect(self):
63 """Connect to a Milvus database."""
64 if self.is_connected is True:
65 return
66 try:
67 self.milvus_client = MilvusClient(**self._connection_data)
68 self.is_connected = True
69 except Exception as e:
70 logger.error(f"Error connecting to Milvus client: {e}!")
71 self.is_connected = False
73 def disconnect(self):
74 """Close the database connection."""
75 if self.is_connected is False:
76 return
77 self.milvus_client.close()
78 self.is_connected = False
80 def check_connection(self):
81 """Check the connection to the Milvus database."""
82 response_code = StatusResponse(False)
83 try:
84 response_code.success = self.milvus_client is not None
85 except Exception as e:
86 logger.error(f"Error checking Milvus connection: {e}!")
87 response_code.error_message = str(e)
88 return response_code
90 def get_tables(self) -> HandlerResponse:
91 """Get the list of collections in the Milvus database."""
92 collections = self.milvus_client.list_collections()
93 collections_name = pd.DataFrame(
94 columns=["TABLE_NAME"],
95 data=[collection for collection in collections],
96 )
97 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name)
99 def drop_table(self, table_name: str, if_exists=True):
100 """Delete a collection from the Milvus database."""
101 try:
102 self.milvus_client.drop_collection(collection_name=table_name)
103 except Exception as e:
104 if not if_exists:
105 raise Exception(f"Error dropping table '{table_name}': {e}")
107 def _get_milvus_operator(self, operator: FilterOperator) -> str:
108 mapping = {
109 FilterOperator.EQUAL: "==",
110 FilterOperator.NOT_EQUAL: "!=",
111 FilterOperator.LESS_THAN: "<",
112 FilterOperator.LESS_THAN_OR_EQUAL: "<=",
113 FilterOperator.GREATER_THAN: ">",
114 FilterOperator.GREATER_THAN_OR_EQUAL: ">=",
115 FilterOperator.IN: "in",
116 FilterOperator.NOT_IN: "not in",
117 FilterOperator.LIKE: "like",
118 FilterOperator.NOT_LIKE: "not like",
119 }
120 if operator not in mapping:
121 raise Exception(f"Operator {operator} is not supported by Milvus!")
122 return mapping[operator]
124 def _translate_conditions(self, conditions: Optional[List[FilterCondition]], exclude_id: bool = True) -> Optional[str]:
125 """
126 Translate a list of FilterCondition objects a string that can be used by Milvus.
127 E.g.,
128 [
129 FilterCondition(
130 column="metadata.price",
131 op=FilterOperator.LESS_THAN,
132 value=1000,
133 ),
134 FilterCondition(
135 column="metadata.price",
136 op=FilterOperator.GREATER_THAN,
137 value=300,
138 )
139 ]
140 Is converted to: "(price < 1000) and (price > 300)"
141 If exclude_id is set to true then id column is ignored
142 """
143 if not conditions:
144 return
145 # Ignore all non-metadata conditions
146 filtered_conditions = [
147 condition
148 for condition in conditions
149 if condition.column.startswith(TableField.METADATA.value) or condition.column.startswith(TableField.ID.value)
150 ]
151 if len(filtered_conditions) == 0:
152 return None
153 # Translate each metadata condition into a dict
154 milvus_conditions = []
155 for condition in filtered_conditions:
156 if isinstance(condition.value, str):
157 condition.value = f"'{condition.value}'"
158 milvus_conditions.append(f"({condition.column.split('.')[-1]} {self._get_milvus_operator(condition.op)} {condition.value})")
159 # Combine all metadata conditions into a single string and return
160 return " and ".join(milvus_conditions) if milvus_conditions else None
162 def select(
163 self,
164 table_name: str,
165 columns: List[str] = None,
166 conditions: List[FilterCondition] = None,
167 offset: int = None,
168 limit: int = None,
169 ):
170 self.milvus_client.load_collection(collection_name=table_name)
171 # Find vector filter in conditions
172 vector_filter = (
173 []
174 if conditions is None
175 else [
176 condition.value
177 for condition in conditions
178 if condition.column == TableField.SEARCH_VECTOR.value
179 ]
180 )
182 # Generate search arguments
183 search_arguments = {}
184 # TODO: check if distance in columns work
185 if columns:
186 search_arguments["output_fields"] = columns
187 else:
188 search_arguments["output_fields"] = [schema_obj.name for schema_obj in self.SCHEMA]
189 search_arguments["filter"] = self._translate_conditions(conditions)
190 # NOTE: According to api sum of offset and limit should be less than 16384.
191 api_limit = 16384
192 if limit is not None and offset is not None and limit + offset >= api_limit:
193 raise Exception(f"Sum of limit and offset should be less than {api_limit}")
195 if limit is not None:
196 search_arguments["limit"] = limit
197 else:
198 search_arguments["limit"] = self._search_limit
199 if offset is not None:
200 search_arguments["offset"] = offset
202 # Vector search
203 if vector_filter:
204 search_arguments["data"] = vector_filter
205 search_arguments["anns_field"] = TableField.EMBEDDINGS.value
206 if "search_params" not in search_arguments:
207 search_arguments["search_params"] = {}
208 search_arguments["search_params"]["metric_type"] = self._search_params["search_metric_type"]
209 search_arguments["search_params"]["ignore_growing"] = self._search_params["search_ignore_growing"]
210 results = self.milvus_client.search(table_name, **search_arguments)[0]
211 columns_required = [TableField.ID.value, TableField.DISTANCE.value]
212 if TableField.CONTENT.value in columns:
213 columns_required.append(TableField.CONTENT.value)
214 if TableField.EMBEDDINGS.value in columns:
215 columns_required.append(TableField.EMBEDDINGS.value)
216 data = {k: [] for k in columns_required}
217 for hit in results:
218 for col in columns_required:
219 if col != TableField.DISTANCE.value:
220 data[col].append(hit["entity"].get(col))
221 else:
222 data[TableField.DISTANCE.value].append(hit["distance"])
223 return pd.DataFrame(data)
224 else:
225 # Basic search
226 if not search_arguments["filter"]:
227 search_arguments["filter"] = ""
228 search_arguments["output_fields"] = [
229 TableField.ID.value,
230 TableField.CONTENT.value,
231 TableField.EMBEDDINGS.value,
232 ] if not columns else columns
233 results = self.milvus_client.query(table_name, **search_arguments)
234 return pd.DataFrame.from_records(results)
236 def create_table(self, table_name: str, if_not_exists=True):
237 """Create a collection with default parameters in the Milvus database as described in documentation."""
238 id = FieldSchema(
239 name=TableField.ID.value,
240 dtype=DataType.VARCHAR,
241 is_primary=True,
242 max_length=self._create_table_params["create_id_max_len"],
243 auto_id=self._create_table_params["create_auto_id"]
244 )
245 embeddings = FieldSchema(
246 name=TableField.EMBEDDINGS.value,
247 dtype=DataType.FLOAT_VECTOR,
248 dim=self._create_table_params["create_embedding_dim"]
249 )
250 content = FieldSchema(
251 name=TableField.CONTENT.value,
252 dtype=DataType.VARCHAR,
253 max_length=self._create_table_params["create_content_max_len"],
254 default_value=self._create_table_params["create_content_default_value"]
255 )
256 schema = CollectionSchema(
257 fields=[id, content, embeddings],
258 description=self._create_table_params["create_schema_description"],
259 enable_dynamic_field=self._create_table_params["create_dynamic_field"]
260 )
261 collection_name = table_name
262 self.milvus_client.create_collection(
263 collection_name=collection_name,
264 schema=schema
265 )
266 index_params = self.milvus_client.prepare_index_params()
267 index_params.add_index(
268 field_name=TableField.EMBEDDINGS.value,
269 index_type=self._create_table_params["create_index_type"],
270 metric_type=self._create_table_params["create_index_metric_type"],
271 params=self._create_table_params.get("create_params", {})
272 )
273 self.milvus_client.create_index(
274 collection_name=collection_name,
275 index_params=index_params,
276 )
278 def insert(
279 self, table_name: str, data: pd.DataFrame, columns: List[str] = None
280 ):
281 """Insert data into the Milvus collection."""
282 self.milvus_client.load_collection(collection_name=table_name)
283 if columns:
284 data = data[columns]
285 if TableField.METADATA.value in data.columns:
286 rows = data[TableField.METADATA.value].to_list()
287 for i, row in enumerate(rows):
288 if isinstance(row, str):
289 rows[i] = json.loads(row)
290 data = pd.concat([data, pd.DataFrame.from_records(rows)], axis=1)
291 data.drop(TableField.METADATA.value, axis=1, inplace=True)
292 data_list = data.to_dict(orient="records")
293 for data_dict in data_list:
294 if TableField.EMBEDDINGS.value in data_dict and isinstance(data_dict[TableField.EMBEDDINGS.value], str):
295 data_dict[TableField.EMBEDDINGS.value] = json.loads(data_dict[TableField.EMBEDDINGS.value])
296 self.milvus_client.insert(table_name, data_list)
298 def delete(
299 self, table_name: str, conditions: List[FilterCondition] = None
300 ):
301 # delete only supports IN operator
302 for condition in conditions:
303 if condition.op in [FilterOperator.EQUAL, FilterOperator.IN]:
304 condition.op = FilterOperator.IN
305 if not isinstance(condition.value, list):
306 condition.value = [condition.value]
307 filters = self._translate_conditions(conditions, exclude_id=False)
308 if not filters:
309 raise Exception("Some filters are required, use DROP TABLE to delete everything")
310 if self.milvus_client.has_collection(collection_name=table_name):
311 self.milvus_client.delete(table_name, filter=filters)
313 def get_columns(self, table_name: str) -> HandlerResponse:
314 """Get columns in a Milvus collection"""
315 try:
316 self.milvus_client.has_collection(collection_name=table_name)
317 except Exception as e:
318 return Response(
319 resp_type=RESPONSE_TYPE.ERROR,
320 error_message=f"Error finding table: {e}",
321 )
322 try:
323 field_names = {field["name"] for field in self.milvus_client.describe_collection(collection_name=table_name)["fields"]}
324 schema = [mindsdb_schema_field for mindsdb_schema_field in self.SCHEMA if mindsdb_schema_field["name"] in field_names]
325 data = pd.DataFrame(schema)
326 data.columns = ["COLUMN_NAME", "DATA_TYPE"]
327 return HandlerResponse(data_frame=data)
328 except Exception as e:
329 return Response(
330 resp_type=RESPONSE_TYPE.ERROR,
331 error_message=f"Error finding table: {e}",
332 )