Coverage for mindsdb / integrations / handlers / gcs_handler / gcs_handler.py: 0%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from contextlib import contextmanager 

2 

3import json 

4import duckdb 

5import pandas as pd 

6import fsspec 

7import google.auth 

8from google.cloud import storage 

9from typing import Text, Dict, Optional, List 

10from duckdb import DuckDBPyConnection 

11 

12from mindsdb.integrations.handlers.gcs_handler.gcs_tables import ( 

13 ListFilesTable, 

14 FileTable 

15) 

16from mindsdb_sql_parser.ast.base import ASTNode 

17from mindsdb_sql_parser.ast import Select, Identifier, Insert, Star, Constant 

18 

19from mindsdb.utilities import log 

20from mindsdb.integrations.libs.response import ( 

21 HandlerStatusResponse as StatusResponse, 

22 HandlerResponse as Response, 

23 RESPONSE_TYPE 

24) 

25 

26from mindsdb.integrations.libs.api_handler import APIHandler 

27 

28logger = log.getLogger(__name__) 

29 

30 

31class GcsHandler(APIHandler): 

32 """ 

33 This handler handles connection and execution of the SQL statements on GCS. 

34 """ 

35 

36 name = 'gcs' 

37 

38 supported_file_formats = ['csv', 'tsv', 'json', 'parquet'] 

39 

40 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs): 

41 """ 

42 Initializes the handler. 

43 

44 Args: 

45 name (Text): The name of the handler instance. 

46 connection_data (Dict): The connection data required to connect to the GCS account. 

47 kwargs: Arbitrary keyword arguments. 

48 """ 

49 super().__init__(name) 

50 self.connection_data = connection_data 

51 self.kwargs = kwargs 

52 self.is_select_query = False 

53 self.service_account_json = None 

54 self.connection = None 

55 

56 if 'service_account_keys' not in self.connection_data and 'service_account_json' not in self.connection_data: 

57 raise ValueError('service_account_keys or service_account_json parameter must be provided.') 

58 

59 if 'service_account_json' in self.connection_data: 

60 self.service_account_json = self.connection_data["service_account_json"] 

61 

62 if 'service_account_keys' in self.connection_data: 

63 with open(self.connection_data["service_account_keys"], "r") as f: 

64 self.service_account_json = json.loads(f.read()) 

65 

66 self.is_connected = False 

67 

68 self._files_table = ListFilesTable(self) 

69 

70 def __del__(self): 

71 if self.is_connected is True: 

72 self.disconnect() 

73 

74 def connect(self) -> DuckDBPyConnection: 

75 """ 

76 Establishes a connection to the GCS account via DuckDB. 

77 

78 Raises: 

79 KeyError: If the required connection parameters are not provided. 

80 

81 Returns: 

82 DuckDBPyConnection : A client object to the GCS account. 

83 """ 

84 if self.is_connected is True: 

85 return self.connection 

86 

87 # Connect to GCS and configure mandatory credentials. 

88 self.connection = self._connect_storage_client() 

89 self.is_connected = True 

90 

91 return self.connection 

92 

93 @contextmanager 

94 def _connect_duckdb(self): 

95 """ 

96 Creates temporal duckdb database which is able to connect to the GCS account. 

97 Have to be used as context manager 

98 

99 Returns: 

100 DuckDBPyConnection 

101 """ 

102 # Connect to GCS via DuckDB. 

103 duckdb_conn = duckdb.connect(":memory:") 

104 

105 # Configure mandatory credentials. 

106 credentials, project_id = google.auth.load_credentials_from_dict(self.service_account_json) 

107 gcs = fsspec.filesystem("gcs", project=project_id, credentials=credentials) 

108 duckdb_conn = duckdb.connect() 

109 duckdb_conn.register_filesystem(gcs) 

110 

111 try: 

112 yield duckdb_conn 

113 finally: 

114 duckdb_conn.close() 

115 

116 def _connect_storage_client(self) -> storage.Client: 

117 """ 

118 Establishes a connection to the GCS account via google-cloud-storage. 

119 

120 Returns: 

121 storage.Client: A client object to the GCS account. 

122 """ 

123 return storage.Client.from_service_account_info(self.service_account_json) 

124 

125 def disconnect(self): 

126 """ 

127 Closes the connection to the GCP account if it's currently open. 

128 """ 

129 if not self.is_connected: 

130 return 

131 self.connection.close() 

132 self.is_connected = False 

133 

134 def check_connection(self) -> StatusResponse: 

135 """ 

136 Checks the status of the connection to the GCS bucket. 

137 

138 Returns: 

139 StatusResponse: An object containing the success status and an error message if an error occurs. 

140 """ 

141 response = StatusResponse(False) 

142 need_to_close = self.is_connected is False 

143 

144 # Check connection via storage client. 

145 try: 

146 storage_client = self._connect_storage_client() 

147 if 'bucket' in self.connection_data: 

148 storage_client.get_bucket(self.connection_data['bucket']) 

149 else: 

150 storage_client.list_buckets() 

151 response.success = True 

152 storage_client.close() 

153 except Exception as e: 

154 logger.error(f'Error connecting to GCS with the given credentials, {e}!') 

155 response.error_message = str(e) 

156 

157 if response.success and need_to_close: 

158 self.disconnect() 

159 

160 elif not response.success and self.is_connected: 

161 self.is_connected = False 

162 

163 return response 

164 

165 def _get_bucket(self, key): 

166 if 'bucket' in self.connection_data: 

167 return self.connection_data['bucket'], key 

168 

169 # get bucket from first part of the key 

170 ar = key.split('/') 

171 return ar[0], '/'.join(ar[1:]) 

172 

173 def read_as_table(self, key) -> pd.DataFrame: 

174 """ 

175 Read object as dataframe. Uses duckdb 

176 """ 

177 

178 bucket, key = self._get_bucket(key) 

179 

180 with self._connect_duckdb() as connection: 

181 

182 cursor = connection.execute(f"SELECT * FROM 'gs://{bucket}/{key}'") 

183 

184 return cursor.fetchdf() 

185 

186 def _read_as_content(self, key) -> None: 

187 """ 

188 Read object as content 

189 """ 

190 bucket, key = self._get_bucket(key) 

191 

192 client = self.connect() 

193 

194 bucket = client.bucket(bucket) 

195 blob = bucket.blob(key) 

196 return blob.download_as_string() 

197 

198 def add_data_to_table(self, key, df) -> None: 

199 """ 

200 Writes the table to a file in the gcs bucket. 

201 

202 Raises: 

203 CatalogException: If the table does not exist in the DuckDB connection. 

204 """ 

205 

206 # Check if the file exists in the gcs bucket. 

207 bucket, key = self._get_bucket(key) 

208 

209 storage_client = self._connect_storage_client() 

210 bucketObj = storage_client.bucket(bucket) 

211 stats = storage.Blob(bucket=bucketObj, name=key).exists(storage_client) 

212 storage_client.close() 

213 if not stats: 

214 raise Exception(f'Error querying the file {key} in the bucket {bucket}!') 

215 

216 with self._connect_duckdb() as connection: 

217 # copy 

218 connection.execute(f"CREATE TABLE tmp_table AS SELECT * FROM 'gs://{bucket}/{key}'") 

219 

220 # insert 

221 connection.execute("INSERT INTO tmp_table BY NAME SELECT * FROM df") 

222 

223 # upload 

224 connection.execute(f"COPY tmp_table TO 'gs://{bucket}/{key}'") 

225 

226 def query(self, query: ASTNode) -> Response: 

227 """ 

228 Executes a SQL query represented by an ASTNode and retrieves the data. 

229 

230 Args: 

231 query (ASTNode): An ASTNode representing the SQL query to be executed. 

232 

233 Raises: 

234 ValueError: If the file format is not supported or the file does not exist in the GCS bucket. 

235 

236 Returns: 

237 Response: The response from the `native_query` method, containing the result of the SQL query execution. 

238 """ 

239 

240 self.connect() 

241 

242 if isinstance(query, Select): 

243 table_name = query.from_table.parts[-1] 

244 

245 if table_name == 'files': 

246 table = self._files_table 

247 df = table.select(query) 

248 

249 # add content 

250 has_content = False 

251 for target in query.targets: 

252 if isinstance(target, Identifier) and target.parts[-1].lower() == 'content': 

253 has_content = True 

254 break 

255 if has_content: 

256 df['content'] = df['path'].apply(self._read_as_content) 

257 else: 

258 extension = table_name.split('.')[-1] 

259 if extension not in self.supported_file_formats: 

260 logger.error(f'The file format {extension} is not supported!') 

261 raise ValueError(f'The file format {extension} is not supported!') 

262 

263 table = FileTable(self, table_name=table_name) 

264 df = table.select(query) 

265 

266 response = Response( 

267 RESPONSE_TYPE.TABLE, 

268 data_frame=df 

269 ) 

270 elif isinstance(query, Insert): 

271 table_name = query.table.parts[-1] 

272 table = FileTable(self, table_name=table_name) 

273 table.insert(query) 

274 response = Response(RESPONSE_TYPE.OK) 

275 else: 

276 raise NotImplementedError 

277 

278 return response 

279 

280 def get_objects(self, limit=None, buckets=None) -> List[dict]: 

281 storage_client = self._connect_storage_client() 

282 if "bucket" in self.connection_data: 

283 add_bucket_to_name = False 

284 scan_buckets = [self.connection_data["bucket"]] 

285 else: 

286 add_bucket_to_name = True 

287 scan_buckets = [b.name for b in storage_client.list_buckets()] 

288 

289 objects = [] 

290 for bucket in scan_buckets: 

291 if buckets is not None and bucket not in buckets: 

292 continue 

293 

294 blobs = storage_client.list_blobs(bucket) 

295 if not blobs: 

296 continue 

297 

298 for blob in blobs: 

299 if blob.storage_class != 'STANDARD': 

300 continue 

301 

302 obj = {} 

303 obj['Bucket'] = bucket 

304 if add_bucket_to_name: 

305 # bucket is part of the name 

306 obj['Key'] = f'{bucket}/{blob.name}' 

307 objects.append(obj) 

308 if limit is not None and len(objects) >= limit: 

309 break 

310 

311 return objects 

312 

313 def get_tables(self) -> Response: 

314 """ 

315 Retrieves a list of tables (objects) in the gcs bucket. 

316 

317 Each object is considered a table. Only the supported file formats are considered as tables. 

318 

319 Returns: 

320 Response: A response object containing the list of tables and views, formatted as per the `Response` class. 

321 """ 

322 

323 # Get only the supported file formats. 

324 # Wrap the object names with backticks to prevent SQL syntax errors. 

325 supported_names = [ 

326 f"`{obj['Key']}`" 

327 for obj in self.get_objects() 

328 if obj['Key'].split('.')[-1] in self.supported_file_formats 

329 ] 

330 

331 # virtual table with list of files 

332 supported_names.insert(0, 'files') 

333 

334 response = Response( 

335 RESPONSE_TYPE.TABLE, 

336 data_frame=pd.DataFrame( 

337 supported_names, 

338 columns=['table_name'] 

339 ) 

340 ) 

341 

342 return response 

343 

344 def get_columns(self, table_name: str) -> Response: 

345 """ 

346 Retrieves column details for a specified table (object) in the gcs bucket. 

347 

348 Args: 

349 table_name (Text): The name of the table for which to retrieve column information. 

350 

351 Raises: 

352 ValueError: If the 'table_name' is not a valid string. 

353 

354 Returns: 

355 Response: A response object containing the column details, formatted as per the `Response` class. 

356 """ 

357 query = Select( 

358 targets=[Star()], 

359 from_table=Identifier(parts=[table_name]), 

360 limit=Constant(1) 

361 ) 

362 

363 result = self.query(query) 

364 

365 response = Response( 

366 RESPONSE_TYPE.TABLE, 

367 data_frame=pd.DataFrame( 

368 { 

369 'column_name': result.data_frame.columns, 

370 'data_type': [data_type if data_type != 'object' else 'string' for data_type in result.data_frame.dtypes] 

371 } 

372 ) 

373 ) 

374 

375 return response