Coverage for mindsdb / integrations / handlers / athena_handler / athena_handler.py: 0%
84 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import time
2import pandas as pd
3from boto3 import client
4from typing import Optional
6from mindsdb_sql_parser import parse_sql
7from mindsdb.integrations.libs.base import DatabaseHandler
8from mindsdb_sql_parser.ast.base import ASTNode
9from mindsdb.utilities import log
10from mindsdb.integrations.libs.response import (
11 HandlerStatusResponse as StatusResponse,
12 HandlerResponse as Response,
13 RESPONSE_TYPE
14)
16logger = log.getLogger(__name__)
19class AthenaHandler(DatabaseHandler):
20 """
21 This handler handles connection and execution of the Athena statements.
22 """
24 name = 'athena'
26 def __init__(self, name: str, connection_data: Optional[dict], **kwargs):
27 """
28 Initialize the handler.
29 Args:
30 name (str): name of particular handler instance
31 connection_data (dict): parameters for connecting to the database
32 **kwargs: arbitrary keyword arguments.
33 """
34 super().__init__(name)
35 self.parser = parse_sql
36 self.dialect = 'athena'
38 self.connection_data = connection_data
39 self.kwargs = kwargs
41 self.connection = None
42 self.is_connected = False
44 def connect(self) -> StatusResponse:
45 """
46 Set up the connection required by the handler.
47 Returns:
48 HandlerStatusResponse
49 """
51 if self.is_connected:
52 return StatusResponse(success=True)
54 try:
55 self.connection = client(
56 'athena',
57 aws_access_key_id=self.connection_data['aws_access_key_id'],
58 aws_secret_access_key=self.connection_data['aws_secret_access_key'],
59 region_name=self.connection_data['region_name'],
60 )
61 self.is_connected = True
62 return StatusResponse(success=True)
63 except Exception as e:
64 logger.error(f'Failed to connect to Athena: {str(e)}')
65 return StatusResponse(success=False, error_message=str(e))
67 def disconnect(self):
68 """
69 Close any existing connections.
70 """
71 if self.is_connected:
72 self.connection = None
73 self.is_connected = False
75 def check_connection(self) -> StatusResponse:
76 """
77 Check connection to the handler.
78 Returns:
79 HandlerStatusResponse
80 """
81 if self.is_connected:
82 return StatusResponse(success=True)
83 else:
84 return self.connect()
86 def native_query(self, query: str) -> Response:
87 """
88 Receive raw query and act upon it somehow.
89 Args:
90 query (str): query in native format
91 Returns:
92 HandlerResponse
93 """
94 need_to_close = not self.is_connected
95 self.connect()
97 try:
98 response = self.connection.start_query_execution(
99 QueryString=query,
100 QueryExecutionContext={
101 'Database': self.connection_data['database'],
102 },
103 ResultConfiguration={
104 'OutputLocation': self.connection_data['results_output_location'],
105 },
106 WorkGroup=self.connection_data['workgroup'],
107 )
108 query_execution_id = response['QueryExecutionId']
109 status = self._wait_for_query_to_complete(query_execution_id)
110 if status == 'SUCCEEDED':
111 result = self.connection.get_query_results(
112 QueryExecutionId=query_execution_id
113 )
114 df = self._parse_query_result(result)
115 response = Response(RESPONSE_TYPE.TABLE, data_frame=df)
116 else:
117 response = Response(RESPONSE_TYPE.ERROR, error_message='Query failed or was cancelled')
118 except Exception as e:
119 logger.error(f'Error executing query in Athena: {str(e)}')
120 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e))
122 if need_to_close:
123 self.disconnect()
125 return response
127 def query(self, query: ASTNode) -> Response:
128 """
129 Receive query as AST (abstract syntax tree) and act upon it somehow.
130 Args:
131 query (ASTNode): sql query represented as AST. May be any kind
132 of query: SELECT, INSERT, DELETE, etc
133 Returns:
134 HandlerResponse
135 """
137 return self.native_query(query.to_string())
139 def get_tables(self) -> Response:
140 """
141 Return list of entities that will be accessible as tables.
142 Returns:
143 Response: A response object containing the list of tables and
144 """
146 query = """
147 select
148 table_schema,
149 table_name,
150 table_type
151 from
152 information_schema.tables
153 where
154 table_schema not in ('information_schema')
155 and table_type in ('BASE TABLE', 'VIEW')
156 """
157 return self.native_query(query)
159 def get_columns(self, table_name: str) -> Response:
160 """
161 Returns a list of entity columns.
162 Args:
163 table_name (str): name of one of tables returned by self.get_tables()
164 Returns:
165 Response: A response object containing the column details
166 Raises:
167 ValueError: If the 'table_name' is not a valid string.
168 """
169 if not table_name or not isinstance(table_name, str):
170 raise ValueError("Invalid value for table name provided.")
172 query = f"""
173 select
174 column_name as "Field",
175 data_type as "Type"
176 from
177 information_schema.columns
178 where
179 table_name = '{table_name}'
180 """
181 return self.native_query(query)
183 def _wait_for_query_to_complete(self, query_execution_id: str) -> str:
184 """
185 Wait for the Athena query to complete.
186 Args:
187 query_execution_id (str): ID of the query to wait for
188 Returns:
189 str: Query execution status
190 """
191 while True:
192 response = self.connection.get_query_execution(QueryExecutionId=query_execution_id)
193 status = response['QueryExecution']['Status']['State']
194 if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
195 return status
197 check_interval = self.connection_data.get('check_interval', 0)
198 if isinstance(check_interval, str) and check_interval.strip().isdigit():
199 check_interval = int(check_interval)
200 if check_interval > 0:
201 time.sleep(check_interval)
203 def _parse_query_result(self, result: dict) -> pd.DataFrame:
204 """
205 Parse the result of the Athena query into a DataFrame.
206 Args:
207 result: Result of the Athena query
208 Returns:
209 pd.DataFrame: Query result as a DataFrame
210 """
212 if not result or 'ResultSet' not in result or 'Rows' not in result['ResultSet']:
213 return pd.DataFrame()
215 rows = result['ResultSet']['Rows']
216 headers = [col['VarCharValue'] for col in rows[0]['Data']]
217 data = [[col.get('VarCharValue') for col in row['Data']] for row in rows[1:]]
218 return pd.DataFrame(data, columns=headers)