Coverage for mindsdb / integrations / handlers / databricks_handler / databricks_handler.py: 64%

137 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Text, Dict, Any, Optional 

2 

3import pandas as pd 

4from databricks.sql import connect, RequestError, ServerOperationError 

5from databricks.sql.client import Connection 

6from databricks.sqlalchemy import DatabricksDialect 

7from mindsdb_sql_parser.ast.base import ASTNode 

8 

9from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender 

10from mindsdb.integrations.libs.base import DatabaseHandler 

11from mindsdb.integrations.libs.response import ( 

12 HandlerStatusResponse as StatusResponse, 

13 HandlerResponse as Response, 

14 RESPONSE_TYPE, 

15 INF_SCHEMA_COLUMNS_NAMES_SET, 

16) 

17from mindsdb.utilities import log 

18from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE 

19 

20 

21logger = log.getLogger(__name__) 

22 

23 

24def _map_type(internal_type_name: str | None) -> MYSQL_DATA_TYPE: 

25 """Map MyDatabricks SQL text types names to MySQL types as enum. 

26 

27 Args: 

28 internal_type_name (str): The name of the Databricks type to map. 

29 

30 Returns: 

31 MYSQL_DATA_TYPE: The MySQL type enum that corresponds to the MySQL text type name. 

32 """ 

33 if not isinstance(internal_type_name, str): 

34 return MYSQL_DATA_TYPE.TEXT 

35 if internal_type_name.upper() == "STRING": 

36 return MYSQL_DATA_TYPE.TEXT 

37 if internal_type_name.upper() == "LONG": 

38 return MYSQL_DATA_TYPE.BIGINT 

39 if internal_type_name.upper() == "SHORT": 

40 return MYSQL_DATA_TYPE.SMALLINT 

41 try: 

42 return MYSQL_DATA_TYPE(internal_type_name.upper()) 

43 except Exception: 

44 logger.info(f"Databricks handler: unknown type: {internal_type_name}, use TEXT as fallback.") 

45 return MYSQL_DATA_TYPE.TEXT 

46 

47 

48class DatabricksHandler(DatabaseHandler): 

49 """ 

50 This handler handles the connection and execution of SQL statements on Databricks. 

51 """ 

52 

53 name = "databricks" 

54 

55 def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs: Any) -> None: 

56 """ 

57 Initializes the handler. 

58 

59 Args: 

60 name (Text): The name of the handler instance. 

61 connection_data (Dict): The connection data required to connect to the Databricks workspace. 

62 kwargs: Arbitrary keyword arguments. 

63 """ 

64 super().__init__(name) 

65 self.connection_data = connection_data 

66 self.kwargs = kwargs 

67 

68 self.connection = None 

69 self.is_connected = False 

70 self.cache_thread_safe = True 

71 

72 def __del__(self) -> None: 

73 """ 

74 Closes the connection when the handler instance is deleted. 

75 """ 

76 if self.is_connected is True: 

77 self.disconnect() 

78 

79 def connect(self) -> Connection: 

80 """ 

81 Establishes a connection to the Databricks workspace. 

82 

83 Raises: 

84 ValueError: If the expected connection parameters are not provided. 

85 

86 Returns: 

87 databricks.sql.client.Connection: A connection object to the Databricks workspace. 

88 """ 

89 if self.is_connected is True: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return self.connection 

91 

92 # Mandatory connection parameters. 

93 if not all(key in self.connection_data for key in ["server_hostname", "http_path", "access_token"]): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 raise ValueError("Required parameters (server_hostname, http_path, access_token) must be provided.") 

95 

96 config = { 

97 "server_hostname": self.connection_data["server_hostname"], 

98 "http_path": self.connection_data["http_path"], 

99 "access_token": self.connection_data["access_token"], 

100 } 

101 

102 # Optional connection parameters. 

103 optional_parameters = [ 

104 "session_configuration", 

105 "http_headers", 

106 "catalog", 

107 "schema", 

108 ] 

109 for parameter in optional_parameters: 

110 if parameter in self.connection_data: 110 ↛ 111line 110 didn't jump to line 111 because the condition on line 110 was never true

111 config[parameter] = self.connection_data[parameter] 

112 

113 try: 

114 self.connection = connect(**config) 

115 self.is_connected = True 

116 return self.connection 

117 except RequestError as request_error: 

118 logger.error(f"Request error when connecting to Databricks: {request_error}") 

119 raise 

120 except RuntimeError as runtime_error: 

121 logger.error(f"Runtime error when connecting to Databricks: {runtime_error}") 

122 raise 

123 except Exception as unknown_error: 

124 logger.error(f"Unknown error when connecting to Databricks: {unknown_error}") 

125 raise 

126 

127 def disconnect(self): 

128 """ 

129 Closes the connection to the Databricks workspace if it's currently open. 

130 """ 

131 if self.is_connected is False: 

132 return 

133 

134 self.connection.close() 

135 self.is_connected = False 

136 return self.is_connected 

137 

138 def check_connection(self) -> StatusResponse: 

139 """ 

140 Checks the status of the connection to the Databricks workspace. 

141 

142 Returns: 

143 StatusResponse: An object containing the success status and an error message if an error occurs. 

144 """ 

145 response = StatusResponse(False) 

146 need_to_close = self.is_connected is False 

147 

148 try: 

149 connection = self.connect() 

150 

151 # Execute a simple query to check the connection. 

152 query = "SELECT 1 FROM information_schema.schemata" 

153 if "schema" in self.connection_data: 153 ↛ 154line 153 didn't jump to line 154 because the condition on line 153 was never true

154 query += f" WHERE schema_name = '{self.connection_data['schema']}'" 

155 

156 with connection.cursor() as cursor: 

157 cursor.execute(query) 

158 result = cursor.fetchall() 

159 

160 # If the query does not return a result, the schema does not exist. 

161 if not result: 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 raise ValueError(f"The schema {self.connection_data['schema']} does not exist!") 

163 

164 response.success = True 

165 except (ValueError, RequestError, RuntimeError, ServerOperationError) as known_error: 

166 logger.error(f"Connection check to Databricks failed, {known_error}!") 

167 response.error_message = str(known_error) 

168 except Exception as unknown_error: 

169 logger.error(f"Connection check to Databricks failed due to an unknown error, {unknown_error}!") 

170 response.error_message = str(unknown_error) 

171 

172 if response.success and need_to_close: 

173 self.disconnect() 

174 

175 elif not response.success and self.is_connected: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 self.is_connected = False 

177 

178 return response 

179 

180 def native_query(self, query: Text) -> Response: 

181 """ 

182 Executes a native SQL query on the Databricks workspace and returns the result. 

183 

184 Args: 

185 query (Text): The SQL query to be executed. 

186 

187 Returns: 

188 Response: A response object containing the result of the query or an error message. 

189 """ 

190 need_to_close = self.is_connected is False 

191 

192 connection = self.connect() 

193 with connection.cursor() as cursor: 

194 try: 

195 cursor.execute(query) 

196 result = cursor.fetchall() 

197 if result: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true

198 response = Response( 

199 RESPONSE_TYPE.TABLE, 

200 data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), 

201 ) 

202 

203 else: 

204 response = Response(RESPONSE_TYPE.OK) 

205 connection.commit() 

206 except ServerOperationError as server_error: 

207 logger.error(f"Server error running query: {query} on Databricks, {server_error}!") 

208 response = Response(RESPONSE_TYPE.ERROR, error_message=str(server_error)) 

209 except Exception as unknown_error: 

210 logger.error(f"Unknown error running query: {query} on Databricks, {unknown_error}!") 

211 response = Response(RESPONSE_TYPE.ERROR, error_message=str(unknown_error)) 

212 

213 if need_to_close is True: 213 ↛ 216line 213 didn't jump to line 216 because the condition on line 213 was always true

214 self.disconnect() 

215 

216 return response 

217 

218 def query(self, query: ASTNode) -> Response: 

219 """ 

220 Executes a SQL query represented by an ASTNode on the Databricks Workspace and retrieves the data. 

221 

222 Args: 

223 query (ASTNode): An ASTNode representing the SQL query to be executed. 

224 

225 Returns: 

226 Response: The response from the `native_query` method, containing the result of the SQL query execution. 

227 """ 

228 renderer = SqlalchemyRender(DatabricksDialect) 

229 query_str = renderer.get_string(query, with_failback=True) 

230 return self.native_query(query_str) 

231 

232 def get_tables(self, all: bool = False) -> Response: 

233 """ 

234 Retrieves a list of all non-system tables in the connected schema of the Databricks workspace. 

235 

236 Args: 

237 all (bool): If True - return tables from all schemas. 

238 

239 Returns: 

240 Response: A response object containing a list of tables in the connected schema. 

241 """ 

242 all_filter = "and table_schema = current_schema()" 

243 if all is True: 243 ↛ 244line 243 didn't jump to line 244 because the condition on line 243 was never true

244 all_filter = "" 

245 query = f""" 

246 SELECT 

247 table_schema, 

248 table_name, 

249 table_type 

250 FROM 

251 information_schema.tables 

252 WHERE 

253 table_schema != 'information_schema' 

254 {all_filter} 

255 """ 

256 result = self.native_query(query) 

257 if result.resp_type == RESPONSE_TYPE.OK: 257 ↛ 258line 257 didn't jump to line 258 because the condition on line 257 was never true

258 result = Response( 

259 RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET)) 

260 ) 

261 return result 

262 

263 def get_columns(self, table_name: str, schema_name: str | None = None) -> Response: 

264 """ 

265 Retrieves column details for a specified table in the Databricks workspace. 

266 

267 Args: 

268 table_name (str): The name of the table for which to retrieve column information. 

269 schema_name (str|None): The name of the schema in which the table is located. 

270 

271 Raises: 

272 ValueError: If the 'table_name' is not a valid string. 

273 

274 Returns: 

275 Response: A response object containing the column details. 

276 """ 

277 if not table_name or not isinstance(table_name, str): 277 ↛ 278line 277 didn't jump to line 278 because the condition on line 277 was never true

278 raise ValueError("Invalid table name provided.") 

279 

280 if isinstance(schema_name, str): 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 schema_name = f"'{schema_name}'" 

282 else: 

283 schema_name = "current_schema()" 

284 query = f""" 

285 SELECT 

286 COLUMN_NAME, 

287 DATA_TYPE, 

288 ORDINAL_POSITION, 

289 COLUMN_DEFAULT, 

290 IS_NULLABLE, 

291 CHARACTER_MAXIMUM_LENGTH, 

292 CHARACTER_OCTET_LENGTH, 

293 NUMERIC_PRECISION, 

294 NUMERIC_SCALE, 

295 DATETIME_PRECISION, 

296 null as CHARACTER_SET_NAME, 

297 null as COLLATION_NAME 

298 FROM 

299 information_schema.columns 

300 WHERE 

301 table_name = '{table_name}' 

302 AND 

303 table_schema = {schema_name} 

304 """ 

305 

306 result = self.native_query(query) 

307 if result.resp_type == RESPONSE_TYPE.OK: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true

308 result = Response( 

309 RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET)) 

310 ) 

311 result.to_columns_table_response(map_type_fn=_map_type) 

312 

313 return result