Coverage for mindsdb / integrations / handlers / bigquery_handler / bigquery_handler.py: 77%

162 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2from typing import Any, Dict, Optional, Text 

3 

4from google.cloud.bigquery import Client, QueryJobConfig, DEFAULT_RETRY 

5from google.api_core.exceptions import BadRequest, NotFound 

6import pandas as pd 

7from sqlalchemy_bigquery.base import BigQueryDialect 

8 

9from mindsdb.utilities import log 

10from mindsdb_sql_parser.ast.base import ASTNode 

11from mindsdb.integrations.libs.base import MetaDatabaseHandler 

12from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

13from mindsdb.integrations.utilities.handlers.auth_utilities.google import ( 

14 GoogleServiceAccountOAuth2Manager, 

15) 

16from mindsdb.integrations.libs.response import ( 

17 HandlerStatusResponse as StatusResponse, 

18 HandlerResponse as Response, 

19 RESPONSE_TYPE, 

20) 

21 

22logger = log.getLogger(__name__) 

23 

24 

25class BigQueryHandler(MetaDatabaseHandler): 

26 """ 

27 This handler handles connection and execution of Google BigQuery statements. 

28 """ 

29 

30 name = "bigquery" 

31 

32 def __init__(self, name: Text, connection_data: Dict, **kwargs: Any): 

33 super().__init__(name) 

34 self.connection_data = connection_data 

35 self.client = None 

36 self.is_connected = False 

37 

38 def __del__(self): 

39 if self.is_connected is True: 

40 self.disconnect() 

41 

42 def connect(self): 

43 """ 

44 Establishes a connection to a BigQuery warehouse. 

45 

46 Raises: 

47 ValueError: If the required connection parameters are not provided or if the credentials cannot be parsed. 

48 mindsdb.integrations.utilities.handlers.auth_utilities.exceptions.NoCredentialsException: If none of the required forms of credentials are provided. 

49 mindsdb.integrations.utilities.handlers.auth_utilities.exceptions.AuthException: If authentication fails. 

50 

51 Returns: 

52 google.cloud.bigquery.client.Client: The client object for the BigQuery connection. 

53 """ 

54 if self.is_connected is True: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 return self.connection 

56 

57 # Mandatory connection parameters 

58 if not all(key in self.connection_data for key in ["project_id", "dataset"]): 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true

59 raise ValueError("Required parameters (project_id, dataset) must be provided.") 

60 

61 service_account_json = self.connection_data.get("service_account_json") 

62 if isinstance(service_account_json, str): 62 ↛ 64line 62 didn't jump to line 64 because the condition on line 62 was never true

63 # GUI send it as str 

64 try: 

65 service_account_json = json.loads(service_account_json) 

66 except json.decoder.JSONDecodeError: 

67 raise ValueError("'service_account_json' is not valid JSON") 

68 if isinstance(service_account_json, dict) and isinstance(service_account_json.get("private_key"), str): 68 ↛ 70line 68 didn't jump to line 70 because the condition on line 68 was never true

69 # some editors may escape new line symbol, also replace windows-like newlines 

70 service_account_json["private_key"] = ( 

71 service_account_json["private_key"].replace("\\n", "\n").replace("\r\n", "\n") 

72 ) 

73 

74 google_sa_oauth2_manager = GoogleServiceAccountOAuth2Manager( 

75 credentials_file=self.connection_data.get("service_account_keys"), 

76 credentials_json=service_account_json, 

77 ) 

78 credentials = google_sa_oauth2_manager.get_oauth2_credentials() 

79 

80 client = Client(project=self.connection_data["project_id"], credentials=credentials) 

81 self.is_connected = True 

82 self.connection = client 

83 return self.connection 

84 

85 def disconnect(self): 

86 """ 

87 Closes the connection to the BigQuery warehouse if it's currently open. 

88 """ 

89 if self.is_connected is False: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return 

91 self.connection.close() 

92 self.is_connected = False 

93 

94 def check_connection(self) -> StatusResponse: 

95 """ 

96 Checks the status of the connection to the BigQuery warehouse. 

97 

98 Returns: 

99 StatusResponse: An object containing the success status and an error message if an error occurs. 

100 """ 

101 response = StatusResponse(False) 

102 

103 try: 

104 connection = self.connect() 

105 connection.query("SELECT 1;", timeout=10, retry=DEFAULT_RETRY.with_deadline(10)) 

106 

107 # Check if the dataset exists 

108 connection.get_dataset(self.connection_data["dataset"]) 

109 

110 response.success = True 

111 except (BadRequest, ValueError) as e: 

112 logger.error(f"Error connecting to BigQuery {self.connection_data['project_id']}, {e}!") 

113 response.error_message = e 

114 except NotFound: 

115 response.error_message = ( 

116 f"Error connecting to BigQuery {self.connection_data['project_id']}: " 

117 f"dataset '{self.connection_data['dataset']}' not found" 

118 ) 

119 

120 if response.success is False and self.is_connected is True: 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 self.is_connected = False 

122 

123 return response 

124 

125 def native_query(self, query: str) -> Response: 

126 """ 

127 Executes a SQL query on the BigQuery warehouse and returns the result. 

128 

129 Args: 

130 query (str): The SQL query to be executed. 

131 

132 Returns: 

133 Response: A response object containing the result of the query or an error message. 

134 """ 

135 connection = self.connect() 

136 try: 

137 job_config = QueryJobConfig( 

138 default_dataset=f"{self.connection_data['project_id']}.{self.connection_data['dataset']}" 

139 ) 

140 query = connection.query(query, job_config=job_config) 

141 result = query.to_dataframe() 

142 if not result.empty: 142 ↛ anywhereline 142 didn't jump anywhere: it always raised an exception.

143 response = Response(RESPONSE_TYPE.TABLE, result) 

144 else: 

145 response = Response(RESPONSE_TYPE.OK) 

146 except Exception as e: 

147 logger.error(f"Error running query: {query} on {self.connection_data['project_id']}!") 

148 response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) 

149 return response 

150 

151 def query(self, query: ASTNode) -> Response: 

152 """ 

153 Executes a SQL query represented by an ASTNode and retrieves the data. 

154 

155 Args: 

156 query (ASTNode): An ASTNode representing the SQL query to be executed. 

157 

158 Returns: 

159 Response: The response from the `native_query` method, containing the result of the SQL query execution. 

160 """ 

161 renderer = SqlalchemyRender(BigQueryDialect) 

162 query_str = renderer.get_string(query, with_failback=True) 

163 return self.native_query(query_str) 

164 

165 def get_tables(self) -> Response: 

166 """ 

167 Retrieves a list of all non-system tables and views in the configured dataset of the BigQuery warehouse. 

168 

169 Returns: 

170 Response: A response object containing the list of tables and views, formatted as per the `Response` class. 

171 """ 

172 query = f""" 

173 SELECT table_name, table_schema, table_type 

174 FROM `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.TABLES` 

175 WHERE table_type IN ('BASE TABLE', 'VIEW') 

176 """ 

177 result = self.native_query(query) 

178 return result 

179 

180 def get_columns(self, table_name) -> Response: 

181 """ 

182 Retrieves column details for a specified table in the configured dataset of the BigQuery warehouse. 

183 

184 Args: 

185 table_name (str): The name of the table for which to retrieve column information. 

186 

187 Returns: 

188 Response: A response object containing the column details, formatted as per the `Response` class. 

189 Raises: 

190 ValueError: If the 'table_name' is not a valid string. 

191 """ 

192 query = f""" 

193 SELECT column_name AS Field, data_type as Type 

194 FROM `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.COLUMNS` 

195 WHERE table_name = '{table_name}' 

196 """ 

197 result = self.native_query(query) 

198 return result 

199 

200 def meta_get_tables(self, table_names: Optional[list] = None) -> Response: 

201 """ 

202 Retrieves table metadata for the specified tables (or all tables if no list is provided). 

203 

204 Args: 

205 table_names (list): A list of table names for which to retrieve metadata information. 

206 

207 Returns: 

208 Response: A response object containing the metadata information, formatted as per the `Response` class. 

209 """ 

210 query = f""" 

211 SELECT 

212 t.table_name, 

213 t.table_schema, 

214 t.table_type, 

215 st.row_count 

216 FROM  

217 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.TABLES` AS t 

218 JOIN  

219 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.__TABLES__` AS st 

220 ON  

221 t.table_name = st.table_id 

222 WHERE  

223 t.table_type IN ('BASE TABLE', 'VIEW') 

224 """ 

225 

226 if table_names is not None and len(table_names) > 0: 226 ↛ 230line 226 didn't jump to line 230 because the condition on line 226 was always true

227 table_names = [f"'{t}'" for t in table_names] 

228 query += f" AND t.table_name IN ({','.join(table_names)})" 

229 

230 result = self.native_query(query) 

231 return result 

232 

233 def meta_get_columns(self, table_names: Optional[list] = None) -> Response: 

234 """ 

235 Retrieves column metadata for the specified tables (or all tables if no list is provided). 

236 

237 Args: 

238 table_names (list): A list of table names for which to retrieve column metadata. 

239 

240 Returns: 

241 Response: A response object containing the column metadata. 

242 """ 

243 query = f""" 

244 SELECT  

245 table_name, 

246 column_name, 

247 data_type, 

248 column_default, 

249 CASE is_nullable 

250 WHEN 'YES' THEN TRUE 

251 ELSE FALSE 

252 END AS is_nullable 

253 FROM  

254 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.COLUMNS` 

255 """ 

256 

257 if table_names is not None and len(table_names) > 0: 257 ↛ 261line 257 didn't jump to line 261 because the condition on line 257 was always true

258 table_names = [f"'{t}'" for t in table_names] 

259 query += f" WHERE table_name IN ({','.join(table_names)})" 

260 

261 result = self.native_query(query) 

262 return result 

263 

264 def meta_get_column_statistics_for_table(self, table_name: str, columns: list) -> Response: 

265 """ 

266 Retrieves statistics for the specified columns in a table. 

267 

268 Args: 

269 table_name (str): The name of the table. 

270 columns (list): A list of column names to retrieve statistics for. 

271 

272 Returns: 

273 Response: A response object containing the column statistics. 

274 """ 

275 # Check column data types 

276 column_types_query = f""" 

277 SELECT column_name, data_type 

278 FROM `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.COLUMNS` 

279 WHERE table_name = '{table_name}' 

280 """ 

281 column_types_result = self.native_query(column_types_query) 

282 

283 if column_types_result.resp_type != RESPONSE_TYPE.TABLE: 

284 logger.error(f"Error retrieving column types for table {table_name}") 

285 return Response( 

286 RESPONSE_TYPE.ERROR, 

287 error_message=f"Could not retrieve column types for table {table_name}", 

288 ) 

289 

290 column_type_map = dict( 

291 zip( 

292 column_types_result.data_frame["column_name"], 

293 column_types_result.data_frame["data_type"], 

294 ) 

295 ) 

296 

297 # Types that don't support MIN/MAX aggregations 

298 UNSUPPORTED_MINMAX_PREFIXES = ("ARRAY", "STRUCT", "RECORD") 

299 UNSUPPORTED_MINMAX_TYPES = ("GEOGRAPHY", "JSON", "BYTES") 

300 

301 def supports_minmax(data_type: str) -> bool: 

302 """Check if a BigQuery data type supports MIN/MAX operations.""" 

303 if data_type is None: 303 ↛ 304line 303 didn't jump to line 304 because the condition on line 303 was never true

304 return False 

305 data_type_upper = data_type.upper() 

306 if any(data_type_upper.startswith(prefix) for prefix in UNSUPPORTED_MINMAX_PREFIXES): 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 return False 

308 if data_type_upper in UNSUPPORTED_MINMAX_TYPES: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 return False 

310 return True 

311 

312 # To avoid hitting BigQuery's query size limits, we will chunk the columns into batches. 

313 BATCH_SIZE = 20 

314 

315 def chunked(lst, n): 

316 """Yields successive n-sized chunks from lst.""" 

317 for i in range(0, len(lst), n): 

318 yield lst[i : i + n] 

319 

320 queries = [] 

321 for column_batch in chunked(columns, BATCH_SIZE): 

322 batch_queries = [] 

323 for column in column_batch: 

324 data_type = column_type_map.get(column) 

325 

326 if supports_minmax(data_type): 326 ↛ 343line 326 didn't jump to line 343 because the condition on line 326 was always true

327 # Full statistics for supported types 

328 batch_queries.append( 

329 f""" 

330 SELECT 

331 '{table_name}' AS table_name, 

332 '{column}' AS column_name, 

333 SAFE_DIVIDE(COUNTIF(`{column}` IS NULL), COUNT(*)) * 100 AS null_percentage, 

334 CAST(MIN(`{column}`) AS STRING) AS minimum_value, 

335 CAST(MAX(`{column}`) AS STRING) AS maximum_value, 

336 COUNT(DISTINCT `{column}`) AS distinct_values_count 

337 FROM 

338 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.{table_name}` 

339 """ 

340 ) 

341 else: 

342 # Limited statistics for complex types (no MIN/MAX/COUNT DISTINCT) 

343 logger.info(f"Skipping MIN/MAX for column {column} with unsupported type: {data_type}") 

344 batch_queries.append( 

345 f""" 

346 SELECT 

347 '{table_name}' AS table_name, 

348 '{column}' AS column_name, 

349 SAFE_DIVIDE(COUNTIF(`{column}` IS NULL), COUNT(*)) * 100 AS null_percentage, 

350 CAST(NULL AS STRING) AS minimum_value, 

351 CAST(NULL AS STRING) AS maximum_value, 

352 CAST(NULL AS INT64) AS distinct_values_count 

353 FROM 

354 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.{table_name}` 

355 """ 

356 ) 

357 

358 if batch_queries: 358 ↛ 321line 358 didn't jump to line 321 because the condition on line 358 was always true

359 query = " UNION ALL ".join(batch_queries) 

360 queries.append(query) 

361 

362 results = [] 

363 for query in queries: 

364 try: 

365 result = self.native_query(query) 

366 if result.resp_type == RESPONSE_TYPE.TABLE: 366 ↛ 369line 366 didn't jump to line 369 because the condition on line 366 was always true

367 results.append(result.data_frame) 

368 else: 

369 logger.error(f"Error retrieving column statistics for table {table_name}: {result.error_message}") 

370 except Exception as e: 

371 logger.error(f"Exception occurred while retrieving column statistics for table {table_name}: {e}") 

372 

373 if not results: 373 ↛ 374line 373 didn't jump to line 374 because the condition on line 373 was never true

374 logger.warning(f"No column statistics could be retrieved for table {table_name}.") 

375 return Response( 

376 RESPONSE_TYPE.ERROR, 

377 error_message=f"No column statistics could be retrieved for table {table_name}.", 

378 ) 

379 return Response( 

380 RESPONSE_TYPE.TABLE, 

381 pd.concat(results, ignore_index=True) if results else pd.DataFrame(), 

382 ) 

383 

384 def meta_get_primary_keys(self, table_names: Optional[list] = None) -> Response: 

385 """ 

386 Retrieves primary key information for the specified tables (or all tables if no list is provided). 

387 

388 Args: 

389 table_names (list): A list of table names for which to retrieve primary key information. 

390 

391 Returns: 

392 Response: A response object containing the primary key information. 

393 """ 

394 query = f""" 

395 SELECT 

396 tc.table_name, 

397 kcu.column_name, 

398 kcu.ordinal_position, 

399 tc.constraint_name, 

400 FROM 

401 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` AS tc 

402 JOIN 

403 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE` AS kcu 

404 ON 

405 tc.constraint_name = kcu.constraint_name 

406 WHERE 

407 tc.constraint_type = 'PRIMARY KEY' 

408 """ 

409 

410 if table_names is not None and len(table_names) > 0: 410 ↛ 414line 410 didn't jump to line 414 because the condition on line 410 was always true

411 table_names = [f"'{t}'" for t in table_names] 

412 query += f" AND tc.table_name IN ({','.join(table_names)})" 

413 

414 result = self.native_query(query) 

415 return result 

416 

417 def meta_get_foreign_keys(self, table_names: Optional[list] = None) -> Response: 

418 """ 

419 Retrieves foreign key information for the specified tables (or all tables if no list is provided). 

420 

421 Args: 

422 table_names (list): A list of table names for which to retrieve foreign key information. 

423 

424 Returns: 

425 Response: A response object containing the foreign key information. 

426 """ 

427 query = f""" 

428 SELECT 

429 ccu.table_name AS parent_table_name, 

430 ccu.column_name AS parent_column_name, 

431 kcu.table_name AS child_table_name, 

432 kcu.column_name AS child_column_name, 

433 tc.constraint_name 

434 FROM 

435 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` AS tc 

436 JOIN 

437 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE` AS kcu 

438 ON 

439 tc.constraint_name = kcu.constraint_name 

440 JOIN 

441 `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE` AS ccu 

442 ON 

443 tc.constraint_name = ccu.constraint_name 

444 WHERE 

445 tc.constraint_type = 'FOREIGN KEY' 

446 """ 

447 

448 if table_names is not None and len(table_names) > 0: 448 ↛ 452line 448 didn't jump to line 452 because the condition on line 448 was always true

449 table_names = [f"'{t}'" for t in table_names] 

450 query += f" AND tc.table_name IN ({','.join(table_names)})" 

451 

452 result = self.native_query(query) 

453 return result