Coverage for mindsdb / api / mysql / mysql_proxy / mysql_proxy.py: 18%
508 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1"""
2*******************************************************
3 * Copyright (C) 2017 MindsDB Inc. <copyright@mindsdb.com>
4 *
5 * This file is part of MindsDB Server.
6 *
7 * MindsDB Server can not be copied and/or distributed without the express
8 * permission of MindsDB Inc
9 *******************************************************
10"""
12import atexit
13import base64
14import os
15import select
16import socket
17import socketserver as SocketServer
18import ssl
19import struct
20import sys
21import tempfile
22import traceback
23import logging
24from functools import partial
25from typing import List
26from dataclasses import dataclass
28import mindsdb.utilities.hooks as hooks
29import mindsdb.utilities.profiler as profiler
30from mindsdb.utilities.sql import clear_sql
31from mindsdb.api.mysql.mysql_proxy.classes.client_capabilities import ClentCapabilities
32from mindsdb.api.mysql.mysql_proxy.classes.server_capabilities import (
33 server_capabilities,
34)
35from mindsdb.api.executor.controllers import SessionController
36from mindsdb.api.mysql.mysql_proxy.data_types.mysql_packet import Packet
37from mindsdb.api.mysql.mysql_proxy.data_types.mysql_packets import (
38 BinaryResultsetRowPacket,
39 ColumnCountPacket,
40 ColumnDefenitionPacket,
41 CommandPacket,
42 EofPacket,
43 ErrPacket,
44 FastAuthFail,
45 HandshakePacket,
46 HandshakeResponsePacket,
47 OkPacket,
48 PasswordAnswer,
49 ResultsetRowPacket,
50 STMTPrepareHeaderPacket,
51 SwitchOutPacket,
52 SwitchOutResponse,
53)
54from mindsdb.api.mysql.mysql_proxy.executor import Executor
55from mindsdb.api.mysql.mysql_proxy.external_libs.mysql_scramble import (
56 scramble as scramble_func,
57)
58from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import (
59 DEFAULT_AUTH_METHOD,
60 CHARSET_NUMBERS,
61 SERVER_STATUS,
62 CAPABILITIES,
63 COMMANDS,
64 ERR,
65 getConstName,
66)
67from mindsdb.api.executor.data_types.answer import ExecuteAnswer
68from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
69from mindsdb.api.executor import exceptions as executor_exceptions
71from mindsdb.api.common.middleware import check_auth
72from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE
73from mindsdb.api.executor.sql_query.result_set import Column, ResultSet
74from mindsdb.utilities import log
75from mindsdb.utilities.config import config
76from mindsdb.utilities.context import context as ctx
77from mindsdb.utilities.otel import increment_otel_query_request_counter
78from mindsdb.utilities.wizards import make_ssl_cert
79from mindsdb.utilities.exception import QueryError
80from mindsdb.utilities.functions import mark_process
81from mindsdb.api.mysql.mysql_proxy.utilities.dump import (
82 dump_result_set_to_mysql,
83 column_to_mysql_column_dict,
84 dump_columns_info,
85 dump_chunks,
86)
87from mindsdb.api.executor.exceptions import WrongCharsetError
89logger = log.getLogger(__name__)
92def empty_fn():
93 pass
96@dataclass
97class SQLAnswer:
98 resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK
99 result_set: ResultSet | None = None
100 status: int | None = None
101 state_track: List[List] | None = None
102 error_code: int | None = None
103 error_message: str | None = None
104 affected_rows: int | None = None
105 mysql_types: list[MYSQL_DATA_TYPE] | None = None
107 @property
108 def type(self):
109 return self.resp_type
111 def dump_http_response(self) -> dict:
112 if self.resp_type == RESPONSE_TYPE.OK:
113 return {
114 "type": self.resp_type,
115 "affected_rows": self.affected_rows,
116 }
117 elif self.resp_type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE): 117 ↛ 124line 117 didn't jump to line 124 because the condition on line 117 was always true
118 data = self.result_set.to_lists(json_types=True)
119 return {
120 "type": RESPONSE_TYPE.TABLE,
121 "data": data,
122 "column_names": [column.alias or column.name for column in self.result_set.columns],
123 }
124 elif self.resp_type == RESPONSE_TYPE.ERROR:
125 return {
126 "type": RESPONSE_TYPE.ERROR,
127 "error_code": self.error_code or 0,
128 "error_message": self.error_message,
129 }
130 else:
131 raise ValueError(f"Unsupported response type for dump HTTP response: {self.resp_type}")
134class MysqlTCPServer(SocketServer.ThreadingTCPServer):
135 """
136 Custom TCP Server with increased request queue size
137 """
139 request_queue_size = 30
142class MysqlProxy(SocketServer.BaseRequestHandler):
143 """
144 The Main Server controller class
145 """
147 @staticmethod
148 def server_close(srv):
149 srv.server_close()
151 def __init__(self, request, client_address, server):
152 self.charset = "utf8"
153 self.charset_text_type = CHARSET_NUMBERS["utf8_general_ci"]
154 self.session = None
155 self.client_capabilities = None
156 self.connection_id = None
157 super().__init__(request, client_address, server)
159 def init_session(self):
160 logger.debug("New connection [{ip}:{port}]".format(ip=self.client_address[0], port=self.client_address[1]))
162 if self.server.connection_id >= 65025:
163 self.server.connection_id = 0
164 self.server.connection_id += 1
165 self.connection_id = self.server.connection_id
166 self.session = SessionController(api_type="sql")
168 if hasattr(self.server, "salt") and isinstance(self.server.salt, str):
169 self.salt = self.server.salt
170 else:
171 self.salt = base64.b64encode(os.urandom(15)).decode()
173 self.socket = self.request
174 self.logging = logger
176 self.current_transaction = None
178 logger.debug("session salt: {salt}".format(salt=self.salt))
180 def handshake(self):
181 def switch_auth(method="mysql_native_password"):
182 self.packet(SwitchOutPacket, seed=self.salt, method=method).send()
183 switch_out_answer = self.packet(SwitchOutResponse)
184 switch_out_answer.get()
185 password = switch_out_answer.password
186 if method == "mysql_native_password" and len(password) == 0:
187 password = scramble_func("", self.salt)
188 return password
190 def get_fast_auth_password():
191 logger.debug("Asking for fast auth password")
192 self.packet(FastAuthFail).send()
193 password_answer = self.packet(PasswordAnswer)
194 password_answer.get()
195 try:
196 password = password_answer.password.value.decode()
197 except Exception:
198 logger.warning("error: no password in Fast Auth answer")
199 self.packet(
200 ErrPacket,
201 err_code=ERR.ER_PASSWORD_NO_MATCH,
202 msg="Is not password in connection query.",
203 ).send()
204 return None
205 return password
207 username = None
208 password = None
210 logger.debug("send HandshakePacket")
211 self.packet(HandshakePacket).send()
213 handshake_resp = self.packet(HandshakeResponsePacket)
214 handshake_resp.get()
215 if handshake_resp.length == 0:
216 logger.debug("HandshakeResponsePacket empty")
217 self.packet(OkPacket).send()
218 return False
219 self.client_capabilities = ClentCapabilities(handshake_resp.capabilities.value)
221 client_auth_plugin = handshake_resp.client_auth_plugin.value.decode()
223 self.session.is_ssl = False
225 if handshake_resp.type == "SSLRequest":
226 logger.debug("switch to SSL")
227 self.session.is_ssl = True
229 ssl_context = ssl.SSLContext()
230 ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
231 ssl_context.load_cert_chain(self.server.cert_path)
232 ssl_socket = ssl_context.wrap_socket(self.socket, server_side=True, do_handshake_on_connect=True)
234 self.socket = ssl_socket
235 handshake_resp = self.packet(HandshakeResponsePacket)
236 handshake_resp.get()
237 client_auth_plugin = handshake_resp.client_auth_plugin.value.decode()
239 username = handshake_resp.username.value.decode()
241 if client_auth_plugin != DEFAULT_AUTH_METHOD:
242 if client_auth_plugin == "mysql_native_password":
243 password = switch_auth("mysql_native_password")
244 else:
245 new_method = (
246 "caching_sha2_password"
247 if client_auth_plugin == "caching_sha2_password"
248 else "mysql_native_password"
249 )
251 if new_method == "caching_sha2_password" and self.session.is_ssl is False:
252 logger.warning(
253 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
254 "error: cant switch to caching_sha2_password without SSL"
255 )
256 self.packet(
257 ErrPacket,
258 err_code=ERR.ER_PASSWORD_NO_MATCH,
259 msg="caching_sha2_password without SSL not supported",
260 ).send()
261 return False
263 logger.debug(
264 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
265 f"switch auth method to {new_method}"
266 )
267 password = switch_auth(new_method)
269 if new_method == "caching_sha2_password":
270 if password == b"\x00":
271 password = ""
272 else:
273 password = get_fast_auth_password()
274 elif "caching_sha2_password" in client_auth_plugin:
275 logger.debug(
276 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
277 "check auth using caching_sha2_password"
278 )
279 password = handshake_resp.enc_password.value
280 if password == b"\x00":
281 password = ""
282 else:
283 # FIXME https://github.com/mindsdb/mindsdb/issues/1374
284 # if self.session.is_ssl:
285 # password = get_fast_auth_password()
286 # else:
287 password = switch_auth()
288 elif "mysql_native_password" in client_auth_plugin:
289 logger.debug(
290 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
291 "check auth using mysql_native_password"
292 )
293 password = handshake_resp.enc_password.value
294 else:
295 logger.debug(
296 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
297 "unknown method, possible ERROR. Try to switch to mysql_native_password"
298 )
299 password = switch_auth("mysql_native_password")
301 try:
302 self.session.database = handshake_resp.database.value.decode()
303 except Exception:
304 self.session.database = None
305 logger.debug(
306 f"Check auth, user={username}, ssl={self.session.is_ssl}, auth_method={client_auth_plugin}: "
307 f"connecting to database {self.session.database}"
308 )
310 auth_data = self.server.check_auth(username, password, scramble_func, self.salt, ctx.company_id)
311 if auth_data["success"]:
312 self.session.username = auth_data["username"]
313 self.session.auth = True
314 self.packet(OkPacket).send()
315 return True
316 else:
317 self.packet(
318 ErrPacket,
319 err_code=ERR.ER_PASSWORD_NO_MATCH,
320 msg=f"Access denied for user {username}",
321 ).send()
322 logger.warning(f"Access denied for user {username}")
323 return False
325 def send_package_group(self, packages):
326 string = b"".join([x.accum() for x in packages])
327 self.socket.sendall(string)
329 def answer_stmt_close(self, stmt_id):
330 self.session.unregister_stmt(stmt_id)
332 def send_query_answer(self, answer: SQLAnswer):
333 if answer.type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE):
334 packages = []
336 if len(answer.result_set) >= 1000:
337 # for big responses leverage pandas map function to convert data to packages
338 self.send_table_packets(result_set=answer.result_set)
339 else:
340 packages += self.get_table_packets(result_set=answer.result_set)
342 if answer.status is not None:
343 packages.append(self.last_packet(status=answer.status))
344 else:
345 packages.append(self.last_packet())
346 self.send_package_group(packages)
347 elif answer.type == RESPONSE_TYPE.OK:
348 self.packet(OkPacket, state_track=answer.state_track, affected_rows=answer.affected_rows).send()
349 elif answer.type == RESPONSE_TYPE.ERROR:
350 self.packet(ErrPacket, err_code=answer.error_code, msg=answer.error_message).send()
351 elif answer.type == RESPONSE_TYPE.EOF:
352 self.packet(EofPacket).send()
354 def _get_column_defenition_packets(self, columns: dict, data=None):
355 if data is None:
356 data = []
357 packets = []
358 for i, column in enumerate(columns):
359 logger.debug(
360 "%s._get_column_defenition_packets: handling column - %s of %s type",
361 self.__class__.__name__,
362 column,
363 type(column),
364 )
365 table_name = column.get("table_name", "table_name")
366 column_name = column.get("name", "column_name")
367 column_alias = column.get("alias", column_name)
368 flags = column.get("flags", 0)
369 if isinstance(flags, list):
370 flags = sum(flags)
371 if column.get("size") is None:
372 length = 1
373 for row in data:
374 if isinstance(row, dict):
375 length = max(len(str(row[column_alias])), length)
376 else:
377 length = max(len(str(row[i])), length)
378 column["size"] = 1
380 packets.append(
381 self.packet(
382 ColumnDefenitionPacket,
383 schema=column.get("database", "mindsdb_schema"),
384 table_alias=column.get("table_alias", table_name),
385 table_name=table_name,
386 column_alias=column_alias,
387 column_name=column_name,
388 column_type=column["type"],
389 charset=column.get("charset", CHARSET_NUMBERS["utf8_unicode_ci"]),
390 max_length=column["size"],
391 flags=flags,
392 )
393 )
394 return packets
396 def get_table_packets(self, result_set: ResultSet, status=0):
397 data_frame, columns_dict = dump_result_set_to_mysql(result_set)
398 data = data_frame.to_dict("split")["data"]
400 # TODO remove columns order
401 packets = [self.packet(ColumnCountPacket, count=len(columns_dict))]
402 packets.extend(self._get_column_defenition_packets(columns_dict, data))
404 if self.client_capabilities.DEPRECATE_EOF is False:
405 packets.append(self.packet(EofPacket, status=status))
407 packets += [self.packet(ResultsetRowPacket, data=x) for x in data]
408 return packets
410 def send_table_packets(self, result_set: ResultSet, status: int = 0):
411 """Send table packets to client, piece by piece
413 Args:
414 result_set (ResultSet): the result set to send
415 status (int): the status to send
417 Returns:
418 None
419 """
420 columns_dicts = dump_columns_info(result_set, infer_column_size=True)
422 packets = [self.packet(ColumnCountPacket, count=len(columns_dicts))]
423 packets.extend(self._get_column_defenition_packets(columns_dicts))
425 if self.client_capabilities.DEPRECATE_EOF is False:
426 packets.append(self.packet(EofPacket, status=status))
427 self.send_package_group(packets)
429 chunk_size = 1000
430 df = result_set.get_raw_df()
431 if len(df) > 0:
432 for chunk in dump_chunks(df, columns_dicts, chunk_size):
433 for i in range(len(chunk)):
434 chunk[i] = self.packet(body=chunk[i], length=len(chunk[i])).accum()
435 self.socket.sendall(b"".join(chunk))
437 def decode_utf(self, text):
438 try:
439 return text.decode("utf-8")
440 except Exception:
441 raise WrongCharsetError(f"SQL contains non utf-8 values: {text}")
443 def is_cloud_connection(self):
444 """Determine source of connection. Must be call before handshake.
445 Idea based on: real mysql connection does not send anything before server handshake, so
446 soket should be in 'out' state. In opposite, clout connection sends '0000' right after
447 connection. '0000' selected because in real mysql connection it should be lenght of package,
448 and it can not be 0.
449 """
450 is_cloud = config.get("cloud", False)
452 if sys.platform != "linux" or is_cloud is False:
453 return {"is_cloud": False}
455 read_poller = select.poll()
456 read_poller.register(self.request, select.POLLIN)
457 events = read_poller.poll(30)
459 if len(events) == 0:
460 return {"is_cloud": False}
462 first_byte = self.request.recv(4, socket.MSG_PEEK)
463 if first_byte == b"\x00\x00\x00\x00":
464 self.request.recv(4)
465 client_capabilities = self.request.recv(8)
466 client_capabilities = struct.unpack("L", client_capabilities)[0]
468 company_id = self.request.recv(4)
469 company_id = struct.unpack("I", company_id)[0]
471 user_class = self.request.recv(1)
472 user_class = struct.unpack("B", user_class)[0]
473 email_confirmed = 1
474 if user_class > 1:
475 email_confirmed = (user_class >> 2) & 1
476 user_class = user_class & 3
478 database_name_len = self.request.recv(2)
479 database_name_len = struct.unpack("H", database_name_len)[0]
481 database_name = ""
482 if database_name_len > 0:
483 database_name = self.request.recv(database_name_len).decode()
485 return {
486 "is_cloud": True,
487 "client_capabilities": client_capabilities,
488 "company_id": company_id,
489 "user_class": user_class,
490 "database": database_name,
491 "email_confirmed": email_confirmed,
492 }
494 return {"is_cloud": False}
496 def to_mysql_columns(self, columns_list: list[Column]) -> list[dict[str, str | int]]:
497 database_name = None if self.session.database == "" else self.session.database.lower()
498 return [column_to_mysql_column_dict(column, database_name=database_name) for column in columns_list]
500 @profiler.profile()
501 def process_query(self, sql: str) -> SQLAnswer:
502 log.log_ram_info(logger)
503 executor = Executor(session=self.session, sqlserver=self)
504 executor.query_execute(sql)
505 executor_answer = executor.executor_answer
507 if executor_answer.data is None:
508 resp = SQLAnswer(
509 resp_type=RESPONSE_TYPE.OK,
510 state_track=executor_answer.state_track,
511 affected_rows=executor_answer.affected_rows,
512 )
513 else:
514 resp = SQLAnswer(
515 resp_type=RESPONSE_TYPE.TABLE,
516 state_track=executor_answer.state_track,
517 result_set=executor_answer.data,
518 status=executor.server_status,
519 affected_rows=executor_answer.affected_rows,
520 mysql_types=executor_answer.data.mysql_types,
521 )
523 # Increment the counter and include metadata in attributes
524 increment_otel_query_request_counter(ctx.get_metadata(query=sql))
526 return resp
528 def answer_stmt_prepare(self, sql):
529 executor = Executor(session=self.session, sqlserver=self)
530 stmt_id = self.session.register_stmt(executor)
532 executor.stmt_prepare(sql)
534 packages = [
535 self.packet(
536 STMTPrepareHeaderPacket,
537 stmt_id=stmt_id,
538 num_columns=len(executor.columns),
539 num_params=len(executor.params),
540 )
541 ]
543 if len(executor.params) > 0:
544 parameters_def = self.to_mysql_columns(executor.params)
545 packages.extend(self._get_column_defenition_packets(parameters_def))
546 if self.client_capabilities.DEPRECATE_EOF is False:
547 status = sum([SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT])
548 packages.append(self.packet(EofPacket, status=status))
550 if len(executor.columns) > 0:
551 columns_def = self.to_mysql_columns(executor.columns)
552 packages.extend(self._get_column_defenition_packets(columns_def))
554 if self.client_capabilities.DEPRECATE_EOF is False:
555 status = sum([SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT])
556 packages.append(self.packet(EofPacket, status=status))
558 self.send_package_group(packages)
560 def answer_stmt_execute(self, stmt_id, parameters):
561 prepared_stmt = self.session.prepared_stmts[stmt_id]
562 executor: Executor = prepared_stmt["statement"]
564 executor.stmt_execute(parameters)
566 executor_answer: ExecuteAnswer = executor.executor_answer
568 if executor_answer.data is None:
569 resp = SQLAnswer(resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track)
570 return self.send_query_answer(resp)
572 # TODO prepared_stmt['type'] == 'lock' is not used but it works
573 result_set = executor_answer.data
574 data_frame, columns_dict = dump_result_set_to_mysql(result_set)
575 data = data_frame.to_dict("split")["data"]
577 packages = [self.packet(ColumnCountPacket, count=len(columns_dict))]
578 packages.extend(self._get_column_defenition_packets(columns_dict))
580 if self.client_capabilities.DEPRECATE_EOF is False:
581 packages.append(self.packet(EofPacket, status=0x0062))
583 # send all
584 for row in data:
585 packages.append(self.packet(BinaryResultsetRowPacket, data=row, columns=columns_dict))
587 server_status = executor.server_status or 0x0002
588 packages.append(self.last_packet(status=server_status))
589 prepared_stmt["fetched"] += len(data)
591 return self.send_package_group(packages)
593 def answer_stmt_fetch(self, stmt_id, limit):
594 prepared_stmt = self.session.prepared_stmts[stmt_id]
595 executor = prepared_stmt["statement"]
596 fetched = prepared_stmt["fetched"]
597 executor_answer: ExecuteAnswer = executor.executor_answer
599 if executor_answer.data is None:
600 resp = SQLAnswer(resp_type=RESPONSE_TYPE.OK, state_track=executor_answer.state_track)
601 return self.send_query_answer(resp)
603 packages = []
604 columns = self.to_mysql_columns(executor_answer.data.columns)
605 for row in executor_answer.data[fetched:limit].to_lists():
606 packages.append(self.packet(BinaryResultsetRowPacket, data=row, columns=columns))
608 prepared_stmt["fetched"] += len(executor_answer.data[fetched:limit])
610 if len(executor_answer.data) <= limit + fetched:
611 status = sum(
612 [
613 SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT,
614 SERVER_STATUS.SERVER_STATUS_LAST_ROW_SENT,
615 ]
616 )
617 else:
618 status = sum(
619 [
620 SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT,
621 SERVER_STATUS.SERVER_STATUS_CURSOR_EXISTS,
622 ]
623 )
625 packages.append(self.last_packet(status=status))
626 self.send_package_group(packages)
628 def handle(self):
629 """
630 Handle new incoming connections
631 :return:
632 """
633 ctx.set_default()
635 self.server.hook_before_handle()
637 logger.debug("handle new incoming connection")
638 cloud_connection = self.is_cloud_connection()
640 ctx.company_id = cloud_connection.get("company_id")
642 self.init_session()
643 if cloud_connection["is_cloud"] is False:
644 if self.handshake() is False:
645 return
646 else:
647 ctx.user_class = cloud_connection["user_class"]
648 ctx.email_confirmed = cloud_connection["email_confirmed"]
649 self.client_capabilities = ClentCapabilities(cloud_connection["client_capabilities"])
650 self.session.database = cloud_connection["database"]
651 self.session.username = "cloud"
652 self.session.auth = True
654 while True:
655 logger.debug("Got a new packet")
656 p = self.packet(CommandPacket)
658 try:
659 success = p.get()
660 except Exception:
661 logger.exception("Session closed, on packet read error:")
662 return
664 if success is False:
665 logger.debug("Session closed by client")
666 return
668 logger.debug("Command TYPE: {type}".format(type=getConstName(COMMANDS, p.type.value)))
670 sql = None
671 response = None
672 error_type = None
673 error_code = None
674 error_text = None
675 error_traceback = None
677 try:
678 if p.type.value == COMMANDS.COM_QUERY:
679 sql = self.decode_utf(p.sql.value)
680 sql = clear_sql(sql)
681 logger.debug(f"Incoming query: {sql}")
682 profiler.set_meta(query=sql, api="mysql", environment=config.get("environment"))
683 with profiler.Context("mysql_query_processing"), mark_process("mysql_query"):
684 response = self.process_query(sql)
685 elif p.type.value == COMMANDS.COM_STMT_PREPARE:
686 sql = self.decode_utf(p.sql.value)
687 self.answer_stmt_prepare(sql)
688 elif p.type.value == COMMANDS.COM_STMT_EXECUTE:
689 self.answer_stmt_execute(p.stmt_id.value, p.parameters)
690 elif p.type.value == COMMANDS.COM_STMT_FETCH:
691 self.answer_stmt_fetch(p.stmt_id.value, p.limit.value)
692 elif p.type.value == COMMANDS.COM_STMT_CLOSE:
693 self.answer_stmt_close(p.stmt_id.value)
694 elif p.type.value == COMMANDS.COM_QUIT:
695 logger.debug("Session closed, on client disconnect")
696 self.session = None
697 break
698 elif p.type.value == COMMANDS.COM_INIT_DB:
699 new_database = p.database.value.decode()
701 executor = Executor(session=self.session, sqlserver=self)
702 executor.change_default_db(new_database)
704 response = SQLAnswer(RESPONSE_TYPE.OK)
705 elif p.type.value == COMMANDS.COM_FIELD_LIST:
706 # this command is deprecated, but console client still use it.
707 response = SQLAnswer(RESPONSE_TYPE.OK)
708 elif p.type.value == COMMANDS.COM_STMT_RESET:
709 response = SQLAnswer(RESPONSE_TYPE.OK)
710 elif p.type.value == COMMANDS.COM_PING:
711 response = SQLAnswer(RESPONSE_TYPE.OK)
712 elif p.type.value == COMMANDS.COM_CHANGE_USER:
713 # This package should trigger re-authentication. For now it is forbidden.
714 logger.warning("Got COM_CHANGE_USER packet that could not be processed, return error.")
715 response = SQLAnswer(
716 resp_type=RESPONSE_TYPE.ERROR,
717 error_code=None,
718 error_message="Packet COM_CHANGE_USER could not be processed",
719 )
720 elif p.type.value == COMMANDS.COM_DEBUG:
721 response = SQLAnswer(resp_type=RESPONSE_TYPE.EOF)
722 elif p.type.value == COMMANDS.COM_SET_OPTION:
723 # While regular MySQL options have no effect on mindsdb, we can safely return Ok.
724 logger.warning("Unexpected packet COM_SET_OPTION recieved, return ok.")
725 response = SQLAnswer(RESPONSE_TYPE.OK)
726 elif p.type.value == COMMANDS.COM_SLEEP:
727 # error - is the only valid answer for the packet
728 response = SQLAnswer(
729 resp_type=RESPONSE_TYPE.ERROR,
730 error_code=None,
731 error_message="",
732 )
733 elif p.type.value == COMMANDS.COM_PROCESS_KILL:
734 logger.warning("Unexpected packet COM_PROCESS_KILL recieved, return error.")
735 response = SQLAnswer(
736 resp_type=RESPONSE_TYPE.ERROR,
737 error_code=None,
738 error_message="Packet COM_PROCESS_KILL could not be processed",
739 )
740 elif p.type.value == COMMANDS.COM_RESET_CONNECTION:
741 logger.warning("Unexpected packet COM_RESET_CONNECTION recieved, return error.")
742 response = SQLAnswer(
743 resp_type=RESPONSE_TYPE.ERROR,
744 error_code=None,
745 error_message="Packet COM_RESET_CONNECTION could not be processed",
746 )
747 elif p.type.value == COMMANDS.COM_SHUTDOWN:
748 logger.warning("Unexpected packet COM_SHUTDOWN recieved, return error.")
749 response = SQLAnswer(
750 resp_type=RESPONSE_TYPE.ERROR,
751 error_code=None,
752 error_message="Packet COM_SHUTDOWN could not be processed",
753 )
754 else:
755 logger.warning("Command has no specific handler, return OK msg")
756 logger.debug(str(p))
757 response = SQLAnswer(RESPONSE_TYPE.OK)
759 except (QueryError, executor_exceptions.ExecutorException, executor_exceptions.UnknownError) as e:
760 error_type = "expected" if e.is_expected else "unexpected"
761 error_code = e.mysql_error_code
762 if e.is_expected:
763 if logger.isEnabledFor(logging.DEBUG):
764 logger.info("Query execution failed with expected error:", exc_info=True)
765 else:
766 logger.info(f"Query execution failed with expected error: {e}")
767 else:
768 logger.exception("Query execution failed with error")
769 response = SQLAnswer(
770 resp_type=RESPONSE_TYPE.ERROR,
771 error_code=error_code,
772 error_message=str(e),
773 )
775 except Exception as e:
776 error_type = "unexpected"
777 error_traceback = traceback.format_exc()
778 logger.exception("ERROR while executing query:")
779 error_code = ERR.ER_SYNTAX_ERROR
780 response = SQLAnswer(
781 resp_type=RESPONSE_TYPE.ERROR,
782 error_code=error_code,
783 error_message=str(e),
784 )
786 if response is not None:
787 self.send_query_answer(response)
788 if response.type == RESPONSE_TYPE.ERROR:
789 error_text = response.error_message
790 error_code = response.error_code
791 error_type = error_type or "expected"
793 hooks.after_api_query(
794 company_id=ctx.company_id,
795 api="mysql",
796 command=getConstName(COMMANDS, p.type.value),
797 payload=sql,
798 error_type=error_type,
799 error_code=error_code,
800 error_text=error_text,
801 traceback=error_traceback,
802 )
804 def packet(self, packetClass=Packet, **kwargs):
805 """
806 Factory method for packets
808 :param packetClass:
809 :param kwargs:
810 :return:
811 """
812 p = packetClass(socket=self.socket, session=self.session, proxy=self, **kwargs)
813 self.session.inc_packet_sequence_number()
814 return p
816 def last_packet(self, status=0x0002):
817 if self.client_capabilities.DEPRECATE_EOF is True:
818 return self.packet(OkPacket, eof=True, status=status)
819 else:
820 return self.packet(EofPacket, status=status)
822 def set_context(self, context):
823 if "db" in context: 823 ↛ 824line 823 didn't jump to line 824 because the condition on line 823 was never true
824 self.session.database = context["db"]
825 else:
826 self.session.database = config.get("default_project")
828 if "profiling" in context: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true
829 self.session.profiling = context["profiling"]
830 if "predictor_cache" in context: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true
831 self.session.predictor_cache = context["predictor_cache"]
832 if "show_secrets" in context: 832 ↛ 833line 832 didn't jump to line 833 because the condition on line 832 was never true
833 self.session.show_secrets = context["show_secrets"]
835 def get_context(self):
836 context = {"show_secrets": self.session.show_secrets}
837 if self.session.database is not None: 837 ↛ 839line 837 didn't jump to line 839 because the condition on line 837 was always true
838 context["db"] = self.session.database
839 if self.session.profiling is True: 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true
840 context["profiling"] = True
841 if self.session.predictor_cache is False: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true
842 context["predictor_cache"] = False
844 return context
846 @staticmethod
847 def startProxy():
848 """
849 Create a server and wait for incoming connections until Ctrl-C
850 """
851 global logger
853 cert_path = config["api"]["mysql"].get("certificate_path")
854 if cert_path is None or cert_path == "":
855 cert_path = tempfile.mkstemp(prefix="mindsdb_cert_", text=True)[1]
856 make_ssl_cert(cert_path)
857 atexit.register(lambda: os.remove(cert_path))
858 elif not os.path.exists(cert_path):
859 logger.error("Certificate defined in 'certificate_path' setting does not exist")
861 # TODO make it session local
862 server_capabilities.set(CAPABILITIES.CLIENT_SSL, config["api"]["mysql"]["ssl"])
864 host = config["api"]["mysql"]["host"]
865 port = int(config["api"]["mysql"]["port"])
867 logger.info(f"Starting MindsDB Mysql proxy server on tcp://{host}:{port}")
869 SocketServer.TCPServer.allow_reuse_address = True
870 server = MysqlTCPServer((host, port), MysqlProxy)
871 server.mindsdb_config = config
872 server.check_auth = partial(check_auth, config=config)
873 server.cert_path = cert_path
874 server.connection_id = 0
875 server.hook_before_handle = empty_fn
877 atexit.register(MysqlProxy.server_close, srv=server)
879 # Activate the server; this will keep running until you
880 # interrupt the program with Ctrl-C
881 logger.info("Waiting for incoming connections...")
882 server.serve_forever()