Coverage for mindsdb / integrations / handlers / qdrant_handler / qdrant_handler.py: 0%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2from typing import Any, List, Optional 

3from itertools import zip_longest 

4 

5from qdrant_client import QdrantClient, models 

6import pandas as pd 

7 

8from mindsdb.integrations.libs.response import HandlerResponse 

9from mindsdb.integrations.libs.response import RESPONSE_TYPE 

10from mindsdb.integrations.libs.response import HandlerResponse as Response 

11from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

12from mindsdb.integrations.libs.vectordatabase_handler import ( 

13 FilterCondition, 

14 FilterOperator, 

15 TableField, 

16 VectorStoreHandler, 

17) 

18from mindsdb.utilities import log 

19 

20logger = log.getLogger(__name__) 

21 

22 

23class QdrantHandler(VectorStoreHandler): 

24 """Handles connection and execution of the Qdrant statements.""" 

25 

26 name = "qdrant" 

27 

28 def __init__(self, name: str, **kwargs): 

29 super().__init__(name) 

30 connection_data = kwargs.get("connection_data").copy() 

31 # Qdrant offers several configuration and optmization options at the time of collection creation 

32 # Since the create table statement doesn't have a way to pass these options 

33 # We are requiring the user to pass these options in the connection_data 

34 # These options are documented here. https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection 

35 self.collection_config = connection_data.pop("collection_config") 

36 self.connect(**connection_data) 

37 

38 def connect(self, **kwargs): 

39 """Connect to a Qdrant instance. 

40 A Qdrant client can be instantiated with a REST, GRPC interface or in-memory for testing. 

41 To use the in-memory instance, specify the location argument as ':memory:'.""" 

42 if self.is_connected: 

43 return self._client 

44 

45 try: 

46 self._client = QdrantClient(**kwargs) 

47 self.is_connected = True 

48 return self._client 

49 except Exception as e: 

50 logger.error(f"Error instantiating a Qdrant client: {e}") 

51 self.is_connected = False 

52 

53 def disconnect(self): 

54 """Close the database connection.""" 

55 if self.is_connected: 

56 self._client.close() 

57 self._client = None 

58 self.is_connected = False 

59 

60 def check_connection(self) -> StatusResponse: 

61 """Check the connection to the Qdrant database. 

62 

63 Returns: 

64 StatusResponse: Indicates if the connection is alive 

65 """ 

66 need_to_close = not self.is_connected 

67 

68 try: 

69 # Using a trivial operation to get the connection status 

70 # As there isn't a universal ping method for the REST, GRPC and in-memory interface 

71 self._client.get_locks() 

72 response_code = StatusResponse(True) 

73 except Exception as e: 

74 logger.error(f"Error connecting to a Qdrant instance: {e}") 

75 response_code = StatusResponse(False, error_message=str(e)) 

76 finally: 

77 if response_code.success and need_to_close: 

78 self.disconnect() 

79 if not response_code.success and self.is_connected: 

80 self.is_connected = False 

81 

82 return response_code 

83 

84 def drop_table(self, table_name: str, if_exists=True): 

85 """Delete a collection from the Qdrant Instance. 

86 

87 Args: 

88 table_name (str): The name of the collection to be dropped 

89 if_exists (bool, optional): Throws an error if this value is set to false and the collection doesn't exist. Defaults to True. 

90 

91 Returns: 

92 HandlerResponse: _description_ 

93 """ 

94 result = self._client.delete_collection(table_name) 

95 if not (result or if_exists): 

96 raise Exception(f"Table {table_name} does not exist!") 

97 

98 def get_tables(self) -> HandlerResponse: 

99 """Get the list of collections in the Qdrant instance. 

100 

101 Returns: 

102 HandlerResponse: The common query handler response with a list of table names 

103 """ 

104 collection_response = self._client.get_collections() 

105 collections_name = pd.DataFrame( 

106 columns=["table_name"], 

107 data=[collection.name for collection in collection_response.collections], 

108 ) 

109 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name) 

110 

111 def get_columns(self, table_name: str) -> HandlerResponse: 

112 try: 

113 _ = self._client.get_collection(table_name) 

114 except ValueError: 

115 return Response( 

116 resp_type=RESPONSE_TYPE.ERROR, 

117 error_message=f"Table {table_name} does not exist!", 

118 ) 

119 return super().get_columns(table_name) 

120 

121 def insert( 

122 self, table_name: str, data: pd.DataFrame, columns: List[str] = None 

123 ): 

124 """Handler for the insert query 

125 

126 Args: 

127 table_name (str): The name of the table to be inserted into 

128 data (pd.DataFrame): The data to be inserted 

129 columns (List[str], optional): Columns to be inserted into. Unused as the values are derived from the "data" argument. Defaults to None. 

130 

131 Returns: 

132 HandlerResponse: The common query handler response 

133 """ 

134 assert len(data[TableField.ID.value]) == len(data[TableField.EMBEDDINGS.value]), "Number of ids and embeddings must be equal" 

135 

136 # Qdrant doesn't have a distinction between documents and metadata 

137 # Any data that is to be stored should be placed in the "payload" field 

138 data = data.to_dict(orient="list") 

139 payloads = [] 

140 content_list = data[TableField.CONTENT.value] 

141 if TableField.METADATA.value in data: 

142 metadata_list = data[TableField.METADATA.value] 

143 else: 

144 metadata_list = [None] * len(data) 

145 for document, metadata in zip_longest(content_list, metadata_list, fillvalue=None): 

146 payload = {} 

147 

148 # Insert the document with a "document" key in the payload 

149 if document is not None: 

150 payload["document"] = document 

151 

152 # Unpack all the metadata fields into the payload 

153 if metadata is not None: 

154 if isinstance(metadata, str): 

155 metadata = ast.literal_eval(metadata) 

156 payload = {**payload, **metadata} 

157 

158 if payload: 

159 payloads.append(payload) 

160 

161 # IDs can be either integers or strings(UUIDs) 

162 # The following step ensures proper type of numberic values 

163 ids = [int(id) if str(id).isdigit() else id for id in data[TableField.ID.value]] 

164 self._client.upsert(table_name, points=models.Batch( 

165 ids=ids, 

166 vectors=data[TableField.EMBEDDINGS.value], 

167 payloads=payloads 

168 )) 

169 

170 def create_table(self, table_name: str, if_not_exists=True): 

171 """Create a collection with the given name in the Qdrant database. 

172 

173 Args: 

174 table_name (str): Name of the table(Collection) to be created 

175 if_not_exists (bool, optional): Throws an error if this value is set to false and the collection already exists. Defaults to True. 

176 

177 Returns: 

178 HandlerResponse: The common query handler response 

179 """ 

180 try: 

181 # Create a collection with the collection name and collection_config set during __init__ 

182 self._client.create_collection(table_name, self.collection_config) 

183 except ValueError as e: 

184 if if_not_exists is False: 

185 raise e 

186 

187 def _get_qdrant_filter(self, operator: FilterOperator, value: Any) -> dict: 

188 """ Map the filter operator to the Qdrant filter 

189 We use a match and not a dict so as to conditionally construct values 

190 With a dict, all the values the values will be constructed 

191 Generating models.Range() with a str type value fails 

192 

193 Args: 

194 operator (FilterOperator): FilterOperator specified in the query. Eg >=, <=, = 

195 value (Any): Value specified in the query 

196 

197 Raises: 

198 Exception: If an unsupported operator is specified 

199 

200 Returns: 

201 dict: A dict of Qdrant filtering clauses 

202 """ 

203 if operator == FilterOperator.EQUAL: 

204 return {"match": models.MatchValue(value=value)} 

205 elif operator == FilterOperator.NOT_EQUAL: 

206 return {"match": models.MatchExcept(**{"except": [value]})} 

207 elif operator == FilterOperator.LESS_THAN: 

208 return {"range": models.Range(lt=value)} 

209 elif operator == FilterOperator.LESS_THAN_OR_EQUAL: 

210 return {"range": models.Range(lte=value)} 

211 elif operator == FilterOperator.GREATER_THAN: 

212 return {"range": models.Range(gt=value)} 

213 elif operator == FilterOperator.GREATER_THAN_OR_EQUAL: 

214 return {"range": models.Range(gte=value)} 

215 else: 

216 raise Exception(f"Operator {operator} is not supported by Qdrant!") 

217 

218 def _translate_filter_conditions( 

219 self, conditions: List[FilterCondition] 

220 ) -> Optional[dict]: 

221 """ 

222 Translate a list of FilterCondition objects a dict that can be used by Qdrant. 

223 Filtering clause docs can be found here: https://qdrant.tech/documentation/concepts/filtering/ 

224 E.g., 

225 [ 

226 FilterCondition( 

227 column="metadata.created_at", 

228 op=FilterOperator.LESS_THAN, 

229 value=7132423, 

230 ), 

231 FilterCondition( 

232 column="metadata.created_at", 

233 op=FilterOperator.GREATER_THAN, 

234 value=2323432, 

235 ) 

236 ] 

237 --> 

238 models.Filter( 

239 must=[ 

240 models.FieldCondition( 

241 key="created_at", 

242 match=models.Range(lt=7132423), 

243 ), 

244 models.FieldCondition( 

245 key="created_at", 

246 match=models.Range(gt=2323432), 

247 ), 

248 ] 

249 ) 

250 """ 

251 # We ignore all non-metadata conditions 

252 if conditions is None: 

253 return None 

254 filter_conditions = [ 

255 condition 

256 for condition in conditions 

257 if condition.column.startswith(TableField.METADATA.value) 

258 ] 

259 if len(filter_conditions) == 0: 

260 return None 

261 

262 qdrant_filters = [] 

263 for condition in filter_conditions: 

264 payload_key = condition.column.split(".")[-1] 

265 qdrant_filters.append( 

266 models.FieldCondition(key=payload_key, **self._get_qdrant_filter(condition.op, condition.value)) 

267 ) 

268 

269 return models.Filter(must=qdrant_filters) if qdrant_filters else None 

270 

271 def update( 

272 self, table_name: str, data: pd.DataFrame, columns: List[str] = None 

273 ): 

274 # insert makes upsert 

275 return self.insert(table_name, data) 

276 

277 def select(self, table_name: str, columns: Optional[List[str]] = None, 

278 conditions: Optional[List[FilterCondition]] = None, offset: int = 0, limit: int = 10) -> pd.DataFrame: 

279 """Select query handler 

280 Eg: SELECT * FROM qdrant.test_table 

281 

282 Args: 

283 table_name (str): The name of the table to be queried 

284 columns (Optional[List[str]], optional): List of column names specified in the query. Defaults to None. 

285 conditions (Optional[List[FilterCondition]], optional): List of "where" conditionals. Defaults to None. 

286 offset (int, optional): Offset the results by the provided value. Defaults to 0. 

287 limit (int, optional): Number of results to return. Defaults to 10. 

288 

289 Returns: 

290 HandlerResponse: The common query handler response 

291 """ 

292 

293 # Validate and set offset and limit as None is passed if not set in the query 

294 offset = offset if offset is not None else 0 

295 limit = limit if limit is not None else 10 

296 

297 # Full scroll if no where conditions are specified 

298 if not conditions: 

299 results = self._client.scroll(table_name, limit=limit, offset=offset) 

300 payload = self._process_select_results(results[0], columns) 

301 return payload 

302 

303 # Filter conditions 

304 vector_filter = [condition.value for condition in conditions if condition.column == TableField.SEARCH_VECTOR.value] 

305 id_filters = [condition.value for condition in conditions if condition.column == TableField.ID.value] 

306 query_filters = self._translate_filter_conditions(conditions) 

307 

308 # Prefer returning results by IDs first 

309 if id_filters: 

310 

311 if len(id_filters) > 0: 

312 # is wrapped to a list 

313 if isinstance(id_filters[0], list): 

314 id_filters = id_filters[0] 

315 # convert to int if possible 

316 id_filters = [int(id) if isinstance(id, str) and id.isdigit() else id for id in id_filters] 

317 

318 results = self._client.retrieve(table_name, ids=id_filters) 

319 # Followed by the search_vector value 

320 elif vector_filter: 

321 # Perform a similarity search with the first vector filter 

322 results = self._client.search(table_name, query_vector=vector_filter[0], query_filter=query_filters or None, limit=limit, offset=offset) 

323 elif query_filters: 

324 results = self._client.scroll(table_name, scroll_filter=query_filters, limit=limit, offset=offset)[0] 

325 

326 # Process results 

327 payload = self._process_select_results(results, columns) 

328 return payload 

329 

330 def _process_select_results(self, results=None, columns=None): 

331 """Private method to process the results of a select query 

332 

333 Args: 

334 results: A List[Records] or List[ScoredPoint]. Defaults to None 

335 columns: List of column names specified in the query. Defaults to None 

336 

337 Returns: 

338 Dataframe: A processed pandas dataframe 

339 """ 

340 ids, documents, metadata, distances = [], [], [], [] 

341 

342 for result in results: 

343 ids.append(result.id) 

344 # The documents and metadata are stored as a dict in the payload 

345 documents.append(result.payload["document"]) 

346 metadata.append({k: v for k, v in result.payload.items() if k != "document"}) 

347 

348 # Score is only available for similarity search results 

349 if "score" in result: 

350 distances.append(result.score) 

351 

352 payload = { 

353 TableField.ID.value: ids, 

354 TableField.CONTENT.value: documents, 

355 TableField.METADATA.value: metadata, 

356 } 

357 

358 # Filter result columns 

359 if columns: 

360 payload = { 

361 column: payload[column] 

362 for column in columns 

363 if column != TableField.EMBEDDINGS.value and column in payload 

364 } 

365 

366 # If the distance list is empty, don't add it to the result 

367 if distances: 

368 payload[TableField.DISTANCE.value] = distances 

369 

370 return pd.DataFrame(payload) 

371 

372 def delete( 

373 self, table_name: str, conditions: List[FilterCondition] = None 

374 ): 

375 """Delete query handler 

376 

377 Args: 

378 table_name (str): List of column names specified in the query. Defaults to None. 

379 conditions (List[FilterCondition], optional): List of "where" conditionals. Defaults to None. 

380 

381 Raises: 

382 Exception: If no conditions are specified 

383 

384 Returns: 

385 HandlerResponse: The common query handler response 

386 """ 

387 filters = self._translate_filter_conditions(conditions) 

388 # Get id filters 

389 ids = [ 

390 condition.value 

391 for condition in conditions 

392 if condition.column == TableField.ID.value 

393 ] or None 

394 

395 if filters is None and ids is None: 

396 raise Exception("Delete query must have at least one condition!") 

397 

398 if ids: 

399 self._client.delete(table_name, points_selector=models.PointIdsList(points=ids)) 

400 

401 if filters: 

402 self._client.delete(table_name, points_selector=models.FilterSelector(filter=filters))