Coverage for mindsdb / integrations / handlers / cloud_spanner_handler / cloud_spanner_handler.py: 0%
97 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
3from google.oauth2 import service_account
4from google.cloud.spanner_dbapi.connection import connect, Connection
5from google.cloud.sqlalchemy_spanner import SpannerDialect
7import pandas as pd
8from mindsdb_sql_parser import parse_sql
9from mindsdb_sql_parser.ast.base import ASTNode
10from mindsdb_sql_parser.ast import CreateTable, Function
11from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender
13from mindsdb.integrations.libs.base import DatabaseHandler
14from mindsdb.integrations.libs.response import RESPONSE_TYPE
15from mindsdb.integrations.libs.response import HandlerResponse as Response
16from mindsdb.integrations.libs.response import (
17 HandlerStatusResponse as StatusResponse,
18)
19from mindsdb.utilities import log
21logger = log.getLogger(__name__)
24class CloudSpannerHandler(DatabaseHandler):
25 """This handler handles connection and execution of the Cloud Spanner statements."""
27 name = 'cloud_spanner'
29 def __init__(self, name: str, **kwargs):
30 super().__init__(name)
31 self.parser = parse_sql
32 self.connection_data = kwargs.get('connection_data')
33 self.dialect = self.connection_data.get('dialect', 'googlesql')
35 if self.dialect == 'postgres':
36 self.renderer = SqlalchemyRender('postgres')
37 else:
38 self.renderer = SqlalchemyRender(SpannerDialect)
40 self.connection = None
41 self.is_connected = False
43 def __del__(self):
44 if self.is_connected is True:
45 self.disconnect()
47 def connect(self) -> Connection:
48 """Connect to a Cloud Spanner database.
50 Returns:
51 Connection: The database connection.
52 """
54 if self.is_connected is True:
55 return self.connection
57 args = {
58 'database_id': self.connection_data.get('database_id'),
59 'instance_id': self.connection_data.get('instance_id'),
60 'project': self.connection_data.get('project'),
61 'credentials': self.connection_data.get('credentials'),
62 }
64 args['credentials'] = service_account.Credentials.from_service_account_info(
65 json.loads(args['credentials'])
66 )
67 self.connection = connect(**args)
68 self.is_connected = True
70 return self.connection
72 def disconnect(self):
73 """Close the database connection."""
75 if self.is_connected is False:
76 return
78 self.connection.close()
79 self.is_connected = False
81 def check_connection(self) -> StatusResponse:
82 """Check the connection to the Cloud Spanner database.
84 Returns:
85 StatusResponse: Connection success status and error message if an error occurs.
86 """
88 response = StatusResponse(False)
90 try:
91 self.connect()
92 response.success = True
93 except Exception as e:
94 logger.error(
95 f'Error connecting to Cloud Spanner {self.connection_data["database_id"]}, {e}!'
96 )
97 response.error_message = str(e)
98 finally:
99 if response.success is True and self.is_connected:
100 self.disconnect()
101 if response.success is False and self.is_connected:
102 self.is_connected = False
104 return response
106 def native_query(self, query: str) -> Response:
107 """Execute a SQL query.
109 Args:
110 query (str): The SQL query to execute.
112 Returns:
113 Response: The query result.
114 """
116 connection = self.connect()
117 cursor = connection.cursor()
119 try:
120 cursor.execute(query)
122 # The cursor description check indicates if there are any results.
123 # This is required as spanner_dbapi will fail on a fetchall() call on an empty cursor.
124 if cursor.description:
125 result = cursor.fetchall()
126 response = Response(
127 RESPONSE_TYPE.TABLE,
128 data_frame=pd.DataFrame(
129 result, columns=[x[0] for x in cursor.description]
130 ),
131 )
132 else:
133 response = Response(RESPONSE_TYPE.OK)
135 connection.commit()
136 except Exception as e:
137 logger.error(
138 f'Error running query: {query} on {self.connection_data["database_id"]}!'
139 )
140 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e))
142 cursor.close()
143 if self.is_connected:
144 self.disconnect()
146 return response
148 def query(self, query: ASTNode) -> Response:
149 """Render and execute a SQL query.
151 Args:
152 query (ASTNode): The SQL query.
154 Returns:
155 Response: The query result.
156 """
158 # check primary key for table:
159 if isinstance(query, CreateTable) and query.columns is not None:
160 id_col = None
161 has_primary = False
162 for col in query.columns:
163 if col.name.lower() == 'id':
164 id_col = col
165 if col.is_primary_key:
166 has_primary = True
167 # if no other primary keys use id
168 if not has_primary and id_col:
169 id_col.is_primary_key = True
170 id_col.default = Function('GENERATE_UUID', args=[])
172 query_str = self.renderer.get_string(query, with_failback=True)
174 return self.native_query(query_str)
176 def get_tables(self) -> Response:
177 """Get a list of all the tables in the database.
179 Returns:
180 Response: Names of the tables in the database.
181 """
183 query = '''
184 SELECT
185 t.table_name
186 FROM
187 information_schema.tables AS t
188 WHERE
189 t.table_schema = ''
190 '''
191 result = self.native_query(query)
192 df = result.data_frame
194 if df is not None:
195 result.data_frame = df.rename(columns={df.columns[0]: 'table_name'})
197 return result
199 def get_columns(self, table_name: str) -> Response:
200 """Get details about a table.
202 Args:
203 table_name (str): Name of the table to retrieve details of.
205 Returns:
206 Response: Details of the table.
207 """
209 query = f'''
210 SELECT
211 t.column_name,
212 t.spanner_type,
213 t.is_nullable
214 FROM
215 information_schema.columns AS t
216 WHERE
217 t.table_name = '{table_name}'
218 '''
219 return self.native_query(query)