Coverage for mindsdb / integrations / handlers / databricks_handler / databricks_handler.py: 64%
137 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Text, Dict, Any, Optional
3import pandas as pd
4from databricks.sql import connect, RequestError, ServerOperationError
5from databricks.sql.client import Connection
6from databricks.sqlalchemy import DatabricksDialect
7from mindsdb_sql_parser.ast.base import ASTNode
9from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender
10from mindsdb.integrations.libs.base import DatabaseHandler
11from mindsdb.integrations.libs.response import (
12 HandlerStatusResponse as StatusResponse,
13 HandlerResponse as Response,
14 RESPONSE_TYPE,
15 INF_SCHEMA_COLUMNS_NAMES_SET,
16)
17from mindsdb.utilities import log
18from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
21logger = log.getLogger(__name__)
24def _map_type(internal_type_name: str | None) -> MYSQL_DATA_TYPE:
25 """Map MyDatabricks SQL text types names to MySQL types as enum.
27 Args:
28 internal_type_name (str): The name of the Databricks type to map.
30 Returns:
31 MYSQL_DATA_TYPE: The MySQL type enum that corresponds to the MySQL text type name.
32 """
33 if not isinstance(internal_type_name, str):
34 return MYSQL_DATA_TYPE.TEXT
35 if internal_type_name.upper() == "STRING":
36 return MYSQL_DATA_TYPE.TEXT
37 if internal_type_name.upper() == "LONG":
38 return MYSQL_DATA_TYPE.BIGINT
39 if internal_type_name.upper() == "SHORT":
40 return MYSQL_DATA_TYPE.SMALLINT
41 try:
42 return MYSQL_DATA_TYPE(internal_type_name.upper())
43 except Exception:
44 logger.info(f"Databricks handler: unknown type: {internal_type_name}, use TEXT as fallback.")
45 return MYSQL_DATA_TYPE.TEXT
48class DatabricksHandler(DatabaseHandler):
49 """
50 This handler handles the connection and execution of SQL statements on Databricks.
51 """
53 name = "databricks"
55 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs: Any) -> None:
56 """
57 Initializes the handler.
59 Args:
60 name (Text): The name of the handler instance.
61 connection_data (Dict): The connection data required to connect to the Databricks workspace.
62 kwargs: Arbitrary keyword arguments.
63 """
64 super().__init__(name)
65 self.connection_data = connection_data
66 self.kwargs = kwargs
68 self.connection = None
69 self.is_connected = False
70 self.cache_thread_safe = True
72 def __del__(self) -> None:
73 """
74 Closes the connection when the handler instance is deleted.
75 """
76 if self.is_connected is True:
77 self.disconnect()
79 def connect(self) -> Connection:
80 """
81 Establishes a connection to the Databricks workspace.
83 Raises:
84 ValueError: If the expected connection parameters are not provided.
86 Returns:
87 databricks.sql.client.Connection: A connection object to the Databricks workspace.
88 """
89 if self.is_connected is True: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 return self.connection
92 # Mandatory connection parameters.
93 if not all(key in self.connection_data for key in ["server_hostname", "http_path", "access_token"]): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 raise ValueError("Required parameters (server_hostname, http_path, access_token) must be provided.")
96 config = {
97 "server_hostname": self.connection_data["server_hostname"],
98 "http_path": self.connection_data["http_path"],
99 "access_token": self.connection_data["access_token"],
100 }
102 # Optional connection parameters.
103 optional_parameters = [
104 "session_configuration",
105 "http_headers",
106 "catalog",
107 "schema",
108 ]
109 for parameter in optional_parameters:
110 if parameter in self.connection_data: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true
111 config[parameter] = self.connection_data[parameter]
113 try:
114 self.connection = connect(**config)
115 self.is_connected = True
116 return self.connection
117 except RequestError as request_error:
118 logger.error(f"Request error when connecting to Databricks: {request_error}")
119 raise
120 except RuntimeError as runtime_error:
121 logger.error(f"Runtime error when connecting to Databricks: {runtime_error}")
122 raise
123 except Exception as unknown_error:
124 logger.error(f"Unknown error when connecting to Databricks: {unknown_error}")
125 raise
127 def disconnect(self):
128 """
129 Closes the connection to the Databricks workspace if it's currently open.
130 """
131 if self.is_connected is False:
132 return
134 self.connection.close()
135 self.is_connected = False
136 return self.is_connected
138 def check_connection(self) -> StatusResponse:
139 """
140 Checks the status of the connection to the Databricks workspace.
142 Returns:
143 StatusResponse: An object containing the success status and an error message if an error occurs.
144 """
145 response = StatusResponse(False)
146 need_to_close = self.is_connected is False
148 try:
149 connection = self.connect()
151 # Execute a simple query to check the connection.
152 query = "SELECT 1 FROM information_schema.schemata"
153 if "schema" in self.connection_data: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true
154 query += f" WHERE schema_name = '{self.connection_data['schema']}'"
156 with connection.cursor() as cursor:
157 cursor.execute(query)
158 result = cursor.fetchall()
160 # If the query does not return a result, the schema does not exist.
161 if not result: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 raise ValueError(f"The schema {self.connection_data['schema']} does not exist!")
164 response.success = True
165 except (ValueError, RequestError, RuntimeError, ServerOperationError) as known_error:
166 logger.error(f"Connection check to Databricks failed, {known_error}!")
167 response.error_message = str(known_error)
168 except Exception as unknown_error:
169 logger.error(f"Connection check to Databricks failed due to an unknown error, {unknown_error}!")
170 response.error_message = str(unknown_error)
172 if response.success and need_to_close:
173 self.disconnect()
175 elif not response.success and self.is_connected: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 self.is_connected = False
178 return response
180 def native_query(self, query: Text) -> Response:
181 """
182 Executes a native SQL query on the Databricks workspace and returns the result.
184 Args:
185 query (Text): The SQL query to be executed.
187 Returns:
188 Response: A response object containing the result of the query or an error message.
189 """
190 need_to_close = self.is_connected is False
192 connection = self.connect()
193 with connection.cursor() as cursor:
194 try:
195 cursor.execute(query)
196 result = cursor.fetchall()
197 if result: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 response = Response(
199 RESPONSE_TYPE.TABLE,
200 data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]),
201 )
203 else:
204 response = Response(RESPONSE_TYPE.OK)
205 connection.commit()
206 except ServerOperationError as server_error:
207 logger.error(f"Server error running query: {query} on Databricks, {server_error}!")
208 response = Response(RESPONSE_TYPE.ERROR, error_message=str(server_error))
209 except Exception as unknown_error:
210 logger.error(f"Unknown error running query: {query} on Databricks, {unknown_error}!")
211 response = Response(RESPONSE_TYPE.ERROR, error_message=str(unknown_error))
213 if need_to_close is True: 213 ↛ 216line 213 didn't jump to line 216 because the condition on line 213 was always true
214 self.disconnect()
216 return response
218 def query(self, query: ASTNode) -> Response:
219 """
220 Executes a SQL query represented by an ASTNode on the Databricks Workspace and retrieves the data.
222 Args:
223 query (ASTNode): An ASTNode representing the SQL query to be executed.
225 Returns:
226 Response: The response from the `native_query` method, containing the result of the SQL query execution.
227 """
228 renderer = SqlalchemyRender(DatabricksDialect)
229 query_str = renderer.get_string(query, with_failback=True)
230 return self.native_query(query_str)
232 def get_tables(self, all: bool = False) -> Response:
233 """
234 Retrieves a list of all non-system tables in the connected schema of the Databricks workspace.
236 Args:
237 all (bool): If True - return tables from all schemas.
239 Returns:
240 Response: A response object containing a list of tables in the connected schema.
241 """
242 all_filter = "and table_schema = current_schema()"
243 if all is True: 243 ↛ 244line 243 didn't jump to line 244 because the condition on line 243 was never true
244 all_filter = ""
245 query = f"""
246 SELECT
247 table_schema,
248 table_name,
249 table_type
250 FROM
251 information_schema.tables
252 WHERE
253 table_schema != 'information_schema'
254 {all_filter}
255 """
256 result = self.native_query(query)
257 if result.resp_type == RESPONSE_TYPE.OK: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true
258 result = Response(
259 RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET))
260 )
261 return result
263 def get_columns(self, table_name: str, schema_name: str | None = None) -> Response:
264 """
265 Retrieves column details for a specified table in the Databricks workspace.
267 Args:
268 table_name (str): The name of the table for which to retrieve column information.
269 schema_name (str|None): The name of the schema in which the table is located.
271 Raises:
272 ValueError: If the 'table_name' is not a valid string.
274 Returns:
275 Response: A response object containing the column details.
276 """
277 if not table_name or not isinstance(table_name, str): 277 ↛ 278line 277 didn't jump to line 278 because the condition on line 277 was never true
278 raise ValueError("Invalid table name provided.")
280 if isinstance(schema_name, str): 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true
281 schema_name = f"'{schema_name}'"
282 else:
283 schema_name = "current_schema()"
284 query = f"""
285 SELECT
286 COLUMN_NAME,
287 DATA_TYPE,
288 ORDINAL_POSITION,
289 COLUMN_DEFAULT,
290 IS_NULLABLE,
291 CHARACTER_MAXIMUM_LENGTH,
292 CHARACTER_OCTET_LENGTH,
293 NUMERIC_PRECISION,
294 NUMERIC_SCALE,
295 DATETIME_PRECISION,
296 null as CHARACTER_SET_NAME,
297 null as COLLATION_NAME
298 FROM
299 information_schema.columns
300 WHERE
301 table_name = '{table_name}'
302 AND
303 table_schema = {schema_name}
304 """
306 result = self.native_query(query)
307 if result.resp_type == RESPONSE_TYPE.OK: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 result = Response(
309 RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET))
310 )
311 result.to_columns_table_response(map_type_fn=_map_type)
313 return result