Coverage for mindsdb / integrations / handlers / s3_handler / s3_handler.py: 61%

216 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List 

2from contextlib import contextmanager 

3 

4import boto3 

5import duckdb 

6from duckdb import HTTPException 

7from mindsdb_sql_parser import parse_sql 

8import pandas as pd 

9from typing import Text, Dict, Optional 

10from botocore.client import Config 

11from botocore.exceptions import ClientError 

12 

13from mindsdb_sql_parser.ast.base import ASTNode 

14from mindsdb_sql_parser.ast import Select, Identifier, Insert, Star, Constant 

15 

16from mindsdb.utilities import log 

17from mindsdb.integrations.libs.response import ( 

18 HandlerStatusResponse as StatusResponse, 

19 HandlerResponse as Response, 

20 RESPONSE_TYPE, 

21) 

22 

23from mindsdb.integrations.libs.api_handler import APIResource, APIHandler 

24from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator 

25 

26logger = log.getLogger(__name__) 

27 

28 

29class ListFilesTable(APIResource): 

30 def list( 

31 self, targets: List[str] = None, conditions: List[FilterCondition] = None, limit: int = None, *args, **kwargs 

32 ) -> pd.DataFrame: 

33 buckets = None 

34 for condition in conditions: 

35 if condition.column == "bucket": 

36 if condition.op == FilterOperator.IN: 

37 buckets = condition.value 

38 elif condition.op == FilterOperator.EQUAL: 

39 buckets = [condition.value] 

40 condition.applied = True 

41 

42 data = [] 

43 for obj in self.handler.get_objects(limit=limit, buckets=buckets): 

44 path = obj["Key"] 

45 path = path.replace("`", "") 

46 item = { 

47 "path": path, 

48 "bucket": obj["Bucket"], 

49 "name": path[path.rfind("/") + 1 :], 

50 "extension": path[path.rfind(".") + 1 :], 

51 } 

52 

53 if targets and "public_url" in targets: 

54 item["public_url"] = self.handler.generate_sas_url(path, obj["Bucket"]) 

55 

56 data.append(item) 

57 

58 return pd.DataFrame(data=data, columns=self.get_columns()) 

59 

60 def get_columns(self) -> List[str]: 

61 return ["path", "name", "extension", "bucket", "content", "public_url"] 

62 

63 

64class FileTable(APIResource): 

65 def list(self, targets: List[str] = None, table_name=None, *args, **kwargs) -> pd.DataFrame: 

66 return self.handler.read_as_table(table_name) 

67 

68 def add(self, data, table_name=None): 

69 df = pd.DataFrame(data) 

70 return self.handler.add_data_to_table(table_name, df) 

71 

72 

73class S3Handler(APIHandler): 

74 """ 

75 This handler handles connection and execution of the SQL statements on AWS S3. 

76 """ 

77 

78 name = "s3" 

79 # TODO: Can other file formats be supported? 

80 supported_file_formats = ["csv", "tsv", "json", "parquet"] 

81 

82 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs): 

83 """ 

84 Initializes the handler. 

85 

86 Args: 

87 name (Text): The name of the handler instance. 

88 connection_data (Dict): The connection data required to connect to the AWS (S3) account. 

89 kwargs: Arbitrary keyword arguments. 

90 """ 

91 super().__init__(name) 

92 self.connection_data = connection_data 

93 self.kwargs = kwargs 

94 

95 self.connection = None 

96 self.is_connected = False 

97 self.cache_thread_safe = True 

98 self.bucket = self.connection_data.get("bucket") 

99 self._regions = {} 

100 

101 self._files_table = ListFilesTable(self) 

102 

103 def __del__(self): 

104 if self.is_connected is True: 

105 self.disconnect() 

106 

107 def connect(self): 

108 """ 

109 Establishes a connection to the AWS (S3) account. 

110 

111 Raises: 

112 ValueError: If the required connection parameters are not provided. 

113 

114 Returns: 

115 boto3.client: A client object to the AWS (S3) account. 

116 """ 

117 if self.is_connected is True: 

118 return self.connection 

119 

120 # Validate mandatory parameters. 

121 if not all(key in self.connection_data for key in ["aws_access_key_id", "aws_secret_access_key"]): 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 raise ValueError("Required parameters (aws_access_key_id, aws_secret_access_key) must be provided.") 

123 

124 # Connect to S3 and configure mandatory credentials. 

125 self.connection = self._connect_boto3() 

126 self.is_connected = True 

127 

128 return self.connection 

129 

130 @contextmanager 

131 def _connect_duckdb(self, bucket): 

132 """ 

133 Creates temporal duckdb database which is able to connect to the AWS (S3) account. 

134 Have to be used as context manager 

135 

136 Returns: 

137 DuckDBPyConnection 

138 """ 

139 # Connect to S3 via DuckDB. 

140 duckdb_conn = duckdb.connect(":memory:") 

141 try: 

142 duckdb_conn.execute("INSTALL httpfs") 

143 except HTTPException as http_error: 

144 logger.debug(f"Error installing the httpfs extension, {http_error}! Forcing installation.") 

145 duckdb_conn.execute("FORCE INSTALL httpfs") 

146 

147 duckdb_conn.execute("LOAD httpfs") 

148 

149 # Configure mandatory credentials. 

150 duckdb_conn.execute(f"SET s3_access_key_id='{self.connection_data['aws_access_key_id']}'") 

151 duckdb_conn.execute(f"SET s3_secret_access_key='{self.connection_data['aws_secret_access_key']}'") 

152 

153 # Configure optional parameters. 

154 if "aws_session_token" in self.connection_data: 

155 duckdb_conn.execute(f"SET s3_session_token='{self.connection_data['aws_session_token']}'") 

156 

157 # detect region for bucket 

158 if bucket not in self._regions: 

159 client = self.connect() 

160 self._regions[bucket] = client.get_bucket_location(Bucket=bucket)["LocationConstraint"] 

161 

162 region = self._regions[bucket] 

163 duckdb_conn.execute(f"SET s3_region='{region}'") 

164 

165 try: 

166 yield duckdb_conn 

167 finally: 

168 duckdb_conn.close() 

169 

170 def _connect_boto3(self) -> boto3.client: 

171 """ 

172 Establishes a connection to the AWS (S3) account. 

173 

174 Returns: 

175 boto3.client: A client object to the AWS (S3) account. 

176 """ 

177 # Configure mandatory credentials. 

178 config = { 

179 "aws_access_key_id": self.connection_data["aws_access_key_id"], 

180 "aws_secret_access_key": self.connection_data["aws_secret_access_key"], 

181 } 

182 

183 # Configure optional parameters. 

184 optional_parameters = ["region_name", "aws_session_token"] 

185 for parameter in optional_parameters: 

186 if parameter in self.connection_data: 

187 config[parameter] = self.connection_data[parameter] 

188 

189 client = boto3.client("s3", **config, config=Config(signature_version="s3v4")) 

190 

191 # check connection 

192 if self.bucket is not None: 192 ↛ 195line 192 didn't jump to line 195 because the condition on line 192 was always true

193 client.head_bucket(Bucket=self.bucket) 

194 else: 

195 client.list_buckets() 

196 

197 return client 

198 

199 def disconnect(self): 

200 """ 

201 Closes the connection to the AWS (S3) account if it's currently open. 

202 """ 

203 if not self.is_connected: 

204 return 

205 self.connection.close() 

206 self.is_connected = False 

207 

208 def check_connection(self) -> StatusResponse: 

209 """ 

210 Checks the status of the connection to the S3 bucket. 

211 

212 Returns: 

213 StatusResponse: An object containing the success status and an error message if an error occurs. 

214 """ 

215 response = StatusResponse(False) 

216 need_to_close = self.is_connected is False 

217 

218 # Check connection via boto3. 

219 try: 

220 self._connect_boto3() 

221 response.success = True 

222 except (ClientError, ValueError) as e: 

223 logger.error(f"Error connecting to S3 with the given credentials, {e}!") 

224 response.error_message = str(e) 

225 

226 if response.success and need_to_close: 

227 self.disconnect() 

228 

229 elif not response.success and self.is_connected: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 self.is_connected = False 

231 

232 return response 

233 

234 def _get_bucket(self, key): 

235 if self.bucket is not None: 235 ↛ 239line 235 didn't jump to line 239 because the condition on line 235 was always true

236 return self.bucket, key 

237 

238 # get bucket from first part of the key 

239 ar = key.split("/") 

240 return ar[0], "/".join(ar[1:]) 

241 

242 def read_as_table(self, key) -> pd.DataFrame: 

243 """ 

244 Read object as dataframe. Uses duckdb 

245 """ 

246 bucket, key = self._get_bucket(key) 

247 

248 with self._connect_duckdb(bucket) as connection: 

249 cursor = connection.execute(f"SELECT * FROM 's3://{bucket}/{key}'") 

250 

251 return cursor.fetchdf() 

252 

253 def _read_as_content(self, key) -> None: 

254 """ 

255 Read object as content 

256 """ 

257 bucket, key = self._get_bucket(key) 

258 

259 client = self.connect() 

260 

261 obj = client.get_object(Bucket=bucket, Key=key) 

262 content = obj["Body"].read() 

263 return content 

264 

265 def add_data_to_table(self, key, df) -> None: 

266 """ 

267 Writes the table to a file in the S3 bucket. 

268 

269 Raises: 

270 CatalogException: If the table does not exist in the DuckDB connection. 

271 """ 

272 

273 # Check if the file exists in the S3 bucket. 

274 bucket, key = self._get_bucket(key) 

275 

276 try: 

277 client = self.connect() 

278 client.head_object(Bucket=bucket, Key=key) 

279 except ClientError as e: 

280 logger.error(f"Error querying the file {key} in the bucket {bucket}, {e}!") 

281 raise e 

282 

283 with self._connect_duckdb(bucket) as connection: 

284 # copy 

285 connection.execute(f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{bucket}/{key}'") 

286 

287 # insert 

288 connection.execute("INSERT INTO tmp_table BY NAME SELECT * FROM df") 

289 

290 # upload 

291 connection.execute(f"COPY tmp_table TO 's3://{bucket}/{key}'") 

292 

293 def query(self, query: ASTNode) -> Response: 

294 """ 

295 Executes a SQL query represented by an ASTNode and retrieves the data. 

296 

297 Args: 

298 query (ASTNode): An ASTNode representing the SQL query to be executed. 

299 

300 Raises: 

301 ValueError: If the file format is not supported or the file does not exist in the S3 bucket. 

302 

303 Returns: 

304 Response: A response object containing the result of the query or an error message. 

305 """ 

306 

307 self.connect() 

308 

309 if isinstance(query, Select): 

310 table_name = query.from_table.parts[-1] 

311 

312 if table_name == "files": 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 table = self._files_table 

314 df = table.select(query) 

315 

316 # add content 

317 has_content = False 

318 for target in query.targets: 

319 if isinstance(target, Identifier) and target.parts[-1].lower() == "content": 

320 has_content = True 

321 break 

322 if has_content: 

323 df["content"] = df["path"].apply(self._read_as_content) 

324 else: 

325 extension = table_name.split(".")[-1] 

326 if extension not in self.supported_file_formats: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true

327 logger.error(f"The file format {extension} is not supported!") 

328 raise ValueError(f"The file format {extension} is not supported!") 

329 

330 table = FileTable(self, table_name=table_name) 

331 df = table.select(query) 

332 

333 response = Response(RESPONSE_TYPE.TABLE, data_frame=df) 

334 elif isinstance(query, Insert): 334 ↛ 340line 334 didn't jump to line 340 because the condition on line 334 was always true

335 table_name = query.table.parts[-1] 

336 table = FileTable(self, table_name=table_name) 

337 table.insert(query) 

338 response = Response(RESPONSE_TYPE.OK) 

339 else: 

340 raise NotImplementedError 

341 

342 return response 

343 

344 def native_query(self, query: str) -> Response: 

345 """ 

346 Executes a SQL query and returns the result. 

347 

348 Args: 

349 query (str): The SQL query to be executed. 

350 

351 Returns: 

352 Response: A response object containing the result of the query or an error message. 

353 """ 

354 query_ast = parse_sql(query) 

355 return self.query(query_ast) 

356 

357 def get_objects(self, limit=None, buckets=None) -> List[dict]: 

358 client = self.connect() 

359 if self.bucket is not None: 359 ↛ 363line 359 didn't jump to line 363 because the condition on line 359 was always true

360 add_bucket_to_name = False 

361 scan_buckets = [self.bucket] 

362 else: 

363 add_bucket_to_name = True 

364 scan_buckets = [b["Name"] for b in client.list_buckets()["Buckets"]] 

365 

366 objects = [] 

367 for bucket in scan_buckets: 

368 if buckets is not None and bucket not in buckets: 368 ↛ 369line 368 didn't jump to line 369 because the condition on line 368 was never true

369 continue 

370 

371 resp = client.list_objects_v2(Bucket=bucket) 

372 if "Contents" not in resp: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true

373 continue 

374 

375 for obj in resp["Contents"]: 

376 if obj.get("StorageClass", "STANDARD") != "STANDARD": 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true

377 continue 

378 

379 obj["Bucket"] = bucket 

380 if add_bucket_to_name: 380 ↛ 382line 380 didn't jump to line 382 because the condition on line 380 was never true

381 # bucket is part of the name 

382 obj["Key"] = f"{bucket}/{obj['Key']}" 

383 objects.append(obj) 

384 if limit is not None and len(objects) >= limit: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true

385 break 

386 

387 return objects 

388 

389 def generate_sas_url(self, key: str, bucket: str) -> str: 

390 """ 

391 Generates a pre-signed URL for accessing an object in the S3 bucket. 

392 

393 Args: 

394 key (str): The key (path) of the object in the S3 bucket. 

395 bucket (str): The name of the S3 bucket. 

396 

397 Returns: 

398 str: The pre-signed URL for accessing the object. 

399 """ 

400 client = self.connect() 

401 url = client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=3600) 

402 return url 

403 

404 def get_tables(self) -> Response: 

405 """ 

406 Retrieves a list of tables (objects) in the S3 bucket. 

407 

408 Each object is considered a table. Only the supported file formats are considered as tables. 

409 

410 Returns: 

411 Response: A response object containing the list of tables and views, formatted as per the `Response` class. 

412 """ 

413 

414 # Get only the supported file formats. 

415 # Wrap the object names with backticks to prevent SQL syntax errors. 

416 supported_names = [ 

417 f"`{obj['Key']}`" for obj in self.get_objects() if obj["Key"].split(".")[-1] in self.supported_file_formats 

418 ] 

419 

420 # virtual table with list of files 

421 supported_names.insert(0, "files") 

422 

423 response = Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame(supported_names, columns=["table_name"])) 

424 

425 return response 

426 

427 def get_columns(self, table_name: str) -> Response: 

428 """ 

429 Retrieves column details for a specified table (object) in the S3 bucket. 

430 

431 Args: 

432 table_name (Text): The name of the table for which to retrieve column information. 

433 

434 Raises: 

435 ValueError: If the 'table_name' is not a valid string. 

436 

437 Returns: 

438 Response: A response object containing the column details, formatted as per the `Response` class. 

439 """ 

440 query = Select(targets=[Star()], from_table=Identifier(parts=[table_name]), limit=Constant(1)) 

441 

442 result = self.query(query) 

443 

444 response = Response( 

445 RESPONSE_TYPE.TABLE, 

446 data_frame=pd.DataFrame( 

447 { 

448 "column_name": result.data_frame.columns, 

449 "data_type": [ 

450 data_type if data_type != "object" else "string" for data_type in result.data_frame.dtypes 

451 ], 

452 } 

453 ), 

454 ) 

455 

456 return response