Coverage for mindsdb / integrations / handlers / athena_handler / athena_handler.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import time 

2import pandas as pd 

3from boto3 import client 

4from typing import Optional 

5 

6from mindsdb_sql_parser import parse_sql 

7from mindsdb.integrations.libs.base import DatabaseHandler 

8from mindsdb_sql_parser.ast.base import ASTNode 

9from mindsdb.utilities import log 

10from mindsdb.integrations.libs.response import ( 

11 HandlerStatusResponse as StatusResponse, 

12 HandlerResponse as Response, 

13 RESPONSE_TYPE 

14) 

15 

16logger = log.getLogger(__name__) 

17 

18 

19class AthenaHandler(DatabaseHandler): 

20 """ 

21 This handler handles connection and execution of the Athena statements. 

22 """ 

23 

24 name = 'athena' 

25 

26 def __init__(self, name: str, connection_data: Optional[dict], **kwargs): 

27 """ 

28 Initialize the handler. 

29 Args: 

30 name (str): name of particular handler instance 

31 connection_data (dict): parameters for connecting to the database 

32 **kwargs: arbitrary keyword arguments. 

33 """ 

34 super().__init__(name) 

35 self.parser = parse_sql 

36 self.dialect = 'athena' 

37 

38 self.connection_data = connection_data 

39 self.kwargs = kwargs 

40 

41 self.connection = None 

42 self.is_connected = False 

43 

44 def connect(self) -> StatusResponse: 

45 """ 

46 Set up the connection required by the handler. 

47 Returns: 

48 HandlerStatusResponse 

49 """ 

50 

51 if self.is_connected: 

52 return StatusResponse(success=True) 

53 

54 try: 

55 self.connection = client( 

56 'athena', 

57 aws_access_key_id=self.connection_data['aws_access_key_id'], 

58 aws_secret_access_key=self.connection_data['aws_secret_access_key'], 

59 region_name=self.connection_data['region_name'], 

60 ) 

61 self.is_connected = True 

62 return StatusResponse(success=True) 

63 except Exception as e: 

64 logger.error(f'Failed to connect to Athena: {str(e)}') 

65 return StatusResponse(success=False, error_message=str(e)) 

66 

67 def disconnect(self): 

68 """ 

69 Close any existing connections. 

70 """ 

71 if self.is_connected: 

72 self.connection = None 

73 self.is_connected = False 

74 

75 def check_connection(self) -> StatusResponse: 

76 """ 

77 Check connection to the handler. 

78 Returns: 

79 HandlerStatusResponse 

80 """ 

81 if self.is_connected: 

82 return StatusResponse(success=True) 

83 else: 

84 return self.connect() 

85 

86 def native_query(self, query: str) -> Response: 

87 """ 

88 Receive raw query and act upon it somehow. 

89 Args: 

90 query (str): query in native format 

91 Returns: 

92 HandlerResponse 

93 """ 

94 need_to_close = not self.is_connected 

95 self.connect() 

96 

97 try: 

98 response = self.connection.start_query_execution( 

99 QueryString=query, 

100 QueryExecutionContext={ 

101 'Database': self.connection_data['database'], 

102 }, 

103 ResultConfiguration={ 

104 'OutputLocation': self.connection_data['results_output_location'], 

105 }, 

106 WorkGroup=self.connection_data['workgroup'], 

107 ) 

108 query_execution_id = response['QueryExecutionId'] 

109 status = self._wait_for_query_to_complete(query_execution_id) 

110 if status == 'SUCCEEDED': 

111 result = self.connection.get_query_results( 

112 QueryExecutionId=query_execution_id 

113 ) 

114 df = self._parse_query_result(result) 

115 response = Response(RESPONSE_TYPE.TABLE, data_frame=df) 

116 else: 

117 response = Response(RESPONSE_TYPE.ERROR, error_message='Query failed or was cancelled') 

118 except Exception as e: 

119 logger.error(f'Error executing query in Athena: {str(e)}') 

120 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) 

121 

122 if need_to_close: 

123 self.disconnect() 

124 

125 return response 

126 

127 def query(self, query: ASTNode) -> Response: 

128 """ 

129 Receive query as AST (abstract syntax tree) and act upon it somehow. 

130 Args: 

131 query (ASTNode): sql query represented as AST. May be any kind 

132 of query: SELECT, INSERT, DELETE, etc 

133 Returns: 

134 HandlerResponse 

135 """ 

136 

137 return self.native_query(query.to_string()) 

138 

139 def get_tables(self) -> Response: 

140 """ 

141 Return list of entities that will be accessible as tables. 

142 Returns: 

143 Response: A response object containing the list of tables and 

144 """ 

145 

146 query = """ 

147 select 

148 table_schema, 

149 table_name, 

150 table_type 

151 from 

152 information_schema.tables 

153 where 

154 table_schema not in ('information_schema') 

155 and table_type in ('BASE TABLE', 'VIEW') 

156 """ 

157 return self.native_query(query) 

158 

159 def get_columns(self, table_name: str) -> Response: 

160 """ 

161 Returns a list of entity columns. 

162 Args: 

163 table_name (str): name of one of tables returned by self.get_tables() 

164 Returns: 

165 Response: A response object containing the column details 

166 Raises: 

167 ValueError: If the 'table_name' is not a valid string. 

168 """ 

169 if not table_name or not isinstance(table_name, str): 

170 raise ValueError("Invalid value for table name provided.") 

171 

172 query = f""" 

173 select 

174 column_name as "Field", 

175 data_type as "Type" 

176 from 

177 information_schema.columns 

178 where 

179 table_name = '{table_name}' 

180 """ 

181 return self.native_query(query) 

182 

183 def _wait_for_query_to_complete(self, query_execution_id: str) -> str: 

184 """ 

185 Wait for the Athena query to complete. 

186 Args: 

187 query_execution_id (str): ID of the query to wait for 

188 Returns: 

189 str: Query execution status 

190 """ 

191 while True: 

192 response = self.connection.get_query_execution(QueryExecutionId=query_execution_id) 

193 status = response['QueryExecution']['Status']['State'] 

194 if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: 

195 return status 

196 

197 check_interval = self.connection_data.get('check_interval', 0) 

198 if isinstance(check_interval, str) and check_interval.strip().isdigit(): 

199 check_interval = int(check_interval) 

200 if check_interval > 0: 

201 time.sleep(check_interval) 

202 

203 def _parse_query_result(self, result: dict) -> pd.DataFrame: 

204 """ 

205 Parse the result of the Athena query into a DataFrame. 

206 Args: 

207 result: Result of the Athena query 

208 Returns: 

209 pd.DataFrame: Query result as a DataFrame 

210 """ 

211 

212 if not result or 'ResultSet' not in result or 'Rows' not in result['ResultSet']: 

213 return pd.DataFrame() 

214 

215 rows = result['ResultSet']['Rows'] 

216 headers = [col['VarCharValue'] for col in rows[0]['Data']] 

217 data = [[col.get('VarCharValue') for col in row['Data']] for row in rows[1:]] 

218 return pd.DataFrame(data, columns=headers)