Coverage for mindsdb / integrations / handlers / ibm_cos_handler / ibm_cos_handler.py: 0%
197 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from contextlib import contextmanager
2from typing import Text, Dict, Optional, List
4import ibm_boto3
5from ibm_botocore.client import ClientError
6import pandas as pd
7import duckdb
9from mindsdb_sql_parser.ast.base import ASTNode
10from mindsdb_sql_parser.ast import Select, Identifier, Insert, Star, Constant
12from mindsdb.utilities import log
13from mindsdb.integrations.libs.response import (
14 HandlerStatusResponse as StatusResponse,
15 HandlerResponse as Response,
16 RESPONSE_TYPE,
17)
19from mindsdb.integrations.libs.api_handler import APIResource, APIHandler
20from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
22logger = log.getLogger(__name__)
25class ListFilesTable(APIResource):
26 def list(
27 self,
28 targets: List[str] = None,
29 conditions: List[FilterCondition] = None,
30 limit: int = None,
31 *args,
32 **kwargs,
33 ) -> pd.DataFrame:
34 buckets = None
35 for condition in conditions:
36 if condition.column == "bucket":
37 if condition.op == FilterOperator.IN:
38 buckets = condition.value
39 elif condition.op == FilterOperator.EQUAL:
40 buckets = [condition.value]
41 condition.applied = True
43 data = []
44 for obj in self.handler.get_objects(limit=limit, buckets=buckets):
45 path = obj["Key"]
46 if obj["Filename"].split(".")[1] in self.handler.supported_file_formats:
47 item = {
48 "path": path,
49 "bucket": obj["Bucket"],
50 "name": path[path.rfind("/") + 1 :],
51 "extension": path[path.rfind(".") + 1 :],
52 }
54 data.append(item)
56 return pd.DataFrame(data=data, columns=self.get_columns())
58 def get_columns(self) -> List[str]:
59 return ["path", "name", "extension", "bucket", "content"]
62class FileTable(APIResource):
63 def list(self, targets: List[str] = None, table_name=None, *args, **kwargs) -> pd.DataFrame:
64 return self.handler.read_as_table(table_name)
66 def add(self, data, table_name=None):
67 df = pd.DataFrame(data)
68 return self.handler.add_data_to_table(table_name, df)
71class IBMCloudObjectStorageHandler(APIHandler):
72 name = "ibm_cos"
73 supported_file_formats = ["csv", "tsv", "json", "parquet"]
75 def __init__(self, name: Text, connection_data: Optional[Dict] = None, **kwargs):
76 super().__init__(name)
77 self.connection_data = connection_data or {}
78 self.kwargs = kwargs
80 self.connection = None
81 self.is_connected = False
82 self.cache_thread_safe = True
83 self._regions = {}
85 self.bucket = self.connection_data.get("bucket")
86 self._files_table = ListFilesTable(self)
88 def __del__(self):
89 if self.is_connected is True:
90 self.disconnect()
92 def connect(self):
93 if self.is_connected is True:
94 return self.connection
96 required_params = [
97 "cos_hmac_access_key_id",
98 "cos_hmac_secret_access_key",
99 "cos_endpoint_url",
100 ]
101 if not all(key in self.connection_data for key in required_params):
102 raise ValueError(
103 "Required parameters (cos_hmac_access_key_id, cos_hmac_secret_access_key, cos_endpoint_url) must be provided."
104 )
106 self.connection = self._connect_ibm_boto3()
107 self.is_connected = True
109 return self.connection
111 def _connect_ibm_boto3(self) -> ibm_boto3.client:
112 config = {
113 "aws_access_key_id": self.connection_data["cos_hmac_access_key_id"],
114 "aws_secret_access_key": self.connection_data["cos_hmac_secret_access_key"],
115 "endpoint_url": self.connection_data["cos_endpoint_url"],
116 }
118 client = ibm_boto3.client("s3", **config)
120 if self.bucket is not None:
121 client.head_bucket(Bucket=self.bucket)
122 else:
123 client.list_buckets()
125 return client
127 def disconnect(self):
128 if not self.is_connected:
129 return
130 self.connection = None
131 self.is_connected = False
133 def check_connection(self) -> StatusResponse:
134 response = StatusResponse(False)
135 need_to_close = self.is_connected is False
137 try:
138 self._connect_ibm_boto3()
139 response.success = True
140 except (ClientError, ValueError) as e:
141 logger.error(f"Error connecting to IBM COS with the given credentials, {e}!")
142 response.error_message = str(e)
144 if response.success and need_to_close:
145 self.disconnect()
147 elif not response.success and self.is_connected:
148 self.is_connected = False
150 return response
152 @contextmanager
153 def _connect_duckdb(self):
154 duckdb_conn = duckdb.connect(":memory:")
155 duckdb_conn.execute("INSTALL httpfs")
156 duckdb_conn.execute("LOAD httpfs")
158 duckdb_conn.execute(f"SET s3_access_key_id='{self.connection_data['cos_hmac_access_key_id']}'")
159 duckdb_conn.execute(f"SET s3_secret_access_key='{self.connection_data['cos_hmac_secret_access_key']}'")
161 endpoint_url = self.connection_data["cos_endpoint_url"]
162 if endpoint_url.startswith("https://"):
163 endpoint_url = endpoint_url[len("https://") :]
164 elif endpoint_url.startswith("http://"):
165 endpoint_url = endpoint_url[len("http://") :]
167 duckdb_conn.execute(f"SET s3_endpoint='{endpoint_url}'")
168 duckdb_conn.execute("SET s3_url_style='path'")
169 duckdb_conn.execute("SET s3_use_ssl=true")
171 try:
172 yield duckdb_conn
173 finally:
174 duckdb_conn.close()
176 def _get_bucket(self, key):
177 if self.bucket is not None:
178 return self.bucket, key
180 ar = key.split("/")
181 return ar[0], "/".join(ar[1:])
183 def read_as_table(self, key) -> pd.DataFrame:
184 bucket, key = self._get_bucket(key)
186 with self._connect_duckdb() as connection:
187 cursor = connection.execute(f"SELECT * FROM 's3://{bucket}/{key}'")
189 return cursor.fetchdf()
191 def _read_as_content(self, key) -> None:
192 bucket, key = self._get_bucket(key)
194 client = self.connect()
196 obj = client.get_object(Bucket=bucket, Key=key)
197 content = obj["Body"].read()
198 return content
200 def add_data_to_table(self, key, df) -> None:
201 bucket, key = self._get_bucket(key)
203 try:
204 client = self.connect()
205 client.head_object(Bucket=bucket, Key=key)
206 except ClientError as e:
207 logger.error(f"Error querying the file {key} in the bucket {bucket}, {e}!")
208 raise e
210 with self._connect_duckdb() as connection:
211 connection.execute(f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{bucket}/{key}'")
213 connection.execute("INSERT INTO tmp_table BY NAME SELECT * FROM df")
215 connection.execute(f"COPY tmp_table TO 's3://{bucket}/{key}'")
217 def query(self, query: ASTNode) -> Response:
218 self.connect()
219 if isinstance(query, Select):
220 table_name = query.from_table.parts[-1]
222 if table_name == "files":
223 table = self._files_table
224 df = table.select(query)
226 has_content = False
227 for target in query.targets:
228 if isinstance(target, Identifier) and target.parts[-1].lower() == "content":
229 has_content = True
230 break
231 if has_content:
232 df["content"] = df["path"].apply(self._read_as_content)
233 else:
234 extension = table_name.split(".")[-1]
235 if extension not in self.supported_file_formats:
236 logger.error(f"The file format {extension} is not supported!")
237 raise ValueError(f"The file format {extension} is not supported!")
239 table = FileTable(self, table_name=table_name)
240 df = table.select(query)
242 response = Response(RESPONSE_TYPE.TABLE, data_frame=df)
243 elif isinstance(query, Insert):
244 table_name = query.table.parts[-1]
245 table = FileTable(self, table_name=table_name)
246 table.insert(query)
247 response = Response(RESPONSE_TYPE.OK)
248 else:
249 raise NotImplementedError
251 return response
253 def get_objects(self, limit=None, buckets=None) -> List[dict]:
254 client = self.connect()
255 if self.bucket is not None:
256 add_bucket_to_name = False
257 scan_buckets = [self.bucket]
258 else:
259 add_bucket_to_name = True
260 resp = client.list_buckets()
261 scan_buckets = [b["Name"] for b in resp["Buckets"]]
263 objects = []
264 for bucket in scan_buckets:
265 if buckets is not None and bucket not in buckets:
266 continue
268 resp = client.list_objects_v2(Bucket=bucket)
269 if "Contents" not in resp:
270 continue
272 for obj in resp["Contents"]:
273 obj["Bucket"] = bucket
274 obj["Filename"] = obj["Key"]
275 if add_bucket_to_name:
276 obj["Key"] = f"{bucket}/{obj['Key']}"
277 objects.append(obj)
278 if limit is not None and len(objects) >= limit:
279 break
281 return objects
283 def get_tables(self) -> Response:
284 supported_names = [
285 f"{obj['Key']}" for obj in self.get_objects() if obj["Key"].split(".")[-1] in self.supported_file_formats
286 ]
288 supported_names.insert(0, "files")
290 response = Response(
291 RESPONSE_TYPE.TABLE,
292 data_frame=pd.DataFrame(supported_names, columns=["table_name"]),
293 )
295 return response
297 def get_columns(self, table_name: str) -> Response:
298 query = Select(
299 targets=[Star()],
300 from_table=Identifier(parts=[table_name]),
301 limit=Constant(1),
302 )
304 result = self.query(query)
306 response = Response(
307 RESPONSE_TYPE.TABLE,
308 data_frame=pd.DataFrame(
309 {
310 "column_name": result.data_frame.columns,
311 "data_type": [
312 str(dtype) if str(dtype) != "object" else "string" for dtype in result.data_frame.dtypes
313 ],
314 }
315 ),
316 )
318 return response