Coverage for mindsdb / api / mysql / mysql_proxy / mysql_proxy.py: 18%

508 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1""" 

2******************************************************* 

3 * Copyright (C) 2017 MindsDB Inc. <copyright@mindsdb.com> 

4 * 

5 * This file is part of MindsDB Server. 

6 * 

7 * MindsDB Server can not be copied and/or distributed without the express 

8 * permission of MindsDB Inc 

9 ******************************************************* 

10""" 

11 

12import atexit 

13import base64 

14import os 

15import select 

16import socket 

17import socketserver as SocketServer 

18import ssl 

19import struct 

20import sys 

21import tempfile 

22import traceback 

23import logging 

24from functools import partial 

25from typing import List 

26from dataclasses import dataclass 

27 

28import mindsdb.utilities.hooks as hooks 

29import mindsdb.utilities.profiler as profiler 

30from mindsdb.utilities.sql import clear_sql 

31from mindsdb.api.mysql.mysql_proxy.classes.client_capabilities import ClentCapabilities 

32from mindsdb.api.mysql.mysql_proxy.classes.server_capabilities import ( 

33 server_capabilities, 

34) 

35from mindsdb.api.executor.controllers import SessionController 

36from mindsdb.api.mysql.mysql_proxy.data_types.mysql_packet import Packet 

37from mindsdb.api.mysql.mysql_proxy.data_types.mysql_packets import ( 

38 BinaryResultsetRowPacket, 

39 ColumnCountPacket, 

40 ColumnDefenitionPacket, 

41 CommandPacket, 

42 EofPacket, 

43 ErrPacket, 

44 FastAuthFail, 

45 HandshakePacket, 

46 HandshakeResponsePacket, 

47 OkPacket, 

48 PasswordAnswer, 

49 ResultsetRowPacket, 

50 STMTPrepareHeaderPacket, 

51 SwitchOutPacket, 

52 SwitchOutResponse, 

53) 

54from mindsdb.api.mysql.mysql_proxy.executor import Executor 

55from mindsdb.api.mysql.mysql_proxy.external_libs.mysql_scramble import ( 

56 scramble as scramble_func, 

57) 

58from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import ( 

59 DEFAULT_AUTH_METHOD, 

60 CHARSET_NUMBERS, 

61 SERVER_STATUS, 

62 CAPABILITIES, 

63 COMMANDS, 

64 ERR, 

65 getConstName, 

66) 

67from mindsdb.api.executor.data_types.answer import ExecuteAnswer 

68from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE 

69from mindsdb.api.executor import exceptions as executor_exceptions 

70 

71from mindsdb.api.common.middleware import check_auth 

72from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE 

73from mindsdb.api.executor.sql_query.result_set import Column, ResultSet 

74from mindsdb.utilities import log 

75from mindsdb.utilities.config import config 

76from mindsdb.utilities.context import context as ctx 

77from mindsdb.utilities.otel import increment_otel_query_request_counter 

78from mindsdb.utilities.wizards import make_ssl_cert 

79from mindsdb.utilities.exception import QueryError 

80from mindsdb.utilities.functions import mark_process 

81from mindsdb.api.mysql.mysql_proxy.utilities.dump import ( 

82 dump_result_set_to_mysql, 

83 column_to_mysql_column_dict, 

84 dump_columns_info, 

85 dump_chunks, 

86) 

87from mindsdb.api.executor.exceptions import WrongCharsetError 

88 

89logger = log.getLogger(__name__) 

90 

91 

92def empty_fn(): 

93 pass 

94 

95 

96@dataclass 

97class SQLAnswer: 

98 resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK 

99 result_set: ResultSet | None = None 

100 status: int | None = None 

101 state_track: List[List] | None = None 

102 error_code: int | None = None 

103 error_message: str | None = None 

104 affected_rows: int | None = None 

105 mysql_types: list[MYSQL_DATA_TYPE] | None = None 

106 

107 @property 

108 def type(self): 

109 return self.resp_type 

110 

111 def dump_http_response(self) -> dict: 

112 if self.resp_type == RESPONSE_TYPE.OK: 

113 return { 

114 "type": self.resp_type, 

115 "affected_rows": self.affected_rows, 

116 } 

117 elif self.resp_type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE): 117 ↛ 124line 117 didn't jump to line 124 because the condition on line 117 was always true

118 data = self.result_set.to_lists(json_types=True) 

119 return { 

120 "type": RESPONSE_TYPE.TABLE, 

121 "data": data, 

122 "column_names": [column.alias or column.name for column in self.result_set.columns], 

123 } 

124 elif self.resp_type == RESPONSE_TYPE.ERROR: 

125 return { 

126 "type": RESPONSE_TYPE.ERROR, 

127 "error_code": self.error_code or 0, 

128 "error_message": self.error_message, 

129 } 

130 else: 

131 raise ValueError(f"Unsupported response type for dump HTTP response: {self.resp_type}") 

132 

133 

134class MysqlTCPServer(SocketServer.ThreadingTCPServer): 

135 """ 

136 Custom TCP Server with increased request queue size 

137 """ 

138 

139 request_queue_size = 30 

140 

141 

142class MysqlProxy(SocketServer.BaseRequestHandler): 

143 """ 

144 The Main Server controller class 

145 """ 

146 

147 @staticmethod 

148 def server_close(srv): 

149 srv.server_close() 

150 

151 def __init__(self, request, client_address, server): 

152 self.charset = "utf8" 

153 self.charset_text_type = CHARSET_NUMBERS["utf8_general_ci"] 

154 self.session = None 

155 self.client_capabilities = None 

156 self.connection_id = None 

157 super().__init__(request, client_address, server) 

158 

159 def init_session(self): 

160 logger.debug("New connection [{ip}:{port}]".format(ip=self.client_address[0], port=self.client_address[1])) 

161 

162 if self.server.connection_id >= 65025: 

163 self.server.connection_id = 0 

164 self.server.connection_id += 1 

165 self.connection_id = self.server.connection_id 

166 self.session = SessionController(api_type="sql") 

167 

168 if hasattr(self.server, "salt") and isinstance(self.server.salt, str): 

169 self.salt = self.server.salt 

170 else: 

171 self.salt = base64.b64encode(os.urandom(15)).decode() 

172 

173 self.socket = self.request 

174 self.logging = logger 

175 

176 self.current_transaction = None 

177 

178 logger.debug("session salt: {salt}".format(salt=self.salt)) 

179 

180 def handshake(self): 

181 def switch_auth(method="mysql_native_password"): 

182 self.packet(SwitchOutPacket, seed=self.salt, method=method).send() 

183 switch_out_answer = self.packet(SwitchOutResponse) 

184 switch_out_answer.get() 

185 password = switch_out_answer.password 

186 if method == "mysql_native_password" and len(password) == 0: 

187 password = scramble_func("", self.salt) 

188 return password 

189 

190 def get_fast_auth_password(): 

191 logger.debug("Asking for fast auth password") 

192 self.packet(FastAuthFail).send() 

193 password_answer = self.packet(PasswordAnswer) 

194 password_answer.get() 

195 try: 

196 password = password_answer.password.value.decode() 

197 except Exception: 

198 logger.warning("error: no password in Fast Auth answer") 

199 self.packet( 

200 ErrPacket, 

201 err_code=ERR.ER_PASSWORD_NO_MATCH, 

202 msg="Is not password in connection query.", 

203 ).send() 

204 return None 

205 return password 

206 

207 username = None 

208 password = None 

209 

210 logger.debug("send HandshakePacket") 

211 self.packet(HandshakePacket).send() 

212 

213 handshake_resp = self.packet(HandshakeResponsePacket) 

214 handshake_resp.get() 

215 if handshake_resp.length == 0: 

216 logger.debug("HandshakeResponsePacket empty") 

217 self.packet(OkPacket).send() 

218 return False 

219 self.client_capabilities = ClentCapabilities(handshake_resp.capabilities.value) 

220 

221 client_auth_plugin = handshake_resp.client_auth_plugin.value.decode() 

222 

223 self.session.is_ssl = False 

224 

225 if handshake_resp.type == "SSLRequest": 

226 logger.debug("switch to SSL") 

227 self.session.is_ssl = True 

228 

229 ssl_context = ssl.SSLContext() 

230 ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 

231 ssl_context.load_cert_chain(self.server.cert_path) 

232 ssl_socket = ssl_context.wrap_socket(self.socket, server_side=True, do_handshake_on_connect=True) 

233 

234 self.socket = ssl_socket 

235 handshake_resp = self.packet(HandshakeResponsePacket) 

236 handshake_resp.get() 

237 client_auth_plugin = handshake_resp.client_auth_plugin.value.decode() 

238 

239 username = handshake_resp.username.value.decode() 

240 

241 if client_auth_plugin != DEFAULT_AUTH_METHOD: 

242 if client_auth_plugin == "mysql_native_password": 

243 password = switch_auth("mysql_native_password") 

244 else: 

245 new_method = ( 

246 "caching_sha2_password" 

247 if client_auth_plugin == "caching_sha2_password" 

248 else "mysql_native_password" 

249 ) 

250 

251 if new_method == "caching_sha2_password" and self.session.is_ssl is False: 

252 logger.warning( 

253 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

254 "error: cant switch to caching_sha2_password without SSL" 

255 ) 

256 self.packet( 

257 ErrPacket, 

258 err_code=ERR.ER_PASSWORD_NO_MATCH, 

259 msg="caching_sha2_password without SSL not supported", 

260 ).send() 

261 return False 

262 

263 logger.debug( 

264 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

265 f"switch auth method to {new_method}" 

266 ) 

267 password = switch_auth(new_method) 

268 

269 if new_method == "caching_sha2_password": 

270 if password == b"\x00": 

271 password = "" 

272 else: 

273 password = get_fast_auth_password() 

274 elif "caching_sha2_password" in client_auth_plugin: 

275 logger.debug( 

276 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

277 "check auth using caching_sha2_password" 

278 ) 

279 password = handshake_resp.enc_password.value 

280 if password == b"\x00": 

281 password = "" 

282 else: 

283 # FIXME https://github.com/mindsdb/mindsdb/issues/1374 

284 # if self.session.is_ssl: 

285 # password = get_fast_auth_password() 

286 # else: 

287 password = switch_auth() 

288 elif "mysql_native_password" in client_auth_plugin: 

289 logger.debug( 

290 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

291 "check auth using mysql_native_password" 

292 ) 

293 password = handshake_resp.enc_password.value 

294 else: 

295 logger.debug( 

296 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

297 "unknown method, possible ERROR. Try to switch to mysql_native_password" 

298 ) 

299 password = switch_auth("mysql_native_password") 

300 

301 try: 

302 self.session.database = handshake_resp.database.value.decode() 

303 except Exception: 

304 self.session.database = None 

305 logger.debug( 

306 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: " 

307 f"connecting to database {self.session.database}" 

308 ) 

309 

310 auth_data = self.server.check_auth(username, password, scramble_func, self.salt, ctx.company_id) 

311 if auth_data["success"]: 

312 self.session.username = auth_data["username"] 

313 self.session.auth = True 

314 self.packet(OkPacket).send() 

315 return True 

316 else: 

317 self.packet( 

318 ErrPacket, 

319 err_code=ERR.ER_PASSWORD_NO_MATCH, 

320 msg=f"Access denied for user {username}", 

321 ).send() 

322 logger.warning(f"Access denied for user {username}") 

323 return False 

324 

325 def send_package_group(self, packages): 

326 string = b"".join([x.accum() for x in packages]) 

327 self.socket.sendall(string) 

328 

329 def answer_stmt_close(self, stmt_id): 

330 self.session.unregister_stmt(stmt_id) 

331 

332 def send_query_answer(self, answer: SQLAnswer): 

333 if answer.type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE): 

334 packages = [] 

335 

336 if len(answer.result_set) >= 1000: 

337 # for big responses leverage pandas map function to convert data to packages 

338 self.send_table_packets(result_set=answer.result_set) 

339 else: 

340 packages += self.get_table_packets(result_set=answer.result_set) 

341 

342 if answer.status is not None: 

343 packages.append(self.last_packet(status=answer.status)) 

344 else: 

345 packages.append(self.last_packet()) 

346 self.send_package_group(packages) 

347 elif answer.type == RESPONSE_TYPE.OK: 

348 self.packet(OkPacket, state_track=answer.state_track, affected_rows=answer.affected_rows).send() 

349 elif answer.type == RESPONSE_TYPE.ERROR: 

350 self.packet(ErrPacket, err_code=answer.error_code, msg=answer.error_message).send() 

351 elif answer.type == RESPONSE_TYPE.EOF: 

352 self.packet(EofPacket).send() 

353 

354 def _get_column_defenition_packets(self, columns: dict, data=None): 

355 if data is None: 

356 data = [] 

357 packets = [] 

358 for i, column in enumerate(columns): 

359 logger.debug( 

360 "%s._get_column_defenition_packets: handling column - %s of %s type", 

361 self.__class__.__name__, 

362 column, 

363 type(column), 

364 ) 

365 table_name = column.get("table_name", "table_name") 

366 column_name = column.get("name", "column_name") 

367 column_alias = column.get("alias", column_name) 

368 flags = column.get("flags", 0) 

369 if isinstance(flags, list): 

370 flags = sum(flags) 

371 if column.get("size") is None: 

372 length = 1 

373 for row in data: 

374 if isinstance(row, dict): 

375 length = max(len(str(row[column_alias])), length) 

376 else: 

377 length = max(len(str(row[i])), length) 

378 column["size"] = 1 

379 

380 packets.append( 

381 self.packet( 

382 ColumnDefenitionPacket, 

383 schema=column.get("database", "mindsdb_schema"), 

384 table_alias=column.get("table_alias", table_name), 

385 table_name=table_name, 

386 column_alias=column_alias, 

387 column_name=column_name, 

388 column_type=column["type"], 

389 charset=column.get("charset", CHARSET_NUMBERS["utf8_unicode_ci"]), 

390 max_length=column["size"], 

391 flags=flags, 

392 ) 

393 ) 

394 return packets 

395 

396 def get_table_packets(self, result_set: ResultSet, status=0): 

397 data_frame, columns_dict = dump_result_set_to_mysql(result_set) 

398 data = data_frame.to_dict("split")["data"] 

399 

400 # TODO remove columns order 

401 packets = [self.packet(ColumnCountPacket, count=len(columns_dict))] 

402 packets.extend(self._get_column_defenition_packets(columns_dict, data)) 

403 

404 if self.client_capabilities.DEPRECATE_EOF is False: 

405 packets.append(self.packet(EofPacket, status=status)) 

406 

407 packets += [self.packet(ResultsetRowPacket, data=x) for x in data] 

408 return packets 

409 

410 def send_table_packets(self, result_set: ResultSet, status: int = 0): 

411 """Send table packets to client, piece by piece 

412 

413 Args: 

414 result_set (ResultSet): the result set to send 

415 status (int): the status to send 

416 

417 Returns: 

418 None 

419 """ 

420 columns_dicts = dump_columns_info(result_set, infer_column_size=True) 

421 

422 packets = [self.packet(ColumnCountPacket, count=len(columns_dicts))] 

423 packets.extend(self._get_column_defenition_packets(columns_dicts)) 

424 

425 if self.client_capabilities.DEPRECATE_EOF is False: 

426 packets.append(self.packet(EofPacket, status=status)) 

427 self.send_package_group(packets) 

428 

429 chunk_size = 1000 

430 df = result_set.get_raw_df() 

431 if len(df) > 0: 

432 for chunk in dump_chunks(df, columns_dicts, chunk_size): 

433 for i in range(len(chunk)): 

434 chunk[i] = self.packet(body=chunk[i], length=len(chunk[i])).accum() 

435 self.socket.sendall(b"".join(chunk)) 

436 

437 def decode_utf(self, text): 

438 try: 

439 return text.decode("utf-8") 

440 except Exception: 

441 raise WrongCharsetError(f"SQL contains non utf-8 values: {text}") 

442 

443 def is_cloud_connection(self): 

444 """Determine source of connection. Must be call before handshake. 

445 Idea based on: real mysql connection does not send anything before server handshake, so 

446 soket should be in 'out' state. In opposite, clout connection sends '0000' right after 

447 connection. '0000' selected because in real mysql connection it should be lenght of package, 

448 and it can not be 0. 

449 """ 

450 is_cloud = config.get("cloud", False) 

451 

452 if sys.platform != "linux" or is_cloud is False: 

453 return {"is_cloud": False} 

454 

455 read_poller = select.poll() 

456 read_poller.register(self.request, select.POLLIN) 

457 events = read_poller.poll(30) 

458 

459 if len(events) == 0: 

460 return {"is_cloud": False} 

461 

462 first_byte = self.request.recv(4, socket.MSG_PEEK) 

463 if first_byte == b"\x00\x00\x00\x00": 

464 self.request.recv(4) 

465 client_capabilities = self.request.recv(8) 

466 client_capabilities = struct.unpack("L", client_capabilities)[0] 

467 

468 company_id = self.request.recv(4) 

469 company_id = struct.unpack("I", company_id)[0] 

470 

471 user_class = self.request.recv(1) 

472 user_class = struct.unpack("B", user_class)[0] 

473 email_confirmed = 1 

474 if user_class > 1: 

475 email_confirmed = (user_class >> 2) & 1 

476 user_class = user_class & 3 

477 

478 database_name_len = self.request.recv(2) 

479 database_name_len = struct.unpack("H", database_name_len)[0] 

480 

481 database_name = "" 

482 if database_name_len > 0: 

483 database_name = self.request.recv(database_name_len).decode() 

484 

485 return { 

486 "is_cloud": True, 

487 "client_capabilities": client_capabilities, 

488 "company_id": company_id, 

489 "user_class": user_class, 

490 "database": database_name, 

491 "email_confirmed": email_confirmed, 

492 } 

493 

494 return {"is_cloud": False} 

495 

496 def to_mysql_columns(self, columns_list: list[Column]) -> list[dict[str, str | int]]: 

497 database_name = None if self.session.database == "" else self.session.database.lower() 

498 return [column_to_mysql_column_dict(column, database_name=database_name) for column in columns_list] 

499 

500 @profiler.profile() 

501 def process_query(self, sql: str) -> SQLAnswer: 

502 log.log_ram_info(logger) 

503 executor = Executor(session=self.session, sqlserver=self) 

504 executor.query_execute(sql) 

505 executor_answer = executor.executor_answer 

506 

507 if executor_answer.data is None: 

508 resp = SQLAnswer( 

509 resp_type=RESPONSE_TYPE.OK, 

510 state_track=executor_answer.state_track, 

511 affected_rows=executor_answer.affected_rows, 

512 ) 

513 else: 

514 resp = SQLAnswer( 

515 resp_type=RESPONSE_TYPE.TABLE, 

516 state_track=executor_answer.state_track, 

517 result_set=executor_answer.data, 

518 status=executor.server_status, 

519 affected_rows=executor_answer.affected_rows, 

520 mysql_types=executor_answer.data.mysql_types, 

521 ) 

522 

523 # Increment the counter and include metadata in attributes 

524 increment_otel_query_request_counter(ctx.get_metadata(query=sql)) 

525 

526 return resp 

527 

528 def answer_stmt_prepare(self, sql): 

529 executor = Executor(session=self.session, sqlserver=self) 

530 stmt_id = self.session.register_stmt(executor) 

531 

532 executor.stmt_prepare(sql) 

533 

534 packages = [ 

535 self.packet( 

536 STMTPrepareHeaderPacket, 

537 stmt_id=stmt_id, 

538 num_columns=len(executor.columns), 

539 num_params=len(executor.params), 

540 ) 

541 ] 

542 

543 if len(executor.params) > 0: 

544 parameters_def = self.to_mysql_columns(executor.params) 

545 packages.extend(self._get_column_defenition_packets(parameters_def)) 

546 if self.client_capabilities.DEPRECATE_EOF is False: 

547 status = sum([SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT]) 

548 packages.append(self.packet(EofPacket, status=status)) 

549 

550 if len(executor.columns) > 0: 

551 columns_def = self.to_mysql_columns(executor.columns) 

552 packages.extend(self._get_column_defenition_packets(columns_def)) 

553 

554 if self.client_capabilities.DEPRECATE_EOF is False: 

555 status = sum([SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT]) 

556 packages.append(self.packet(EofPacket, status=status)) 

557 

558 self.send_package_group(packages) 

559 

560 def answer_stmt_execute(self, stmt_id, parameters): 

561 prepared_stmt = self.session.prepared_stmts[stmt_id] 

562 executor: Executor = prepared_stmt["statement"] 

563 

564 executor.stmt_execute(parameters) 

565 

566 executor_answer: ExecuteAnswer = executor.executor_answer 

567 

568 if executor_answer.data is None: 

569 resp = SQLAnswer(resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track) 

570 return self.send_query_answer(resp) 

571 

572 # TODO prepared_stmt['type'] == 'lock' is not used but it works 

573 result_set = executor_answer.data 

574 data_frame, columns_dict = dump_result_set_to_mysql(result_set) 

575 data = data_frame.to_dict("split")["data"] 

576 

577 packages = [self.packet(ColumnCountPacket, count=len(columns_dict))] 

578 packages.extend(self._get_column_defenition_packets(columns_dict)) 

579 

580 if self.client_capabilities.DEPRECATE_EOF is False: 

581 packages.append(self.packet(EofPacket, status=0x0062)) 

582 

583 # send all 

584 for row in data: 

585 packages.append(self.packet(BinaryResultsetRowPacket, data=row, columns=columns_dict)) 

586 

587 server_status = executor.server_status or 0x0002 

588 packages.append(self.last_packet(status=server_status)) 

589 prepared_stmt["fetched"] += len(data) 

590 

591 return self.send_package_group(packages) 

592 

593 def answer_stmt_fetch(self, stmt_id, limit): 

594 prepared_stmt = self.session.prepared_stmts[stmt_id] 

595 executor = prepared_stmt["statement"] 

596 fetched = prepared_stmt["fetched"] 

597 executor_answer: ExecuteAnswer = executor.executor_answer 

598 

599 if executor_answer.data is None: 

600 resp = SQLAnswer(resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track) 

601 return self.send_query_answer(resp) 

602 

603 packages = [] 

604 columns = self.to_mysql_columns(executor_answer.data.columns) 

605 for row in executor_answer.data[fetched:limit].to_lists(): 

606 packages.append(self.packet(BinaryResultsetRowPacket, data=row, columns=columns)) 

607 

608 prepared_stmt["fetched"] += len(executor_answer.data[fetched:limit]) 

609 

610 if len(executor_answer.data) <= limit + fetched: 

611 status = sum( 

612 [ 

613 SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT, 

614 SERVER_STATUS.SERVER_STATUS_LAST_ROW_SENT, 

615 ] 

616 ) 

617 else: 

618 status = sum( 

619 [ 

620 SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT, 

621 SERVER_STATUS.SERVER_STATUS_CURSOR_EXISTS, 

622 ] 

623 ) 

624 

625 packages.append(self.last_packet(status=status)) 

626 self.send_package_group(packages) 

627 

628 def handle(self): 

629 """ 

630 Handle new incoming connections 

631 :return: 

632 """ 

633 ctx.set_default() 

634 

635 self.server.hook_before_handle() 

636 

637 logger.debug("handle new incoming connection") 

638 cloud_connection = self.is_cloud_connection() 

639 

640 ctx.company_id = cloud_connection.get("company_id") 

641 

642 self.init_session() 

643 if cloud_connection["is_cloud"] is False: 

644 if self.handshake() is False: 

645 return 

646 else: 

647 ctx.user_class = cloud_connection["user_class"] 

648 ctx.email_confirmed = cloud_connection["email_confirmed"] 

649 self.client_capabilities = ClentCapabilities(cloud_connection["client_capabilities"]) 

650 self.session.database = cloud_connection["database"] 

651 self.session.username = "cloud" 

652 self.session.auth = True 

653 

654 while True: 

655 logger.debug("Got a new packet") 

656 p = self.packet(CommandPacket) 

657 

658 try: 

659 success = p.get() 

660 except Exception: 

661 logger.exception("Session closed, on packet read error:") 

662 return 

663 

664 if success is False: 

665 logger.debug("Session closed by client") 

666 return 

667 

668 logger.debug("Command TYPE: {type}".format(type=getConstName(COMMANDS, p.type.value))) 

669 

670 sql = None 

671 response = None 

672 error_type = None 

673 error_code = None 

674 error_text = None 

675 error_traceback = None 

676 

677 try: 

678 if p.type.value == COMMANDS.COM_QUERY: 

679 sql = self.decode_utf(p.sql.value) 

680 sql = clear_sql(sql) 

681 logger.debug(f"Incoming query: {sql}") 

682 profiler.set_meta(query=sql, api="mysql", environment=config.get("environment")) 

683 with profiler.Context("mysql_query_processing"), mark_process("mysql_query"): 

684 response = self.process_query(sql) 

685 elif p.type.value == COMMANDS.COM_STMT_PREPARE: 

686 sql = self.decode_utf(p.sql.value) 

687 self.answer_stmt_prepare(sql) 

688 elif p.type.value == COMMANDS.COM_STMT_EXECUTE: 

689 self.answer_stmt_execute(p.stmt_id.value, p.parameters) 

690 elif p.type.value == COMMANDS.COM_STMT_FETCH: 

691 self.answer_stmt_fetch(p.stmt_id.value, p.limit.value) 

692 elif p.type.value == COMMANDS.COM_STMT_CLOSE: 

693 self.answer_stmt_close(p.stmt_id.value) 

694 elif p.type.value == COMMANDS.COM_QUIT: 

695 logger.debug("Session closed, on client disconnect") 

696 self.session = None 

697 break 

698 elif p.type.value == COMMANDS.COM_INIT_DB: 

699 new_database = p.database.value.decode() 

700 

701 executor = Executor(session=self.session, sqlserver=self) 

702 executor.change_default_db(new_database) 

703 

704 response = SQLAnswer(RESPONSE_TYPE.OK) 

705 elif p.type.value == COMMANDS.COM_FIELD_LIST: 

706 # this command is deprecated, but console client still use it. 

707 response = SQLAnswer(RESPONSE_TYPE.OK) 

708 elif p.type.value == COMMANDS.COM_STMT_RESET: 

709 response = SQLAnswer(RESPONSE_TYPE.OK) 

710 elif p.type.value == COMMANDS.COM_PING: 

711 response = SQLAnswer(RESPONSE_TYPE.OK) 

712 elif p.type.value == COMMANDS.COM_CHANGE_USER: 

713 # This package should trigger re-authentication. For now it is forbidden. 

714 logger.warning("Got COM_CHANGE_USER packet that could not be processed, return error.") 

715 response = SQLAnswer( 

716 resp_type=RESPONSE_TYPE.ERROR, 

717 error_code=None, 

718 error_message="Packet COM_CHANGE_USER could not be processed", 

719 ) 

720 elif p.type.value == COMMANDS.COM_DEBUG: 

721 response = SQLAnswer(resp_type=RESPONSE_TYPE.EOF) 

722 elif p.type.value == COMMANDS.COM_SET_OPTION: 

723 # While regular MySQL options have no effect on mindsdb, we can safely return Ok. 

724 logger.warning("Unexpected packet COM_SET_OPTION recieved, return ok.") 

725 response = SQLAnswer(RESPONSE_TYPE.OK) 

726 elif p.type.value == COMMANDS.COM_SLEEP: 

727 # error - is the only valid answer for the packet 

728 response = SQLAnswer( 

729 resp_type=RESPONSE_TYPE.ERROR, 

730 error_code=None, 

731 error_message="", 

732 ) 

733 elif p.type.value == COMMANDS.COM_PROCESS_KILL: 

734 logger.warning("Unexpected packet COM_PROCESS_KILL recieved, return error.") 

735 response = SQLAnswer( 

736 resp_type=RESPONSE_TYPE.ERROR, 

737 error_code=None, 

738 error_message="Packet COM_PROCESS_KILL could not be processed", 

739 ) 

740 elif p.type.value == COMMANDS.COM_RESET_CONNECTION: 

741 logger.warning("Unexpected packet COM_RESET_CONNECTION recieved, return error.") 

742 response = SQLAnswer( 

743 resp_type=RESPONSE_TYPE.ERROR, 

744 error_code=None, 

745 error_message="Packet COM_RESET_CONNECTION could not be processed", 

746 ) 

747 elif p.type.value == COMMANDS.COM_SHUTDOWN: 

748 logger.warning("Unexpected packet COM_SHUTDOWN recieved, return error.") 

749 response = SQLAnswer( 

750 resp_type=RESPONSE_TYPE.ERROR, 

751 error_code=None, 

752 error_message="Packet COM_SHUTDOWN could not be processed", 

753 ) 

754 else: 

755 logger.warning("Command has no specific handler, return OK msg") 

756 logger.debug(str(p)) 

757 response = SQLAnswer(RESPONSE_TYPE.OK) 

758 

759 except (QueryError, executor_exceptions.ExecutorException, executor_exceptions.UnknownError) as e: 

760 error_type = "expected" if e.is_expected else "unexpected" 

761 error_code = e.mysql_error_code 

762 if e.is_expected: 

763 if logger.isEnabledFor(logging.DEBUG): 

764 logger.info("Query execution failed with expected error:", exc_info=True) 

765 else: 

766 logger.info(f"Query execution failed with expected error: {e}") 

767 else: 

768 logger.exception("Query execution failed with error") 

769 response = SQLAnswer( 

770 resp_type=RESPONSE_TYPE.ERROR, 

771 error_code=error_code, 

772 error_message=str(e), 

773 ) 

774 

775 except Exception as e: 

776 error_type = "unexpected" 

777 error_traceback = traceback.format_exc() 

778 logger.exception("ERROR while executing query:") 

779 error_code = ERR.ER_SYNTAX_ERROR 

780 response = SQLAnswer( 

781 resp_type=RESPONSE_TYPE.ERROR, 

782 error_code=error_code, 

783 error_message=str(e), 

784 ) 

785 

786 if response is not None: 

787 self.send_query_answer(response) 

788 if response.type == RESPONSE_TYPE.ERROR: 

789 error_text = response.error_message 

790 error_code = response.error_code 

791 error_type = error_type or "expected" 

792 

793 hooks.after_api_query( 

794 company_id=ctx.company_id, 

795 api="mysql", 

796 command=getConstName(COMMANDS, p.type.value), 

797 payload=sql, 

798 error_type=error_type, 

799 error_code=error_code, 

800 error_text=error_text, 

801 traceback=error_traceback, 

802 ) 

803 

804 def packet(self, packetClass=Packet, **kwargs): 

805 """ 

806 Factory method for packets 

807 

808 :param packetClass: 

809 :param kwargs: 

810 :return: 

811 """ 

812 p = packetClass(socket=self.socket, session=self.session, proxy=self, **kwargs) 

813 self.session.inc_packet_sequence_number() 

814 return p 

815 

816 def last_packet(self, status=0x0002): 

817 if self.client_capabilities.DEPRECATE_EOF is True: 

818 return self.packet(OkPacket, eof=True, status=status) 

819 else: 

820 return self.packet(EofPacket, status=status) 

821 

822 def set_context(self, context): 

823 if "db" in context: 823 ↛ 824line 823 didn't jump to line 824 because the condition on line 823 was never true

824 self.session.database = context["db"] 

825 else: 

826 self.session.database = config.get("default_project") 

827 

828 if "profiling" in context: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true

829 self.session.profiling = context["profiling"] 

830 if "predictor_cache" in context: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true

831 self.session.predictor_cache = context["predictor_cache"] 

832 if "show_secrets" in context: 832 ↛ 833line 832 didn't jump to line 833 because the condition on line 832 was never true

833 self.session.show_secrets = context["show_secrets"] 

834 

835 def get_context(self): 

836 context = {"show_secrets": self.session.show_secrets} 

837 if self.session.database is not None: 837 ↛ 839line 837 didn't jump to line 839 because the condition on line 837 was always true

838 context["db"] = self.session.database 

839 if self.session.profiling is True: 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true

840 context["profiling"] = True 

841 if self.session.predictor_cache is False: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true

842 context["predictor_cache"] = False 

843 

844 return context 

845 

846 @staticmethod 

847 def startProxy(): 

848 """ 

849 Create a server and wait for incoming connections until Ctrl-C 

850 """ 

851 global logger 

852 

853 cert_path = config["api"]["mysql"].get("certificate_path") 

854 if cert_path is None or cert_path == "": 

855 cert_path = tempfile.mkstemp(prefix="mindsdb_cert_", text=True)[1] 

856 make_ssl_cert(cert_path) 

857 atexit.register(lambda: os.remove(cert_path)) 

858 elif not os.path.exists(cert_path): 

859 logger.error("Certificate defined in 'certificate_path' setting does not exist") 

860 

861 # TODO make it session local 

862 server_capabilities.set(CAPABILITIES.CLIENT_SSL, config["api"]["mysql"]["ssl"]) 

863 

864 host = config["api"]["mysql"]["host"] 

865 port = int(config["api"]["mysql"]["port"]) 

866 

867 logger.info(f"Starting MindsDB Mysql proxy server on tcp://{host}:{port}") 

868 

869 SocketServer.TCPServer.allow_reuse_address = True 

870 server = MysqlTCPServer((host, port), MysqlProxy) 

871 server.mindsdb_config = config 

872 server.check_auth = partial(check_auth, config=config) 

873 server.cert_path = cert_path 

874 server.connection_id = 0 

875 server.hook_before_handle = empty_fn 

876 

877 atexit.register(MysqlProxy.server_close, srv=server) 

878 

879 # Activate the server; this will keep running until you 

880 # interrupt the program with Ctrl-C 

881 logger.info("Waiting for incoming connections...") 

882 server.serve_forever()