Coverage for mindsdb / integrations / handlers / pinecone_handler / pinecone_handler.py: 0%
188 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import ast
2from typing import List, Optional
4import numpy as np
5from pinecone import Pinecone, ServerlessSpec
6from pinecone.core.openapi.shared.exceptions import NotFoundException, PineconeApiException
7import pandas as pd
9from mindsdb.integrations.libs.response import RESPONSE_TYPE
10from mindsdb.integrations.libs.response import HandlerResponse
11from mindsdb.integrations.libs.response import HandlerResponse as Response
12from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
13from mindsdb.integrations.libs.vectordatabase_handler import (
14 FilterCondition,
15 FilterOperator,
16 TableField,
17 VectorStoreHandler,
18)
19from mindsdb.utilities import log
21logger = log.getLogger(__name__)
23DEFAULT_CREATE_TABLE_PARAMS = {
24 "dimension": 8,
25 "metric": "cosine",
26 "spec": {
27 "cloud": "aws",
28 "region": "us-east-1"
29 }
30}
31MAX_FETCH_LIMIT = 10000
32UPSERT_BATCH_SIZE = 99 # API reccomendation
35class PineconeHandler(VectorStoreHandler):
36 """This handler handles connection and execution of the Pinecone statements."""
38 name = "pinecone"
40 def __init__(self, name: str, connection_data: dict, **kwargs):
41 super().__init__(name)
42 self.connection_data = connection_data
43 self.kwargs = kwargs
45 self.connection = None
46 self.is_connected = False
48 def __del__(self):
49 if self.is_connected is True:
50 self.disconnect()
52 def _get_index_handle(self, index_name):
53 """Returns handler to index specified by `index_name`"""
54 connection = self.connect()
55 index = connection.Index(index_name)
56 try:
57 index.describe_index_stats()
58 except Exception:
59 index = None
60 return index
62 def _get_pinecone_operator(self, operator: FilterOperator) -> str:
63 """Convert FilterOperator to an operator that pinecone's query language can undersand"""
64 mapping = {
65 FilterOperator.EQUAL: "$eq",
66 FilterOperator.NOT_EQUAL: "$ne",
67 FilterOperator.GREATER_THAN: "$gt",
68 FilterOperator.GREATER_THAN_OR_EQUAL: "$gte",
69 FilterOperator.LESS_THAN: "$lt",
70 FilterOperator.LESS_THAN_OR_EQUAL: "$lte",
71 FilterOperator.IN: "$in",
72 FilterOperator.NOT_IN: "$nin",
73 }
74 if operator not in mapping:
75 raise Exception(f"Operator {operator} is not supported by Pinecone!")
76 return mapping[operator]
78 def _translate_metadata_condition(self, conditions: List[FilterCondition]) -> Optional[dict]:
79 """
80 Translate a list of FilterCondition objects a dict that can be used by pinecone.
81 E.g.,
82 [
83 FilterCondition(
84 column="metadata.created_at",
85 op=FilterOperator.LESS_THAN,
86 value="2020-01-01",
87 ),
88 FilterCondition(
89 column="metadata.created_at",
90 op=FilterOperator.GREATER_THAN,
91 value="2019-01-01",
92 )
93 ]
94 -->
95 {
96 "$and": [
97 {"created_at": {"$lt": "2020-01-01"}},
98 {"created_at": {"$gt": "2019-01-01"}}
99 ]
100 }
101 """
102 # we ignore all non-metadata conditions
103 if conditions is None:
104 return None
105 metadata_conditions = [
106 condition
107 for condition in conditions
108 if condition.column.startswith(TableField.METADATA.value)
109 ]
110 if len(metadata_conditions) == 0:
111 return None
113 # we translate each metadata condition into a dict
114 pinecone_conditions = []
115 for condition in metadata_conditions:
116 metadata_key = condition.column.split(".")[-1]
117 pinecone_conditions.append(
118 {
119 metadata_key: {
120 self._get_pinecone_operator(condition.op): condition.value
121 }
122 }
123 )
125 # we combine all metadata conditions into a single dict
126 metadata_condition = (
127 {"$and": pinecone_conditions}
128 if len(pinecone_conditions) > 1
129 else pinecone_conditions[0]
130 )
131 return metadata_condition
133 def _matches_to_dicts(self, matches: List):
134 """Converts the custom pinecone response type to a list of python dict"""
135 return [match.to_dict() for match in matches]
137 def connect(self):
138 """Connect to a pinecone database."""
139 if self.is_connected is True:
140 return self.connection
142 if 'api_key' not in self.connection_data:
143 raise ValueError('Required parameter (api_key) must be provided.')
145 try:
146 self.connection = Pinecone(api_key=self.connection_data['api_key'])
147 return self.connection
148 except Exception as e:
149 logger.error(f"Error connecting to Pinecone client, {e}!")
150 self.is_connected = False
152 def disconnect(self):
153 """Close the pinecone connection."""
154 if self.is_connected is False:
155 return
156 self.connection = None
157 self.is_connected = False
159 def check_connection(self):
160 """Check the connection to pinecone."""
161 response = StatusResponse(False)
162 need_to_close = self.is_connected is False
164 try:
165 connection = self.connect()
166 connection.list_indexes()
167 response.success = True
168 except Exception as e:
169 logger.error(f"Error connecting to pinecone , {e}!")
170 response.error_message = str(e)
172 if response.success is True and need_to_close:
173 self.disconnect()
174 if response.success is False and self.is_connected is True:
175 self.is_connected = False
177 return response
179 def get_tables(self) -> HandlerResponse:
180 """Get the list of indexes in the pinecone database."""
181 connection = self.connect()
182 indexes = connection.list_indexes()
183 df = pd.DataFrame(
184 columns=["table_name"],
185 data=[index['name'] for index in indexes],
186 )
187 return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=df)
189 def create_table(self, table_name: str, if_not_exists=True):
190 """Create an index with the given name in the Pinecone database."""
191 connection = self.connect()
193 # TODO: Should other parameters be supported? Pod indexes?
194 # TODO: Should there be a better way to provide these parameters rather than when establishing the connection?
195 create_table_params = {}
196 for key, val in DEFAULT_CREATE_TABLE_PARAMS.items():
197 if key in self.connection_data:
198 create_table_params[key] = self.connection_data[key]
199 else:
200 create_table_params[key] = val
202 create_table_params["spec"] = ServerlessSpec(**create_table_params["spec"])
204 try:
205 connection.create_index(name=table_name, **create_table_params)
206 except PineconeApiException as pinecone_error:
207 if pinecone_error.status == 409 and if_not_exists:
208 return
209 raise Exception(f"Error creating index '{table_name}': {pinecone_error}")
211 def insert(self, table_name: str, data: pd.DataFrame):
212 """Insert data into pinecone index passed in through `table_name` parameter."""
213 index = self._get_index_handle(table_name)
214 if index is None:
215 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?")
217 data.rename(columns={
218 TableField.ID.value: "id",
219 TableField.EMBEDDINGS.value: "values"},
220 inplace=True)
222 columns = ["id", "values"]
224 if TableField.METADATA.value in data.columns:
225 data.rename(columns={TableField.METADATA.value: "metadata"}, inplace=True)
226 # fill None and NaN values with empty dict
227 if data['metadata'].isnull().any():
228 data['metadata'] = data['metadata'].apply(lambda x: {} if x is None or (isinstance(x, float) and np.isnan(x)) else x)
229 columns.append("metadata")
231 data = data[columns]
233 # convert the embeddings to lists if they are strings
234 data["values"] = data["values"].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
236 for chunk in (data[pos:pos + UPSERT_BATCH_SIZE] for pos in range(0, len(data), UPSERT_BATCH_SIZE)):
237 chunk = chunk.to_dict(orient="records")
238 index.upsert(vectors=chunk)
240 def drop_table(self, table_name: str, if_exists=True):
241 """Delete an index passed in through `table_name` from the pinecone ."""
242 connection = self.connect()
243 try:
244 connection.delete_index(table_name)
245 except NotFoundException:
246 if if_exists:
247 return
248 raise Exception(f"Error deleting index '{table_name}', are you sure the name is correct?")
250 def delete(self, table_name: str, conditions: List[FilterCondition] = None):
251 """Delete records in pinecone index `table_name` based on ids or based on metadata conditions."""
252 filters = self._translate_metadata_condition(conditions)
253 ids = [
254 condition.value
255 for condition in conditions
256 if condition.column == TableField.ID.value
257 ] or None
258 if filters is None and ids is None:
259 raise Exception("Delete query must have either id condition or metadata condition!")
260 index = self._get_index_handle(table_name)
261 if index is None:
262 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?")
264 if filters is None:
265 index.delete(ids=ids)
266 else:
267 index.delete(filter=filters)
269 def select(
270 self,
271 table_name: str,
272 columns: List[str] = None,
273 conditions: List[FilterCondition] = None,
274 offset: int = None,
275 limit: int = None,
276 ):
277 """Run query on pinecone index named `table_name` and get results."""
278 # TODO: Add support for namespaces.
279 index = self._get_index_handle(table_name)
280 if index is None:
281 raise Exception(f"Error getting index '{table_name}', are you sure the name is correct?")
283 query = {
284 "include_values": True,
285 "include_metadata": True
286 }
288 # check for metadata filter
289 metadata_filters = self._translate_metadata_condition(conditions)
290 if metadata_filters is not None:
291 query["filter"] = metadata_filters
293 # check for vector and id filters
294 vector_filters = []
295 id_filters = []
297 if conditions:
298 for condition in conditions:
299 if condition.column == TableField.SEARCH_VECTOR.value:
300 vector_filters.append(condition.value)
301 elif condition.column == TableField.ID.value:
302 id_filters.append(condition.value)
304 if vector_filters:
305 if len(vector_filters) > 1:
306 raise Exception("You cannot have multiple search_vectors in query")
308 query["vector"] = vector_filters[0]
309 # For subqueries, the vector filter is a list of list of strings
310 if isinstance(query["vector"], list) and isinstance(query["vector"][0], str):
311 if len(query["vector"]) > 1:
312 raise Exception("You cannot have multiple search_vectors in query")
314 try:
315 query["vector"] = ast.literal_eval(query["vector"][0])
316 except Exception as e:
317 raise Exception(f"Cannot parse the search vector '{query['vector']}'into a list: {e}")
319 if id_filters:
320 if len(id_filters) > 1:
321 raise Exception("You cannot have multiple IDs in query")
323 query["id"] = id_filters[0]
325 if not vector_filters and not id_filters:
326 raise Exception("You must provide either a search_vector or an ID in the query")
328 # check for limit
329 if limit is not None:
330 query["top_k"] = limit
331 else:
332 query["top_k"] = MAX_FETCH_LIMIT
334 # exec query
335 try:
336 result = index.query(**query)
337 except Exception as e:
338 raise Exception(f"Error running SELECT query on '{table_name}': {e}")
340 # convert to dataframe
341 df_columns = {
342 "id": TableField.ID.value,
343 "metadata": TableField.METADATA.value,
344 "values": TableField.EMBEDDINGS.value,
345 }
346 results_df = pd.DataFrame.from_records(self._matches_to_dicts(result["matches"]))
347 if bool(len(results_df.columns)):
348 results_df.rename(columns=df_columns, inplace=True)
349 else:
350 results_df = pd.DataFrame(columns=list(df_columns.values()))
351 results_df[TableField.CONTENT.value] = ""
352 return results_df[columns]
354 def get_columns(self, table_name: str) -> HandlerResponse:
355 return super().get_columns(table_name)