Coverage for mindsdb / integrations / handlers / pinecone_handler / pinecone_handler.py: 0%

188 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2from typing import List, Optional 

3 

4import numpy as np 

5from pinecone import Pinecone, ServerlessSpec 

6from pinecone.core.openapi.shared.exceptions import NotFoundException, PineconeApiException 

7import pandas as pd 

8 

9from mindsdb.integrations.libs.response import RESPONSE_TYPE 

10from mindsdb.integrations.libs.response import HandlerResponse 

11from mindsdb.integrations.libs.response import HandlerResponse as Response 

12from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse 

13from mindsdb.integrations.libs.vectordatabase_handler import ( 

14 FilterCondition, 

15 FilterOperator, 

16 TableField, 

17 VectorStoreHandler, 

18) 

19from mindsdb.utilities import log 

20 

21logger = log.getLogger(__name__) 

22 

23DEFAULT_CREATE_TABLE_PARAMS = { 

24 "dimension": 8, 

25 "metric": "cosine", 

26 "spec": { 

27 "cloud": "aws", 

28 "region": "us-east-1" 

29 } 

30} 

31MAX_FETCH_LIMIT = 10000 

32UPSERT_BATCH_SIZE = 99 # API reccomendation 

33 

34 

35class PineconeHandler(VectorStoreHandler): 

36 """This handler handles connection and execution of the Pinecone statements.""" 

37 

38 name = "pinecone" 

39 

40 def __init__(self, name: str, connection_data: dict, **kwargs): 

41 super().__init__(name) 

42 self.connection_data = connection_data 

43 self.kwargs = kwargs 

44 

45 self.connection = None 

46 self.is_connected = False 

47 

48 def __del__(self): 

49 if self.is_connected is True: 

50 self.disconnect() 

51 

52 def _get_index_handle(self, index_name): 

53 """Returns handler to index specified by `index_name`""" 

54 connection = self.connect() 

55 index = connection.Index(index_name) 

56 try: 

57 index.describe_index_stats() 

58 except Exception: 

59 index = None 

60 return index 

61 

62 def _get_pinecone_operator(self, operator: FilterOperator) -> str: 

63 """Convert FilterOperator to an operator that pinecone's query language can undersand""" 

64 mapping = { 

65 FilterOperator.EQUAL: "$eq", 

66 FilterOperator.NOT_EQUAL: "$ne", 

67 FilterOperator.GREATER_THAN: "$gt", 

68 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte", 

69 FilterOperator.LESS_THAN: "$lt", 

70 FilterOperator.LESS_THAN_OR_EQUAL: "$lte", 

71 FilterOperator.IN: "$in", 

72 FilterOperator.NOT_IN: "$nin", 

73 } 

74 if operator not in mapping: 

75 raise Exception(f"Operator {operator} is not supported by Pinecone!") 

76 return mapping[operator] 

77 

78 def _translate_metadata_condition(self, conditions: List[FilterCondition]) -> Optional[dict]: 

79 """ 

80 Translate a list of FilterCondition objects a dict that can be used by pinecone. 

81 E.g., 

82 [ 

83 FilterCondition( 

84 column="metadata.created_at", 

85 op=FilterOperator.LESS_THAN, 

86 value="2020-01-01", 

87 ), 

88 FilterCondition( 

89 column="metadata.created_at", 

90 op=FilterOperator.GREATER_THAN, 

91 value="2019-01-01", 

92 ) 

93 ] 

94 --> 

95 { 

96 "$and": [ 

97 {"created_at": {"$lt": "2020-01-01"}}, 

98 {"created_at": {"$gt": "2019-01-01"}} 

99 ] 

100 } 

101 """ 

102 # we ignore all non-metadata conditions 

103 if conditions is None: 

104 return None 

105 metadata_conditions = [ 

106 condition 

107 for condition in conditions 

108 if condition.column.startswith(TableField.METADATA.value) 

109 ] 

110 if len(metadata_conditions) == 0: 

111 return None 

112 

113 # we translate each metadata condition into a dict 

114 pinecone_conditions = [] 

115 for condition in metadata_conditions: 

116 metadata_key = condition.column.split(".")[-1] 

117 pinecone_conditions.append( 

118 { 

119 metadata_key: { 

120 self._get_pinecone_operator(condition.op): condition.value 

121 } 

122 } 

123 ) 

124 

125 # we combine all metadata conditions into a single dict 

126 metadata_condition = ( 

127 {"$and": pinecone_conditions} 

128 if len(pinecone_conditions) > 1 

129 else pinecone_conditions[0] 

130 ) 

131 return metadata_condition 

132 

133 def _matches_to_dicts(self, matches: List): 

134 """Converts the custom pinecone response type to a list of python dict""" 

135 return [match.to_dict() for match in matches] 

136 

137 def connect(self): 

138 """Connect to a pinecone database.""" 

139 if self.is_connected is True: 

140 return self.connection 

141 

142 if 'api_key' not in self.connection_data: 

143 raise ValueError('Required parameter (api_key) must be provided.') 

144 

145 try: 

146 self.connection = Pinecone(api_key=self.connection_data['api_key']) 

147 return self.connection 

148 except Exception as e: 

149 logger.error(f"Error connecting to Pinecone client, {e}!") 

150 self.is_connected = False 

151 

152 def disconnect(self): 

153 """Close the pinecone connection.""" 

154 if self.is_connected is False: 

155 return 

156 self.connection = None 

157 self.is_connected = False 

158 

159 def check_connection(self): 

160 """Check the connection to pinecone.""" 

161 response = StatusResponse(False) 

162 need_to_close = self.is_connected is False 

163 

164 try: 

165 connection = self.connect() 

166 connection.list_indexes() 

167 response.success = True 

168 except Exception as e: 

169 logger.error(f"Error connecting to pinecone , {e}!") 

170 response.error_message = str(e) 

171 

172 if response.success is True and need_to_close: 

173 self.disconnect() 

174 if response.success is False and self.is_connected is True: 

175 self.is_connected = False 

176 

177 return response 

178 

179 def get_tables(self) -> HandlerResponse: 

180 """Get the list of indexes in the pinecone database.""" 

181 connection = self.connect() 

182 indexes = connection.list_indexes() 

183 df = pd.DataFrame( 

184 columns=["table_name"], 

185 data=[index['name'] for index in indexes], 

186 ) 

187 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=df) 

188 

189 def create_table(self, table_name: str, if_not_exists=True): 

190 """Create an index with the given name in the Pinecone database.""" 

191 connection = self.connect() 

192 

193 # TODO: Should other parameters be supported? Pod indexes? 

194 # TODO: Should there be a better way to provide these parameters rather than when establishing the connection? 

195 create_table_params = {} 

196 for key, val in DEFAULT_CREATE_TABLE_PARAMS.items(): 

197 if key in self.connection_data: 

198 create_table_params[key] = self.connection_data[key] 

199 else: 

200 create_table_params[key] = val 

201 

202 create_table_params["spec"] = ServerlessSpec(**create_table_params["spec"]) 

203 

204 try: 

205 connection.create_index(name=table_name, **create_table_params) 

206 except PineconeApiException as pinecone_error: 

207 if pinecone_error.status == 409 and if_not_exists: 

208 return 

209 raise Exception(f"Error creating index '{table_name}': {pinecone_error}") 

210 

211 def insert(self, table_name: str, data: pd.DataFrame): 

212 """Insert data into pinecone index passed in through `table_name` parameter.""" 

213 index = self._get_index_handle(table_name) 

214 if index is None: 

215 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?") 

216 

217 data.rename(columns={ 

218 TableField.ID.value: "id", 

219 TableField.EMBEDDINGS.value: "values"}, 

220 inplace=True) 

221 

222 columns = ["id", "values"] 

223 

224 if TableField.METADATA.value in data.columns: 

225 data.rename(columns={TableField.METADATA.value: "metadata"}, inplace=True) 

226 # fill None and NaN values with empty dict 

227 if data['metadata'].isnull().any(): 

228 data['metadata'] = data['metadata'].apply(lambda x: {} if x is None or (isinstance(x, float) and np.isnan(x)) else x) 

229 columns.append("metadata") 

230 

231 data = data[columns] 

232 

233 # convert the embeddings to lists if they are strings 

234 data["values"] = data["values"].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) 

235 

236 for chunk in (data[pos:pos + UPSERT_BATCH_SIZE] for pos in range(0, len(data), UPSERT_BATCH_SIZE)): 

237 chunk = chunk.to_dict(orient="records") 

238 index.upsert(vectors=chunk) 

239 

240 def drop_table(self, table_name: str, if_exists=True): 

241 """Delete an index passed in through `table_name` from the pinecone .""" 

242 connection = self.connect() 

243 try: 

244 connection.delete_index(table_name) 

245 except NotFoundException: 

246 if if_exists: 

247 return 

248 raise Exception(f"Error deleting index '{table_name}', are you sure the name is correct?") 

249 

250 def delete(self, table_name: str, conditions: List[FilterCondition] = None): 

251 """Delete records in pinecone index `table_name` based on ids or based on metadata conditions.""" 

252 filters = self._translate_metadata_condition(conditions) 

253 ids = [ 

254 condition.value 

255 for condition in conditions 

256 if condition.column == TableField.ID.value 

257 ] or None 

258 if filters is None and ids is None: 

259 raise Exception("Delete query must have either id condition or metadata condition!") 

260 index = self._get_index_handle(table_name) 

261 if index is None: 

262 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?") 

263 

264 if filters is None: 

265 index.delete(ids=ids) 

266 else: 

267 index.delete(filter=filters) 

268 

269 def select( 

270 self, 

271 table_name: str, 

272 columns: List[str] = None, 

273 conditions: List[FilterCondition] = None, 

274 offset: int = None, 

275 limit: int = None, 

276 ): 

277 """Run query on pinecone index named `table_name` and get results.""" 

278 # TODO: Add support for namespaces. 

279 index = self._get_index_handle(table_name) 

280 if index is None: 

281 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?") 

282 

283 query = { 

284 "include_values": True, 

285 "include_metadata": True 

286 } 

287 

288 # check for metadata filter 

289 metadata_filters = self._translate_metadata_condition(conditions) 

290 if metadata_filters is not None: 

291 query["filter"] = metadata_filters 

292 

293 # check for vector and id filters 

294 vector_filters = [] 

295 id_filters = [] 

296 

297 if conditions: 

298 for condition in conditions: 

299 if condition.column == TableField.SEARCH_VECTOR.value: 

300 vector_filters.append(condition.value) 

301 elif condition.column == TableField.ID.value: 

302 id_filters.append(condition.value) 

303 

304 if vector_filters: 

305 if len(vector_filters) > 1: 

306 raise Exception("You cannot have multiple search_vectors in query") 

307 

308 query["vector"] = vector_filters[0] 

309 # For subqueries, the vector filter is a list of list of strings 

310 if isinstance(query["vector"], list) and isinstance(query["vector"][0], str): 

311 if len(query["vector"]) > 1: 

312 raise Exception("You cannot have multiple search_vectors in query") 

313 

314 try: 

315 query["vector"] = ast.literal_eval(query["vector"][0]) 

316 except Exception as e: 

317 raise Exception(f"Cannot parse the search vector '{query['vector']}'into a list: {e}") 

318 

319 if id_filters: 

320 if len(id_filters) > 1: 

321 raise Exception("You cannot have multiple IDs in query") 

322 

323 query["id"] = id_filters[0] 

324 

325 if not vector_filters and not id_filters: 

326 raise Exception("You must provide either a search_vector or an ID in the query") 

327 

328 # check for limit 

329 if limit is not None: 

330 query["top_k"] = limit 

331 else: 

332 query["top_k"] = MAX_FETCH_LIMIT 

333 

334 # exec query 

335 try: 

336 result = index.query(**query) 

337 except Exception as e: 

338 raise Exception(f"Error running SELECT query on '{table_name}': {e}") 

339 

340 # convert to dataframe 

341 df_columns = { 

342 "id": TableField.ID.value, 

343 "metadata": TableField.METADATA.value, 

344 "values": TableField.EMBEDDINGS.value, 

345 } 

346 results_df = pd.DataFrame.from_records(self._matches_to_dicts(result["matches"])) 

347 if bool(len(results_df.columns)): 

348 results_df.rename(columns=df_columns, inplace=True) 

349 else: 

350 results_df = pd.DataFrame(columns=list(df_columns.values())) 

351 results_df[TableField.CONTENT.value] = "" 

352 return results_df[columns] 

353 

354 def get_columns(self, table_name: str) -> HandlerResponse: 

355 return super().get_columns(table_name)