Coverage for mindsdb / integrations / handlers / byom_handler / byom_handler.py: 13%
393 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1"""BYOM: Bring Your Own Model
3env vars to contloll BYOM:
4 - MINDSDB_BYOM_ENABLED - can BYOM be used or not. Locally enabled by default.
5 - MINDSDB_BYOM_INHOUSE_ENABLED - enable or disable 'inhouse' BYOM usage. Locally enabled by default.
6 - MINDSDB_BYOM_DEFAULT_TYPE - [inhouse|venv] default byom type. Locally it is 'venv' by default.
7"""
9import os
10import re
11import sys
12import shutil
13import pickle
14import tarfile
15import tempfile
16import subprocess
17from enum import Enum
18from pathlib import Path
19from datetime import datetime
20from typing import Optional, Dict, Union
22import pandas as pd
23from pandas.api import types as pd_types
25from mindsdb.utilities import log
26from mindsdb.utilities.config import Config
27from mindsdb.utilities.fs import safe_extract
28from mindsdb.interfaces.storage import db
29from mindsdb.integrations.libs.base import BaseMLEngine
30from mindsdb.integrations.libs.const import PREDICTOR_STATUS
31from mindsdb.integrations.utilities.utils import format_exception_error
32import mindsdb.utilities.profiler as profiler
35from .proc_wrapper import (
36 pd_decode,
37 pd_encode,
38 encode,
39 decode,
40 BYOM_METHOD,
41 import_string,
42 find_model_class,
43 check_module,
44)
45from .__about__ import __version__
48BYOM_TYPE = Enum("BYOM_TYPE", ["INHOUSE", "VENV"])
50logger = log.getLogger(__name__)
53class BYOMHandler(BaseMLEngine):
54 name = "byom"
56 def __init__(self, model_storage, engine_storage, **kwargs) -> None:
57 # region check availability
58 is_cloud = Config().get("cloud", False)
59 if is_cloud is True:
60 byom_enabled = os.environ.get("MINDSDB_BYOM_ENABLED", "false").lower()
61 if byom_enabled not in ("true", "1"):
62 raise RuntimeError("BYOM is disabled on cloud")
63 # endregion
65 self.model_wrapper = None
67 self.inhouse_model_wrapper = None
68 self.model_wrappers = {}
70 # region read and save set default byom type
71 try:
72 self._default_byom_type = BYOM_TYPE.VENV
73 if os.environ.get("MINDSDB_BYOM_DEFAULT_TYPE") is not None:
74 self._default_byom_type = BYOM_TYPE[os.environ.get("MINDSDB_BYOM_DEFAULT_TYPE").upper()]
75 except KeyError:
76 logger.warning(f"Wrong value of env var MINDSDB_BYOM_DEFAULT_TYPE, {BYOM_TYPE.VENV} will be used")
77 self._default_byom_type = BYOM_TYPE.VENV
78 # endregion
80 # region check if 'inhouse' BYOM is enabled
81 env_var = os.environ.get("MINDSDB_BYOM_INHOUSE_ENABLED")
82 if env_var is None:
83 self._inhouse_enabled = False if is_cloud else True
84 else:
85 self._inhouse_enabled = env_var.lower() in ("true", "1")
86 # endregion
88 super().__init__(model_storage, engine_storage, **kwargs)
90 @staticmethod
91 def normalize_engine_version(engine_version: Union[int, str, None]) -> int:
92 """Cast engine version to int, or return `1` if can not be casted
94 Args:
95 engine_version (Union[int, str, None]): engine version
97 Returns:
98 int: engine version
99 """
100 if isinstance(engine_version, str):
101 try:
102 engine_version = int(engine_version)
103 except Exception:
104 engine_version = 1
105 if isinstance(engine_version, int) is False:
106 engine_version = 1
107 return engine_version
109 @staticmethod
110 def create_validation(target: str, args: dict = None, **kwargs) -> None:
111 if isinstance(args, dict) is False:
112 return
113 using_args = args.get("using", {})
114 engine_version = using_args.get("engine_version")
115 if engine_version is not None:
116 engine_version = BYOMHandler.normalize_engine_version(engine_version)
117 else:
118 connection_args = kwargs["handler_storage"].get_connection_args()
119 versions = connection_args.get("versions")
120 if isinstance(versions, dict):
121 engine_version = max([int(x) for x in versions.keys()])
122 else:
123 engine_version = 1
124 using_args["engine_version"] = engine_version
126 def get_model_engine_version(self) -> int:
127 """Return current model engine version
129 Returns:
130 int: engine version
131 """
132 engine_version = self.model_storage.get_info()["learn_args"].get("using", {}).get("engine_version")
133 engine_version = BYOMHandler.normalize_engine_version(engine_version)
134 return engine_version
136 def normalize_byom_type(self, byom_type: Optional[str]) -> BYOM_TYPE:
137 if byom_type is not None:
138 byom_type = BYOM_TYPE[byom_type.upper()]
139 else:
140 byom_type = self._default_byom_type
141 if byom_type == BYOM_TYPE.INHOUSE and self._inhouse_enabled is False:
142 raise Exception("'Inhouse' BYOM engine type can not be used")
143 return byom_type
145 def _get_model_proxy(self, version=None):
146 if version is None:
147 version = 1
148 if isinstance(version, str):
149 version = int(version)
150 version_mark = ""
151 if version > 1:
152 version_mark = f"_{version}"
153 version_str = str(version)
155 self.engine_storage.fileStorage.pull()
156 try:
157 code = self.engine_storage.fileStorage.file_get(f"code{version_mark}")
158 modules_str = self.engine_storage.fileStorage.file_get(f"modules{version_mark}")
159 except FileNotFoundError:
160 raise Exception(f"Engine version '{version}' does not exists")
162 if version_str not in self.model_wrappers:
163 connection_args = self.engine_storage.get_connection_args()
164 version_meta = connection_args["versions"][version_str]
166 try:
167 engine_version_type = BYOM_TYPE[version_meta.get("type", self._default_byom_type.name).upper()]
168 except KeyError:
169 raise Exception("Unknown BYOM engine type")
171 if engine_version_type == BYOM_TYPE.INHOUSE:
172 if self._inhouse_enabled is False:
173 raise Exception("'Inhouse' BYOM engine type can not be used")
174 if self.inhouse_model_wrapper is None:
175 self.inhouse_model_wrapper = ModelWrapperUnsafe(
176 code=code,
177 modules_str=modules_str,
178 engine_id=self.engine_storage.integration_id,
179 engine_version=version,
180 )
181 self.model_wrappers[version_str] = self.inhouse_model_wrapper
182 elif engine_version_type == BYOM_TYPE.VENV:
183 if version_meta.get("venv_status") != "ready":
184 version_meta["venv_status"] = "creating"
185 self.engine_storage.update_connection_args(connection_args)
186 self.model_wrappers[version_str] = ModelWrapperSafe(
187 code=code,
188 modules_str=modules_str,
189 engine_id=self.engine_storage.integration_id,
190 engine_version=version,
191 )
192 version_meta["venv_status"] = "ready"
193 self.engine_storage.update_connection_args(connection_args)
195 return self.model_wrappers[version_str]
197 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
198 engine_version = self.get_model_engine_version()
199 mp = self._get_model_proxy(engine_version)
200 model_state = self.model_storage.file_get("model")
201 return mp.describe(model_state, attribute)
203 def create(self, target, df=None, args=None, **kwargs):
204 using_args = args.get("using", {})
205 engine_version = using_args.get("engine_version")
207 model_proxy = self._get_model_proxy(engine_version)
208 model_state = model_proxy.train(df, target, args)
210 self.model_storage.file_set("model", model_state)
212 # TODO return columns?
214 def convert_type(field_type):
215 if pd_types.is_integer_dtype(field_type):
216 return "integer"
217 elif pd_types.is_numeric_dtype(field_type):
218 return "float"
219 elif pd_types.is_datetime64_any_dtype(field_type):
220 return "datetime"
221 else:
222 return "categorical"
224 columns = {target: convert_type(object)}
226 self.model_storage.columns_set(columns)
228 def predict(self, df, args=None):
229 pred_args = args.get("predict_params", {})
231 engine_version = pred_args.get("engine_version")
232 if engine_version is not None:
233 engine_version = int(engine_version)
234 else:
235 engine_version = self.get_model_engine_version()
237 model_proxy = self._get_model_proxy(engine_version)
238 model_state = self.model_storage.file_get("model")
239 pred_df = model_proxy.predict(df, model_state, pred_args)
241 return pred_df
243 def create_engine(self, connection_args):
244 code_path = Path(connection_args["code"])
245 self.engine_storage.fileStorage.file_set("code", code_path.read_bytes())
247 requirements_path = Path(connection_args["modules"])
248 self.engine_storage.fileStorage.file_set("modules", requirements_path.read_bytes())
250 self.engine_storage.fileStorage.push()
252 self.engine_storage.update_connection_args(
253 {
254 "handler_version": __version__,
255 "mode": connection_args.get("mode"),
256 "versions": {
257 "1": {
258 "code": code_path.name,
259 "requirements": requirements_path.name,
260 "type": self.normalize_byom_type(connection_args.get("type")).name.lower(),
261 }
262 },
263 }
264 )
266 model_proxy = self._get_model_proxy()
267 try:
268 info = model_proxy.check(connection_args.get("mode"))
269 self.engine_storage.json_set("methods", info["methods"])
271 except Exception as e:
272 if hasattr(model_proxy, "remove_venv"):
273 model_proxy.remove_venv()
274 raise e
276 def update_engine(self, connection_args: dict) -> None:
277 """Add new version of engine
279 Args:
280 connection_args (dict): paths to code and requirements
281 """
282 code_path = Path(connection_args["code"])
283 requirements_path = Path(connection_args["modules"])
285 engine_connection_args = self.engine_storage.get_connection_args()
286 if isinstance(engine_connection_args, dict) is False or "handler_version" not in engine_connection_args:
287 engine_connection_args = {
288 "handler_version": __version__,
289 "versions": {
290 "1": {
291 "code": "code.py",
292 "requirements": "requirements.txt",
293 "type": self._default_byom_type.name.lower(),
294 }
295 },
296 }
297 new_version = str(max([int(x) for x in engine_connection_args["versions"].keys()]) + 1)
299 engine_connection_args["versions"][new_version] = {
300 "code": code_path.name,
301 "requirements": requirements_path.name,
302 "type": self.normalize_byom_type(connection_args.get("type")).name.lower(),
303 }
305 self.engine_storage.fileStorage.file_set(f"code_{new_version}", code_path.read_bytes())
307 self.engine_storage.fileStorage.file_set(f"modules_{new_version}", requirements_path.read_bytes())
308 self.engine_storage.fileStorage.push()
310 self.engine_storage.update_connection_args(engine_connection_args)
312 model_proxy = self._get_model_proxy(new_version)
313 try:
314 methods = model_proxy.check()
315 self.engine_storage.json_set("methods", methods)
317 except Exception as e:
318 if hasattr(model_proxy, "remove_venv"):
319 model_proxy.remove_venv()
320 raise e
322 def function_list(self):
323 return self.engine_storage.json_get("methods")
325 def function_call(self, name, args):
326 mp = self._get_model_proxy()
327 return mp.func_call(name, args)
329 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
330 using_args = args.get("using", {})
331 engine_version = using_args.get("engine_version")
333 model_storage = self.model_storage
334 # TODO: should probably refactor at some point, as a bit of the logic is shared with lightwood's finetune logic
335 try:
336 base_predictor_id = args["base_model_id"]
337 base_predictor_record = db.Predictor.query.get(base_predictor_id)
338 if base_predictor_record.status != PREDICTOR_STATUS.COMPLETE:
339 raise Exception("Base model must be in status 'complete'")
341 predictor_id = model_storage.predictor_id
342 predictor_record = db.Predictor.query.get(predictor_id)
344 predictor_record.data = {
345 "training_log": "training"
346 } # TODO move to ModelStorage (don't work w/ db directly)
347 predictor_record.training_start_at = datetime.now()
348 predictor_record.status = PREDICTOR_STATUS.FINETUNING # TODO: parallel execution block
349 db.session.commit()
351 model_proxy = self._get_model_proxy(engine_version)
352 model_state = self.base_model_storage.file_get("model")
353 model_state = model_proxy.finetune(df, model_state, args=args.get("using", {}))
355 # region hack to speedup file saving
356 with profiler.Context("finetune-byom-write-file"):
357 dest_abs_path = model_storage.fileStorage.folder_path / "model"
358 with open(dest_abs_path, "wb") as fd:
359 fd.write(model_state)
360 model_storage.fileStorage.push(compression_level=0)
361 # endregion
363 predictor_record.update_status = "up_to_date"
364 predictor_record.status = PREDICTOR_STATUS.COMPLETE
365 predictor_record.training_stop_at = datetime.now()
366 db.session.commit()
368 except Exception as e:
369 logger.error("Unexpected error during BYOM finetune:", exc_info=True)
370 predictor_id = model_storage.predictor_id
371 predictor_record = db.Predictor.query.with_for_update().get(predictor_id)
372 error_message = format_exception_error(e)
373 predictor_record.data = {"error": error_message}
374 predictor_record.status = PREDICTOR_STATUS.ERROR
375 db.session.commit()
376 raise
378 finally:
379 if predictor_record.training_stop_at is None:
380 predictor_record.training_stop_at = datetime.now()
381 db.session.commit()
384class ModelWrapperUnsafe:
385 """Model wrapper that executes learn/predict in current process"""
387 def __init__(self, code, modules_str, engine_id, engine_version: int):
388 self.module = import_string(code)
390 model_instance = None
391 model_class = find_model_class(self.module)
392 if model_class is not None:
393 model_instance = model_class()
395 self.model_instance = model_instance
397 def train(self, df, target, args):
398 self.model_instance.train(df, target, args)
399 return pickle.dumps(self.model_instance.__dict__, protocol=5)
401 def predict(self, df, model_state, args):
402 model_state = pickle.loads(model_state)
403 self.model_instance.__dict__ = model_state
404 try:
405 result = self.model_instance.predict(df, args)
406 except Exception:
407 result = self.model_instance.predict(df)
408 return result
410 def finetune(self, df, model_state, args):
411 self.model_instance.__dict__ = pickle.loads(model_state)
413 call_args = [df]
414 if args:
415 call_args.append(args)
417 self.model_instance.finetune(df, args)
419 return pickle.dumps(self.model_instance.__dict__, protocol=5)
421 def describe(self, model_state, attribute: Optional[str] = None) -> pd.DataFrame:
422 if hasattr(self.model_instance, "describe"):
423 model_state = pickle.loads(model_state)
424 self.model_instance.__dict__ = model_state
425 return self.model_instance.describe(attribute)
426 return pd.DataFrame()
428 def func_call(self, func_name, args):
429 func = getattr(self.module, func_name)
430 return func(*args)
432 def check(self, mode: str = None):
433 methods = check_module(self.module, mode)
434 return methods
437class ModelWrapperSafe:
438 """Model wrapper that executes learn/predict in venv"""
440 def __init__(self, code, modules_str, engine_id, engine_version: int):
441 self.code = code
442 modules = self.parse_requirements(modules_str)
444 self.config = Config()
445 self.is_cloud = Config().get("cloud", False)
447 self.env_path = None
448 self.env_storage_path = None
449 self.prepare_env(modules, engine_id, engine_version)
451 def prepare_env(self, modules, engine_id, engine_version: int):
452 try:
453 import virtualenv
455 base_path = self.config.get("byom", {}).get("venv_path")
456 if base_path is None:
457 # create in root path
458 base_path = Path(self.config.paths["root"]) / "venvs"
459 else:
460 base_path = Path(base_path)
461 base_path.mkdir(parents=True, exist_ok=True)
463 env_folder_name = f"env_{engine_id}"
464 if isinstance(engine_version, int) and engine_version > 1:
465 env_folder_name = f"{env_folder_name}_{engine_version}"
467 self.env_storage_path = base_path / env_folder_name
468 if self.is_cloud:
469 bese_env_path = Path(tempfile.gettempdir()) / "mindsdb" / "venv"
470 bese_env_path.mkdir(parents=True, exist_ok=True)
471 self.env_path = bese_env_path / env_folder_name
472 tar_path = self.env_storage_path.with_suffix(".tar")
473 if self.env_path.exists() is False and tar_path.exists() is True:
474 with tarfile.open(tar_path) as tar:
475 safe_extract(tar, path=bese_env_path)
476 else:
477 self.env_path = self.env_storage_path
479 if sys.platform in ("win32", "cygwin"):
480 exectable_folder_name = "Scripts"
481 else:
482 exectable_folder_name = "bin"
484 pip_cmd = self.env_path / exectable_folder_name / "pip"
485 self.python_path = self.env_path / exectable_folder_name / "python"
487 if self.env_path.exists():
488 # already exists. it means requirements are already installed
489 return
491 # create
492 logger.info(f"Creating new environment: {self.env_path}")
493 virtualenv.cli_run(["-p", sys.executable, str(self.env_path)])
494 logger.info(f"Created new environment: {self.env_path}")
496 if len(modules) > 0:
497 self.install_modules(modules, pip_cmd=pip_cmd)
498 except Exception:
499 # DANGER !!! VENV MUST BE CREATED
500 logger.info("Can't create virtual environment. venv module should be installed")
502 if self.is_cloud:
503 raise
505 self.python_path = Path(sys.executable)
507 # try to install modules everytime
508 self.install_modules(modules, pip_cmd=pip_cmd)
510 # fastest way to copy files if destination is NFS
511 if self.is_cloud and self.env_storage_path != self.env_path:
512 old_cwd = os.getcwd()
513 os.chdir(str(bese_env_path))
514 tar_path = self.env_path.with_suffix(".tar")
515 with tarfile.open(name=str(tar_path), mode="w") as tar:
516 tar.add(str(self.env_path.name))
517 os.chdir(old_cwd)
518 subprocess.run(
519 ["cp", "-R", "--no-preserve=mode,ownership", str(tar_path), str(base_path / tar_path.name)],
520 check=True,
521 shell=False,
522 )
523 tar_path.unlink()
525 def remove_venv(self):
526 if self.env_path is not None and self.env_path.exists():
527 shutil.rmtree(str(self.env_path))
529 if self.is_cloud:
530 tar_path = self.env_storage_path.with_suffix(".tar")
531 tar_path.unlink()
533 def parse_requirements(self, requirements):
534 # get requirements from string
535 # they should be located at the top of the file, before code
537 pattern = "^[\w\\[\\]-]+[=!<>\s]*[\d\.]*[,=!<>\s]*[\d\.]*$" # noqa
538 modules = []
539 for line in requirements.split(b"\n"):
540 line = line.decode().strip()
541 if line:
542 if re.match(pattern, line):
543 modules.append(line)
544 else:
545 raise Exception(f"Wrong requirement: {line}")
547 is_pandas = any([m.lower().startswith("pandas") for m in modules])
548 if not is_pandas:
549 modules.append("pandas>=2.0.0,<2.1.0")
550 modules.append("numpy<2.0.0")
552 # for dataframe serialization
553 modules.append("pyarrow==19.0.0")
554 return modules
556 def install_modules(self, modules, pip_cmd):
557 # install in current environment using pip
558 for module in modules:
559 logger.debug(f"BYOM install module: {module}")
560 p = subprocess.Popen([pip_cmd, "install", module], stderr=subprocess.PIPE)
561 p.wait()
562 if p.returncode != 0:
563 raise Exception(f"Problem with installing module {module}: {p.stderr.read()}")
565 def _run_command(self, params):
566 logger.debug(f"BYOM run command: {params.get('method')}")
567 params_enc = encode(params)
569 wrapper_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "proc_wrapper.py")
570 p = subprocess.Popen(
571 [str(self.python_path), wrapper_path],
572 stdin=subprocess.PIPE,
573 stdout=subprocess.PIPE,
574 stderr=subprocess.PIPE,
575 )
577 p.stdin.write(params_enc)
578 p.stdin.close()
579 ret_enc = p.stdout.read()
581 p.wait()
583 try:
584 ret = decode(ret_enc)
585 except (pickle.UnpicklingError, EOFError):
586 raise RuntimeError(p.stderr.read())
587 return ret
589 def check(self, mode: str = None):
590 params = {
591 "method": BYOM_METHOD.CHECK.value,
592 "code": self.code,
593 "mode": mode,
594 }
595 return self._run_command(params)
597 def train(self, df, target, args):
598 params = {
599 "method": BYOM_METHOD.TRAIN.value,
600 "code": self.code,
601 "df": None,
602 "to_predict": target,
603 "args": args,
604 }
605 if df is not None:
606 params["df"] = pd_encode(df)
608 model_state = self._run_command(params)
609 return model_state
611 def predict(self, df, model_state, args):
612 params = {
613 "method": BYOM_METHOD.PREDICT.value,
614 "code": self.code,
615 "model_state": model_state,
616 "df": pd_encode(df),
617 "args": args,
618 }
619 pred_df = self._run_command(params)
620 return pd_decode(pred_df)
622 def finetune(self, df, model_state, args):
623 params = {
624 "method": BYOM_METHOD.FINETUNE.value,
625 "code": self.code,
626 "model_state": model_state,
627 "df": pd_encode(df),
628 "args": args,
629 }
631 model_state = self._run_command(params)
632 return model_state
634 def describe(self, model_state, attribute: Optional[str] = None) -> pd.DataFrame:
635 params = {
636 "method": BYOM_METHOD.DESCRIBE.value,
637 "code": self.code,
638 "model_state": model_state,
639 "attribute": attribute,
640 }
641 enc_df = self._run_command(params)
642 df = pd_decode(enc_df)
643 return df
645 def func_call(self, func_name, args):
646 params = {
647 "method": BYOM_METHOD.FUNC_CALL.value,
648 "code": self.code,
649 "func_name": func_name,
650 "args": args,
651 }
652 result = self._run_command(params)
653 return result