Coverage for mindsdb / integrations / handlers / milvus_handler / milvus_handler.py: 0%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Optional 

2 

3import pandas as pd 

4import json 

5from pymilvus import MilvusClient, CollectionSchema, DataType, FieldSchema 

6 

7from mindsdb.integrations.libs.response import RESPONSE_TYPE 

8from mindsdb.integrations.libs.response import HandlerResponse 

9from mindsdb.integrations.libs.response import HandlerResponse as Response 

10from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

11from mindsdb.integrations.libs.vectordatabase_handler import FilterCondition, FilterOperator, TableField, VectorStoreHandler 

12from mindsdb.utilities import log 

13 

14logger = log.getLogger(__name__) 

15 

16 

17class MilvusHandler(VectorStoreHandler): 

18 """This handler handles connection and execution of the Milvus statements.""" 

19 

20 name = "milvus" 

21 

22 def __init__(self, name: str, **kwargs): 

23 super().__init__(name) 

24 self.milvus_client = None 

25 self._connection_data = kwargs["connection_data"] 

26 # Extract parameters used while searching and leave the rest for establishing connection 

27 self._search_limit = 100 

28 if "search_default_limit" in self._connection_data: 

29 self._search_limit = self._connection_data["search_default_limit"] 

30 self._search_params = { 

31 "search_metric_type": "L2", 

32 "search_ignore_growing": False, 

33 "search_params": {"nprobe": 10}, 

34 } 

35 for search_param_name in self._search_params: 

36 if search_param_name in self._connection_data: 

37 self._search_params[search_param_name] = self._connection_data[search_param_name] 

38 # Extract parameters used for creating tables 

39 self._create_table_params = { 

40 "create_auto_id": False, 

41 "create_id_max_len": 64, 

42 "create_embedding_dim": 8, 

43 "create_dynamic_field": True, 

44 "create_content_max_len": 200, 

45 "create_content_default_value": "", 

46 "create_schema_description": "MindsDB generated table", 

47 "create_alias": "default", 

48 "create_index_params": {}, 

49 "create_index_metric_type": "L2", 

50 "create_index_type": "AUTOINDEX", 

51 } 

52 for create_table_param in self._create_table_params: 

53 if create_table_param in self._connection_data: 

54 self._create_table_params[create_table_param] = self._connection_data[create_table_param] 

55 self.is_connected = False 

56 self.connect() 

57 

58 def __del__(self): 

59 if self.is_connected is True: 

60 self.disconnect() 

61 

62 def connect(self): 

63 """Connect to a Milvus database.""" 

64 if self.is_connected is True: 

65 return 

66 try: 

67 self.milvus_client = MilvusClient(**self._connection_data) 

68 self.is_connected = True 

69 except Exception as e: 

70 logger.error(f"Error connecting to Milvus client: {e}!") 

71 self.is_connected = False 

72 

73 def disconnect(self): 

74 """Close the database connection.""" 

75 if self.is_connected is False: 

76 return 

77 self.milvus_client.close() 

78 self.is_connected = False 

79 

80 def check_connection(self): 

81 """Check the connection to the Milvus database.""" 

82 response_code = StatusResponse(False) 

83 try: 

84 response_code.success = self.milvus_client is not None 

85 except Exception as e: 

86 logger.error(f"Error checking Milvus connection: {e}!") 

87 response_code.error_message = str(e) 

88 return response_code 

89 

90 def get_tables(self) -> HandlerResponse: 

91 """Get the list of collections in the Milvus database.""" 

92 collections = self.milvus_client.list_collections() 

93 collections_name = pd.DataFrame( 

94 columns=["TABLE_NAME"], 

95 data=[collection for collection in collections], 

96 ) 

97 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name) 

98 

99 def drop_table(self, table_name: str, if_exists=True): 

100 """Delete a collection from the Milvus database.""" 

101 try: 

102 self.milvus_client.drop_collection(collection_name=table_name) 

103 except Exception as e: 

104 if not if_exists: 

105 raise Exception(f"Error dropping table '{table_name}': {e}") 

106 

107 def _get_milvus_operator(self, operator: FilterOperator) -> str: 

108 mapping = { 

109 FilterOperator.EQUAL: "==", 

110 FilterOperator.NOT_EQUAL: "!=", 

111 FilterOperator.LESS_THAN: "<", 

112 FilterOperator.LESS_THAN_OR_EQUAL: "<=", 

113 FilterOperator.GREATER_THAN: ">", 

114 FilterOperator.GREATER_THAN_OR_EQUAL: ">=", 

115 FilterOperator.IN: "in", 

116 FilterOperator.NOT_IN: "not in", 

117 FilterOperator.LIKE: "like", 

118 FilterOperator.NOT_LIKE: "not like", 

119 } 

120 if operator not in mapping: 

121 raise Exception(f"Operator {operator} is not supported by Milvus!") 

122 return mapping[operator] 

123 

124 def _translate_conditions(self, conditions: Optional[List[FilterCondition]], exclude_id: bool = True) -> Optional[str]: 

125 """ 

126 Translate a list of FilterCondition objects a string that can be used by Milvus. 

127 E.g., 

128 [ 

129 FilterCondition( 

130 column="metadata.price", 

131 op=FilterOperator.LESS_THAN, 

132 value=1000, 

133 ), 

134 FilterCondition( 

135 column="metadata.price", 

136 op=FilterOperator.GREATER_THAN, 

137 value=300, 

138 ) 

139 ] 

140 Is converted to: "(price < 1000) and (price > 300)" 

141 If exclude_id is set to true then id column is ignored 

142 """ 

143 if not conditions: 

144 return 

145 # Ignore all non-metadata conditions 

146 filtered_conditions = [ 

147 condition 

148 for condition in conditions 

149 if condition.column.startswith(TableField.METADATA.value) or condition.column.startswith(TableField.ID.value) 

150 ] 

151 if len(filtered_conditions) == 0: 

152 return None 

153 # Translate each metadata condition into a dict 

154 milvus_conditions = [] 

155 for condition in filtered_conditions: 

156 if isinstance(condition.value, str): 

157 condition.value = f"'{condition.value}'" 

158 milvus_conditions.append(f"({condition.column.split('.')[-1]} {self._get_milvus_operator(condition.op)} {condition.value})") 

159 # Combine all metadata conditions into a single string and return 

160 return " and ".join(milvus_conditions) if milvus_conditions else None 

161 

162 def select( 

163 self, 

164 table_name: str, 

165 columns: List[str] = None, 

166 conditions: List[FilterCondition] = None, 

167 offset: int = None, 

168 limit: int = None, 

169 ): 

170 self.milvus_client.load_collection(collection_name=table_name) 

171 # Find vector filter in conditions 

172 vector_filter = ( 

173 [] 

174 if conditions is None 

175 else [ 

176 condition.value 

177 for condition in conditions 

178 if condition.column == TableField.SEARCH_VECTOR.value 

179 ] 

180 ) 

181 

182 # Generate search arguments 

183 search_arguments = {} 

184 # TODO: check if distance in columns work 

185 if columns: 

186 search_arguments["output_fields"] = columns 

187 else: 

188 search_arguments["output_fields"] = [schema_obj.name for schema_obj in self.SCHEMA] 

189 search_arguments["filter"] = self._translate_conditions(conditions) 

190 # NOTE: According to api sum of offset and limit should be less than 16384. 

191 api_limit = 16384 

192 if limit is not None and offset is not None and limit + offset >= api_limit: 

193 raise Exception(f"Sum of limit and offset should be less than {api_limit}") 

194 

195 if limit is not None: 

196 search_arguments["limit"] = limit 

197 else: 

198 search_arguments["limit"] = self._search_limit 

199 if offset is not None: 

200 search_arguments["offset"] = offset 

201 

202 # Vector search 

203 if vector_filter: 

204 search_arguments["data"] = vector_filter 

205 search_arguments["anns_field"] = TableField.EMBEDDINGS.value 

206 if "search_params" not in search_arguments: 

207 search_arguments["search_params"] = {} 

208 search_arguments["search_params"]["metric_type"] = self._search_params["search_metric_type"] 

209 search_arguments["search_params"]["ignore_growing"] = self._search_params["search_ignore_growing"] 

210 results = self.milvus_client.search(table_name, **search_arguments)[0] 

211 columns_required = [TableField.ID.value, TableField.DISTANCE.value] 

212 if TableField.CONTENT.value in columns: 

213 columns_required.append(TableField.CONTENT.value) 

214 if TableField.EMBEDDINGS.value in columns: 

215 columns_required.append(TableField.EMBEDDINGS.value) 

216 data = {k: [] for k in columns_required} 

217 for hit in results: 

218 for col in columns_required: 

219 if col != TableField.DISTANCE.value: 

220 data[col].append(hit["entity"].get(col)) 

221 else: 

222 data[TableField.DISTANCE.value].append(hit["distance"]) 

223 return pd.DataFrame(data) 

224 else: 

225 # Basic search 

226 if not search_arguments["filter"]: 

227 search_arguments["filter"] = "" 

228 search_arguments["output_fields"] = [ 

229 TableField.ID.value, 

230 TableField.CONTENT.value, 

231 TableField.EMBEDDINGS.value, 

232 ] if not columns else columns 

233 results = self.milvus_client.query(table_name, **search_arguments) 

234 return pd.DataFrame.from_records(results) 

235 

236 def create_table(self, table_name: str, if_not_exists=True): 

237 """Create a collection with default parameters in the Milvus database as described in documentation.""" 

238 id = FieldSchema( 

239 name=TableField.ID.value, 

240 dtype=DataType.VARCHAR, 

241 is_primary=True, 

242 max_length=self._create_table_params["create_id_max_len"], 

243 auto_id=self._create_table_params["create_auto_id"] 

244 ) 

245 embeddings = FieldSchema( 

246 name=TableField.EMBEDDINGS.value, 

247 dtype=DataType.FLOAT_VECTOR, 

248 dim=self._create_table_params["create_embedding_dim"] 

249 ) 

250 content = FieldSchema( 

251 name=TableField.CONTENT.value, 

252 dtype=DataType.VARCHAR, 

253 max_length=self._create_table_params["create_content_max_len"], 

254 default_value=self._create_table_params["create_content_default_value"] 

255 ) 

256 schema = CollectionSchema( 

257 fields=[id, content, embeddings], 

258 description=self._create_table_params["create_schema_description"], 

259 enable_dynamic_field=self._create_table_params["create_dynamic_field"] 

260 ) 

261 collection_name = table_name 

262 self.milvus_client.create_collection( 

263 collection_name=collection_name, 

264 schema=schema 

265 ) 

266 index_params = self.milvus_client.prepare_index_params() 

267 index_params.add_index( 

268 field_name=TableField.EMBEDDINGS.value, 

269 index_type=self._create_table_params["create_index_type"], 

270 metric_type=self._create_table_params["create_index_metric_type"], 

271 params=self._create_table_params.get("create_params", {}) 

272 ) 

273 self.milvus_client.create_index( 

274 collection_name=collection_name, 

275 index_params=index_params, 

276 ) 

277 

278 def insert( 

279 self, table_name: str, data: pd.DataFrame, columns: List[str] = None 

280 ): 

281 """Insert data into the Milvus collection.""" 

282 self.milvus_client.load_collection(collection_name=table_name) 

283 if columns: 

284 data = data[columns] 

285 if TableField.METADATA.value in data.columns: 

286 rows = data[TableField.METADATA.value].to_list() 

287 for i, row in enumerate(rows): 

288 if isinstance(row, str): 

289 rows[i] = json.loads(row) 

290 data = pd.concat([data, pd.DataFrame.from_records(rows)], axis=1) 

291 data.drop(TableField.METADATA.value, axis=1, inplace=True) 

292 data_list = data.to_dict(orient="records") 

293 for data_dict in data_list: 

294 if TableField.EMBEDDINGS.value in data_dict and isinstance(data_dict[TableField.EMBEDDINGS.value], str): 

295 data_dict[TableField.EMBEDDINGS.value] = json.loads(data_dict[TableField.EMBEDDINGS.value]) 

296 self.milvus_client.insert(table_name, data_list) 

297 

298 def delete( 

299 self, table_name: str, conditions: List[FilterCondition] = None 

300 ): 

301 # delete only supports IN operator 

302 for condition in conditions: 

303 if condition.op in [FilterOperator.EQUAL, FilterOperator.IN]: 

304 condition.op = FilterOperator.IN 

305 if not isinstance(condition.value, list): 

306 condition.value = [condition.value] 

307 filters = self._translate_conditions(conditions, exclude_id=False) 

308 if not filters: 

309 raise Exception("Some filters are required, use DROP TABLE to delete everything") 

310 if self.milvus_client.has_collection(collection_name=table_name): 

311 self.milvus_client.delete(table_name, filter=filters) 

312 

313 def get_columns(self, table_name: str) -> HandlerResponse: 

314 """Get columns in a Milvus collection""" 

315 try: 

316 self.milvus_client.has_collection(collection_name=table_name) 

317 except Exception as e: 

318 return Response( 

319 resp_type=RESPONSE_TYPE.ERROR, 

320 error_message=f"Error finding table: {e}", 

321 ) 

322 try: 

323 field_names = {field["name"] for field in self.milvus_client.describe_collection(collection_name=table_name)["fields"]} 

324 schema = [mindsdb_schema_field for mindsdb_schema_field in self.SCHEMA if mindsdb_schema_field["name"] in field_names] 

325 data = pd.DataFrame(schema) 

326 data.columns = ["COLUMN_NAME", "DATA_TYPE"] 

327 return HandlerResponse(data_frame=data) 

328 except Exception as e: 

329 return Response( 

330 resp_type=RESPONSE_TYPE.ERROR, 

331 error_message=f"Error finding table: {e}", 

332 )