Coverage for mindsdb / integrations / handlers / s3_handler / s3_handler.py: 61%
216 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List
2from contextlib import contextmanager
4import boto3
5import duckdb
6from duckdb import HTTPException
7from mindsdb_sql_parser import parse_sql
8import pandas as pd
9from typing import Text, Dict, Optional
10from botocore.client import Config
11from botocore.exceptions import ClientError
13from mindsdb_sql_parser.ast.base import ASTNode
14from mindsdb_sql_parser.ast import Select, Identifier, Insert, Star, Constant
16from mindsdb.utilities import log
17from mindsdb.integrations.libs.response import (
18 HandlerStatusResponse as StatusResponse,
19 HandlerResponse as Response,
20 RESPONSE_TYPE,
21)
23from mindsdb.integrations.libs.api_handler import APIResource, APIHandler
24from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator
26logger = log.getLogger(__name__)
29class ListFilesTable(APIResource):
30 def list(
31 self, targets: List[str] = None, conditions: List[FilterCondition] = None, limit: int = None, *args, **kwargs
32 ) -> pd.DataFrame:
33 buckets = None
34 for condition in conditions:
35 if condition.column == "bucket":
36 if condition.op == FilterOperator.IN:
37 buckets = condition.value
38 elif condition.op == FilterOperator.EQUAL:
39 buckets = [condition.value]
40 condition.applied = True
42 data = []
43 for obj in self.handler.get_objects(limit=limit, buckets=buckets):
44 path = obj["Key"]
45 path = path.replace("`", "")
46 item = {
47 "path": path,
48 "bucket": obj["Bucket"],
49 "name": path[path.rfind("/") + 1 :],
50 "extension": path[path.rfind(".") + 1 :],
51 }
53 if targets and "public_url" in targets:
54 item["public_url"] = self.handler.generate_sas_url(path, obj["Bucket"])
56 data.append(item)
58 return pd.DataFrame(data=data, columns=self.get_columns())
60 def get_columns(self) -> List[str]:
61 return ["path", "name", "extension", "bucket", "content", "public_url"]
64class FileTable(APIResource):
65 def list(self, targets: List[str] = None, table_name=None, *args, **kwargs) -> pd.DataFrame:
66 return self.handler.read_as_table(table_name)
68 def add(self, data, table_name=None):
69 df = pd.DataFrame(data)
70 return self.handler.add_data_to_table(table_name, df)
73class S3Handler(APIHandler):
74 """
75 This handler handles connection and execution of the SQL statements on AWS S3.
76 """
78 name = "s3"
79 # TODO: Can other file formats be supported?
80 supported_file_formats = ["csv", "tsv", "json", "parquet"]
82 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs):
83 """
84 Initializes the handler.
86 Args:
87 name (Text): The name of the handler instance.
88 connection_data (Dict): The connection data required to connect to the AWS (S3) account.
89 kwargs: Arbitrary keyword arguments.
90 """
91 super().__init__(name)
92 self.connection_data = connection_data
93 self.kwargs = kwargs
95 self.connection = None
96 self.is_connected = False
97 self.cache_thread_safe = True
98 self.bucket = self.connection_data.get("bucket")
99 self._regions = {}
101 self._files_table = ListFilesTable(self)
103 def __del__(self):
104 if self.is_connected is True:
105 self.disconnect()
107 def connect(self):
108 """
109 Establishes a connection to the AWS (S3) account.
111 Raises:
112 ValueError: If the required connection parameters are not provided.
114 Returns:
115 boto3.client: A client object to the AWS (S3) account.
116 """
117 if self.is_connected is True:
118 return self.connection
120 # Validate mandatory parameters.
121 if not all(key in self.connection_data for key in ["aws_access_key_id", "aws_secret_access_key"]): 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true
122 raise ValueError("Required parameters (aws_access_key_id, aws_secret_access_key) must be provided.")
124 # Connect to S3 and configure mandatory credentials.
125 self.connection = self._connect_boto3()
126 self.is_connected = True
128 return self.connection
130 @contextmanager
131 def _connect_duckdb(self, bucket):
132 """
133 Creates temporal duckdb database which is able to connect to the AWS (S3) account.
134 Have to be used as context manager
136 Returns:
137 DuckDBPyConnection
138 """
139 # Connect to S3 via DuckDB.
140 duckdb_conn = duckdb.connect(":memory:")
141 try:
142 duckdb_conn.execute("INSTALL httpfs")
143 except HTTPException as http_error:
144 logger.debug(f"Error installing the httpfs extension, {http_error}! Forcing installation.")
145 duckdb_conn.execute("FORCE INSTALL httpfs")
147 duckdb_conn.execute("LOAD httpfs")
149 # Configure mandatory credentials.
150 duckdb_conn.execute(f"SET s3_access_key_id='{self.connection_data['aws_access_key_id']}'")
151 duckdb_conn.execute(f"SET s3_secret_access_key='{self.connection_data['aws_secret_access_key']}'")
153 # Configure optional parameters.
154 if "aws_session_token" in self.connection_data:
155 duckdb_conn.execute(f"SET s3_session_token='{self.connection_data['aws_session_token']}'")
157 # detect region for bucket
158 if bucket not in self._regions:
159 client = self.connect()
160 self._regions[bucket] = client.get_bucket_location(Bucket=bucket)["LocationConstraint"]
162 region = self._regions[bucket]
163 duckdb_conn.execute(f"SET s3_region='{region}'")
165 try:
166 yield duckdb_conn
167 finally:
168 duckdb_conn.close()
170 def _connect_boto3(self) -> boto3.client:
171 """
172 Establishes a connection to the AWS (S3) account.
174 Returns:
175 boto3.client: A client object to the AWS (S3) account.
176 """
177 # Configure mandatory credentials.
178 config = {
179 "aws_access_key_id": self.connection_data["aws_access_key_id"],
180 "aws_secret_access_key": self.connection_data["aws_secret_access_key"],
181 }
183 # Configure optional parameters.
184 optional_parameters = ["region_name", "aws_session_token"]
185 for parameter in optional_parameters:
186 if parameter in self.connection_data:
187 config[parameter] = self.connection_data[parameter]
189 client = boto3.client("s3", **config, config=Config(signature_version="s3v4"))
191 # check connection
192 if self.bucket is not None: 192 ↛ 195line 192 didn't jump to line 195 because the condition on line 192 was always true
193 client.head_bucket(Bucket=self.bucket)
194 else:
195 client.list_buckets()
197 return client
199 def disconnect(self):
200 """
201 Closes the connection to the AWS (S3) account if it's currently open.
202 """
203 if not self.is_connected:
204 return
205 self.connection.close()
206 self.is_connected = False
208 def check_connection(self) -> StatusResponse:
209 """
210 Checks the status of the connection to the S3 bucket.
212 Returns:
213 StatusResponse: An object containing the success status and an error message if an error occurs.
214 """
215 response = StatusResponse(False)
216 need_to_close = self.is_connected is False
218 # Check connection via boto3.
219 try:
220 self._connect_boto3()
221 response.success = True
222 except (ClientError, ValueError) as e:
223 logger.error(f"Error connecting to S3 with the given credentials, {e}!")
224 response.error_message = str(e)
226 if response.success and need_to_close:
227 self.disconnect()
229 elif not response.success and self.is_connected: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 self.is_connected = False
232 return response
234 def _get_bucket(self, key):
235 if self.bucket is not None: 235 ↛ 239line 235 didn't jump to line 239 because the condition on line 235 was always true
236 return self.bucket, key
238 # get bucket from first part of the key
239 ar = key.split("/")
240 return ar[0], "/".join(ar[1:])
242 def read_as_table(self, key) -> pd.DataFrame:
243 """
244 Read object as dataframe. Uses duckdb
245 """
246 bucket, key = self._get_bucket(key)
248 with self._connect_duckdb(bucket) as connection:
249 cursor = connection.execute(f"SELECT * FROM 's3://{bucket}/{key}'")
251 return cursor.fetchdf()
253 def _read_as_content(self, key) -> None:
254 """
255 Read object as content
256 """
257 bucket, key = self._get_bucket(key)
259 client = self.connect()
261 obj = client.get_object(Bucket=bucket, Key=key)
262 content = obj["Body"].read()
263 return content
265 def add_data_to_table(self, key, df) -> None:
266 """
267 Writes the table to a file in the S3 bucket.
269 Raises:
270 CatalogException: If the table does not exist in the DuckDB connection.
271 """
273 # Check if the file exists in the S3 bucket.
274 bucket, key = self._get_bucket(key)
276 try:
277 client = self.connect()
278 client.head_object(Bucket=bucket, Key=key)
279 except ClientError as e:
280 logger.error(f"Error querying the file {key} in the bucket {bucket}, {e}!")
281 raise e
283 with self._connect_duckdb(bucket) as connection:
284 # copy
285 connection.execute(f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{bucket}/{key}'")
287 # insert
288 connection.execute("INSERT INTO tmp_table BY NAME SELECT * FROM df")
290 # upload
291 connection.execute(f"COPY tmp_table TO 's3://{bucket}/{key}'")
293 def query(self, query: ASTNode) -> Response:
294 """
295 Executes a SQL query represented by an ASTNode and retrieves the data.
297 Args:
298 query (ASTNode): An ASTNode representing the SQL query to be executed.
300 Raises:
301 ValueError: If the file format is not supported or the file does not exist in the S3 bucket.
303 Returns:
304 Response: A response object containing the result of the query or an error message.
305 """
307 self.connect()
309 if isinstance(query, Select):
310 table_name = query.from_table.parts[-1]
312 if table_name == "files": 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 table = self._files_table
314 df = table.select(query)
316 # add content
317 has_content = False
318 for target in query.targets:
319 if isinstance(target, Identifier) and target.parts[-1].lower() == "content":
320 has_content = True
321 break
322 if has_content:
323 df["content"] = df["path"].apply(self._read_as_content)
324 else:
325 extension = table_name.split(".")[-1]
326 if extension not in self.supported_file_formats: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true
327 logger.error(f"The file format {extension} is not supported!")
328 raise ValueError(f"The file format {extension} is not supported!")
330 table = FileTable(self, table_name=table_name)
331 df = table.select(query)
333 response = Response(RESPONSE_TYPE.TABLE, data_frame=df)
334 elif isinstance(query, Insert): 334 ↛ 340line 334 didn't jump to line 340 because the condition on line 334 was always true
335 table_name = query.table.parts[-1]
336 table = FileTable(self, table_name=table_name)
337 table.insert(query)
338 response = Response(RESPONSE_TYPE.OK)
339 else:
340 raise NotImplementedError
342 return response
344 def native_query(self, query: str) -> Response:
345 """
346 Executes a SQL query and returns the result.
348 Args:
349 query (str): The SQL query to be executed.
351 Returns:
352 Response: A response object containing the result of the query or an error message.
353 """
354 query_ast = parse_sql(query)
355 return self.query(query_ast)
357 def get_objects(self, limit=None, buckets=None) -> List[dict]:
358 client = self.connect()
359 if self.bucket is not None: 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true
360 add_bucket_to_name = False
361 scan_buckets = [self.bucket]
362 else:
363 add_bucket_to_name = True
364 scan_buckets = [b["Name"] for b in client.list_buckets()["Buckets"]]
366 objects = []
367 for bucket in scan_buckets:
368 if buckets is not None and bucket not in buckets: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true
369 continue
371 resp = client.list_objects_v2(Bucket=bucket)
372 if "Contents" not in resp: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true
373 continue
375 for obj in resp["Contents"]:
376 if obj.get("StorageClass", "STANDARD") != "STANDARD": 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true
377 continue
379 obj["Bucket"] = bucket
380 if add_bucket_to_name: 380 ↛ 382line 380 didn't jump to line 382 because the condition on line 380 was never true
381 # bucket is part of the name
382 obj["Key"] = f"{bucket}/{obj['Key']}"
383 objects.append(obj)
384 if limit is not None and len(objects) >= limit: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true
385 break
387 return objects
389 def generate_sas_url(self, key: str, bucket: str) -> str:
390 """
391 Generates a pre-signed URL for accessing an object in the S3 bucket.
393 Args:
394 key (str): The key (path) of the object in the S3 bucket.
395 bucket (str): The name of the S3 bucket.
397 Returns:
398 str: The pre-signed URL for accessing the object.
399 """
400 client = self.connect()
401 url = client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=3600)
402 return url
404 def get_tables(self) -> Response:
405 """
406 Retrieves a list of tables (objects) in the S3 bucket.
408 Each object is considered a table. Only the supported file formats are considered as tables.
410 Returns:
411 Response: A response object containing the list of tables and views, formatted as per the `Response` class.
412 """
414 # Get only the supported file formats.
415 # Wrap the object names with backticks to prevent SQL syntax errors.
416 supported_names = [
417 f"`{obj['Key']}`" for obj in self.get_objects() if obj["Key"].split(".")[-1] in self.supported_file_formats
418 ]
420 # virtual table with list of files
421 supported_names.insert(0, "files")
423 response = Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame(supported_names, columns=["table_name"]))
425 return response
427 def get_columns(self, table_name: str) -> Response:
428 """
429 Retrieves column details for a specified table (object) in the S3 bucket.
431 Args:
432 table_name (Text): The name of the table for which to retrieve column information.
434 Raises:
435 ValueError: If the 'table_name' is not a valid string.
437 Returns:
438 Response: A response object containing the column details, formatted as per the `Response` class.
439 """
440 query = Select(targets=[Star()], from_table=Identifier(parts=[table_name]), limit=Constant(1))
442 result = self.query(query)
444 response = Response(
445 RESPONSE_TYPE.TABLE,
446 data_frame=pd.DataFrame(
447 {
448 "column_name": result.data_frame.columns,
449 "data_type": [
450 data_type if data_type != "object" else "string" for data_type in result.data_frame.dtypes
451 ],
452 }
453 ),
454 )
456 return response