Coverage for mindsdb / integrations / handlers / aurora_handler / aurora_handler.py: 0%
49 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional
3import boto3
5from mindsdb_sql_parser.ast.base import ASTNode
7from mindsdb.utilities import log
8from mindsdb.integrations.libs.base import DatabaseHandler
9from mindsdb.integrations.libs.response import (
10 HandlerStatusResponse as StatusResponse
11)
12from mindsdb.integrations.handlers.mysql_handler.mysql_handler import MySQLHandler
13from mindsdb.integrations.handlers.postgres_handler.postgres_handler import PostgresHandler
15logger = log.getLogger(__name__)
18class AuroraHandler(DatabaseHandler):
19 """
20 This handler handles connection and execution of the Amazon Aurora statements.
21 """
23 name = 'aurora'
25 def __init__(self, name: str, connection_data: Optional[dict], **kwargs):
26 """
27 Initialize the handler.
28 Args:
29 name (str): name of particular handler instance
30 connection_data (dict): parameters for connecting to the database
31 **kwargs: arbitrary keyword arguments.
32 """
33 super().__init__(name)
35 self.dialect = 'aurora'
36 self.connection_data = connection_data
37 self.kwargs = kwargs
39 database_engine = ""
40 if 'db_engine' not in self.connection_data:
41 database_engine = self.get_database_engine()
43 if self.connection_data['db_engine'] == 'mysql' or database_engine == 'aurora':
44 self.db = MySQLHandler(
45 name=name + 'mysql',
46 connection_data=self.connection_data
47 )
48 elif self.connection_data['db_engine'] == 'postgresql' or database_engine == 'aurora-postgresql':
49 self.db = PostgresHandler(
50 name=name + 'postgresql',
51 connection_data={key: self.connection_data[key] for key in self.connection_data if key != 'db_engine'}
52 )
53 else:
54 raise Exception("The database engine should be either MySQL or PostgreSQL!")
56 def get_database_engine(self):
57 try:
58 session = boto3.session.Session(
59 aws_access_key_id=self.connection_data['aws_access_key_id'],
60 aws_secret_access_key=self.connection_data['aws_secret_access_key']
61 )
63 rds = session.client('rds')
65 response = rds.describe_db_clusters()
67 return next(item for item in response if item["DBClusterIdentifier"] == self.connection_data['host'].split('.')[0])['Engine']
68 except Exception as e:
69 logger.error(f'Error connecting to Aurora, {e}!')
70 logger.error('If the database engine is not provided as a parameter, please ensure that the credentials for the AWS account are passed in instead!')
72 def __del__(self):
73 self.db.__del__()
75 def connect(self) -> StatusResponse:
76 """
77 Set up the connection required by the handler.
78 Returns:
79 HandlerStatusResponse
80 """
82 return self.db.connect()
84 def disconnect(self):
85 """
86 Close any existing connections.
87 """
89 return self.db.disconnect()
91 def check_connection(self) -> StatusResponse:
92 """
93 Check connection to the handler.
94 Returns:
95 HandlerStatusResponse
96 """
98 return self.db.check_connection()
100 def native_query(self, query: str) -> StatusResponse:
101 """
102 Receive raw query and act upon it somehow.
103 Args:
104 query (str): query in native format
105 Returns:
106 HandlerResponse
107 """
109 return self.db.native_query(query)
111 def query(self, query: ASTNode) -> StatusResponse:
112 """
113 Receive query as AST (abstract syntax tree) and act upon it somehow.
114 Args:
115 query (ASTNode): sql query represented as AST. May be any kind
116 of query: SELECT, INTSERT, DELETE, etc
117 Returns:
118 HandlerResponse
119 """
121 return self.db.query(query)
123 def get_tables(self) -> StatusResponse:
124 """
125 Return list of entities that will be accessible as tables.
126 Returns:
127 HandlerResponse
128 """
130 return self.db.get_tables()
132 def get_columns(self, table_name: str) -> StatusResponse:
133 """
134 Returns a list of entity columns.
135 Args:
136 table_name (str): name of one of tables returned by self.get_tables()
137 Returns:
138 HandlerResponse
139 """
141 return self.db.get_columns(table_name)