Coverage for mindsdb / integrations / handlers / cloud_spanner_handler / cloud_spanner_handler.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2 

3from google.oauth2 import service_account 

4from google.cloud.spanner_dbapi.connection import connect, Connection 

5from google.cloud.sqlalchemy_spanner import SpannerDialect 

6 

7import pandas as pd 

8from mindsdb_sql_parser import parse_sql 

9from mindsdb_sql_parser.ast.base import ASTNode 

10from mindsdb_sql_parser.ast import CreateTable, Function 

11from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

12 

13from mindsdb.integrations.libs.base import DatabaseHandler 

14from mindsdb.integrations.libs.response import RESPONSE_TYPE 

15from mindsdb.integrations.libs.response import HandlerResponse as Response 

16from mindsdb.integrations.libs.response import ( 

17 HandlerStatusResponse as StatusResponse, 

18) 

19from mindsdb.utilities import log 

20 

21logger = log.getLogger(__name__) 

22 

23 

24class CloudSpannerHandler(DatabaseHandler): 

25 """This handler handles connection and execution of the Cloud Spanner statements.""" 

26 

27 name = 'cloud_spanner' 

28 

29 def __init__(self, name: str, **kwargs): 

30 super().__init__(name) 

31 self.parser = parse_sql 

32 self.connection_data = kwargs.get('connection_data') 

33 self.dialect = self.connection_data.get('dialect', 'googlesql') 

34 

35 if self.dialect == 'postgres': 

36 self.renderer = SqlalchemyRender('postgres') 

37 else: 

38 self.renderer = SqlalchemyRender(SpannerDialect) 

39 

40 self.connection = None 

41 self.is_connected = False 

42 

43 def __del__(self): 

44 if self.is_connected is True: 

45 self.disconnect() 

46 

47 def connect(self) -> Connection: 

48 """Connect to a Cloud Spanner database. 

49 

50 Returns: 

51 Connection: The database connection. 

52 """ 

53 

54 if self.is_connected is True: 

55 return self.connection 

56 

57 args = { 

58 'database_id': self.connection_data.get('database_id'), 

59 'instance_id': self.connection_data.get('instance_id'), 

60 'project': self.connection_data.get('project'), 

61 'credentials': self.connection_data.get('credentials'), 

62 } 

63 

64 args['credentials'] = service_account.Credentials.from_service_account_info( 

65 json.loads(args['credentials']) 

66 ) 

67 self.connection = connect(**args) 

68 self.is_connected = True 

69 

70 return self.connection 

71 

72 def disconnect(self): 

73 """Close the database connection.""" 

74 

75 if self.is_connected is False: 

76 return 

77 

78 self.connection.close() 

79 self.is_connected = False 

80 

81 def check_connection(self) -> StatusResponse: 

82 """Check the connection to the Cloud Spanner database. 

83 

84 Returns: 

85 StatusResponse: Connection success status and error message if an error occurs. 

86 """ 

87 

88 response = StatusResponse(False) 

89 

90 try: 

91 self.connect() 

92 response.success = True 

93 except Exception as e: 

94 logger.error( 

95 f'Error connecting to Cloud Spanner {self.connection_data["database_id"]}, {e}!' 

96 ) 

97 response.error_message = str(e) 

98 finally: 

99 if response.success is True and self.is_connected: 

100 self.disconnect() 

101 if response.success is False and self.is_connected: 

102 self.is_connected = False 

103 

104 return response 

105 

106 def native_query(self, query: str) -> Response: 

107 """Execute a SQL query. 

108 

109 Args: 

110 query (str): The SQL query to execute. 

111 

112 Returns: 

113 Response: The query result. 

114 """ 

115 

116 connection = self.connect() 

117 cursor = connection.cursor() 

118 

119 try: 

120 cursor.execute(query) 

121 

122 # The cursor description check indicates if there are any results. 

123 # This is required as spanner_dbapi will fail on a fetchall() call on an empty cursor. 

124 if cursor.description: 

125 result = cursor.fetchall() 

126 response = Response( 

127 RESPONSE_TYPE.TABLE, 

128 data_frame=pd.DataFrame( 

129 result, columns=[x[0] for x in cursor.description] 

130 ), 

131 ) 

132 else: 

133 response = Response(RESPONSE_TYPE.OK) 

134 

135 connection.commit() 

136 except Exception as e: 

137 logger.error( 

138 f'Error running query: {query} on {self.connection_data["database_id"]}!' 

139 ) 

140 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) 

141 

142 cursor.close() 

143 if self.is_connected: 

144 self.disconnect() 

145 

146 return response 

147 

148 def query(self, query: ASTNode) -> Response: 

149 """Render and execute a SQL query. 

150 

151 Args: 

152 query (ASTNode): The SQL query. 

153 

154 Returns: 

155 Response: The query result. 

156 """ 

157 

158 # check primary key for table: 

159 if isinstance(query, CreateTable) and query.columns is not None: 

160 id_col = None 

161 has_primary = False 

162 for col in query.columns: 

163 if col.name.lower() == 'id': 

164 id_col = col 

165 if col.is_primary_key: 

166 has_primary = True 

167 # if no other primary keys use id 

168 if not has_primary and id_col: 

169 id_col.is_primary_key = True 

170 id_col.default = Function('GENERATE_UUID', args=[]) 

171 

172 query_str = self.renderer.get_string(query, with_failback=True) 

173 

174 return self.native_query(query_str) 

175 

176 def get_tables(self) -> Response: 

177 """Get a list of all the tables in the database. 

178 

179 Returns: 

180 Response: Names of the tables in the database. 

181 """ 

182 

183 query = ''' 

184 SELECT 

185 t.table_name 

186 FROM 

187 information_schema.tables AS t 

188 WHERE 

189 t.table_schema = '' 

190 ''' 

191 result = self.native_query(query) 

192 df = result.data_frame 

193 

194 if df is not None: 

195 result.data_frame = df.rename(columns={df.columns[0]: 'table_name'}) 

196 

197 return result 

198 

199 def get_columns(self, table_name: str) -> Response: 

200 """Get details about a table. 

201 

202 Args: 

203 table_name (str): Name of the table to retrieve details of. 

204 

205 Returns: 

206 Response: Details of the table. 

207 """ 

208 

209 query = f''' 

210 SELECT 

211 t.column_name, 

212 t.spanner_type, 

213 t.is_nullable 

214 FROM 

215 information_schema.columns AS t 

216 WHERE 

217 t.table_name = '{table_name}' 

218 ''' 

219 return self.native_query(query)