Coverage for mindsdb / integrations / handlers / byom_handler / byom_handler.py: 13%

393 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1"""BYOM: Bring Your Own Model 

2 

3env vars to contloll BYOM: 

4 - MINDSDB_BYOM_ENABLED - can BYOM be used or not. Locally enabled by default. 

5 - MINDSDB_BYOM_INHOUSE_ENABLED - enable or disable 'inhouse' BYOM usage. Locally enabled by default. 

6 - MINDSDB_BYOM_DEFAULT_TYPE - [inhouse|venv] default byom type. Locally it is 'venv' by default. 

7""" 

8 

9import os 

10import re 

11import sys 

12import shutil 

13import pickle 

14import tarfile 

15import tempfile 

16import subprocess 

17from enum import Enum 

18from pathlib import Path 

19from datetime import datetime 

20from typing import Optional, Dict, Union 

21 

22import pandas as pd 

23from pandas.api import types as pd_types 

24 

25from mindsdb.utilities import log 

26from mindsdb.utilities.config import Config 

27from mindsdb.utilities.fs import safe_extract 

28from mindsdb.interfaces.storage import db 

29from mindsdb.integrations.libs.base import BaseMLEngine 

30from mindsdb.integrations.libs.const import PREDICTOR_STATUS 

31from mindsdb.integrations.utilities.utils import format_exception_error 

32import mindsdb.utilities.profiler as profiler 

33 

34 

35from .proc_wrapper import ( 

36 pd_decode, 

37 pd_encode, 

38 encode, 

39 decode, 

40 BYOM_METHOD, 

41 import_string, 

42 find_model_class, 

43 check_module, 

44) 

45from .__about__ import __version__ 

46 

47 

48BYOM_TYPE = Enum("BYOM_TYPE", ["INHOUSE", "VENV"]) 

49 

50logger = log.getLogger(__name__) 

51 

52 

53class BYOMHandler(BaseMLEngine): 

54 name = "byom" 

55 

56 def __init__(self, model_storage, engine_storage, **kwargs) -> None: 

57 # region check availability 

58 is_cloud = Config().get("cloud", False) 

59 if is_cloud is True: 

60 byom_enabled = os.environ.get("MINDSDB_BYOM_ENABLED", "false").lower() 

61 if byom_enabled not in ("true", "1"): 

62 raise RuntimeError("BYOM is disabled on cloud") 

63 # endregion 

64 

65 self.model_wrapper = None 

66 

67 self.inhouse_model_wrapper = None 

68 self.model_wrappers = {} 

69 

70 # region read and save set default byom type 

71 try: 

72 self._default_byom_type = BYOM_TYPE.VENV 

73 if os.environ.get("MINDSDB_BYOM_DEFAULT_TYPE") is not None: 

74 self._default_byom_type = BYOM_TYPE[os.environ.get("MINDSDB_BYOM_DEFAULT_TYPE").upper()] 

75 except KeyError: 

76 logger.warning(f"Wrong value of env var MINDSDB_BYOM_DEFAULT_TYPE, {BYOM_TYPE.VENV} will be used") 

77 self._default_byom_type = BYOM_TYPE.VENV 

78 # endregion 

79 

80 # region check if 'inhouse' BYOM is enabled 

81 env_var = os.environ.get("MINDSDB_BYOM_INHOUSE_ENABLED") 

82 if env_var is None: 

83 self._inhouse_enabled = False if is_cloud else True 

84 else: 

85 self._inhouse_enabled = env_var.lower() in ("true", "1") 

86 # endregion 

87 

88 super().__init__(model_storage, engine_storage, **kwargs) 

89 

90 @staticmethod 

91 def normalize_engine_version(engine_version: Union[int, str, None]) -> int: 

92 """Cast engine version to int, or return `1` if can not be casted 

93 

94 Args: 

95 engine_version (Union[int, str, None]): engine version 

96 

97 Returns: 

98 int: engine version 

99 """ 

100 if isinstance(engine_version, str): 

101 try: 

102 engine_version = int(engine_version) 

103 except Exception: 

104 engine_version = 1 

105 if isinstance(engine_version, int) is False: 

106 engine_version = 1 

107 return engine_version 

108 

109 @staticmethod 

110 def create_validation(target: str, args: dict = None, **kwargs) -> None: 

111 if isinstance(args, dict) is False: 

112 return 

113 using_args = args.get("using", {}) 

114 engine_version = using_args.get("engine_version") 

115 if engine_version is not None: 

116 engine_version = BYOMHandler.normalize_engine_version(engine_version) 

117 else: 

118 connection_args = kwargs["handler_storage"].get_connection_args() 

119 versions = connection_args.get("versions") 

120 if isinstance(versions, dict): 

121 engine_version = max([int(x) for x in versions.keys()]) 

122 else: 

123 engine_version = 1 

124 using_args["engine_version"] = engine_version 

125 

126 def get_model_engine_version(self) -> int: 

127 """Return current model engine version 

128 

129 Returns: 

130 int: engine version 

131 """ 

132 engine_version = self.model_storage.get_info()["learn_args"].get("using", {}).get("engine_version") 

133 engine_version = BYOMHandler.normalize_engine_version(engine_version) 

134 return engine_version 

135 

136 def normalize_byom_type(self, byom_type: Optional[str]) -> BYOM_TYPE: 

137 if byom_type is not None: 

138 byom_type = BYOM_TYPE[byom_type.upper()] 

139 else: 

140 byom_type = self._default_byom_type 

141 if byom_type == BYOM_TYPE.INHOUSE and self._inhouse_enabled is False: 

142 raise Exception("'Inhouse' BYOM engine type can not be used") 

143 return byom_type 

144 

145 def _get_model_proxy(self, version=None): 

146 if version is None: 

147 version = 1 

148 if isinstance(version, str): 

149 version = int(version) 

150 version_mark = "" 

151 if version > 1: 

152 version_mark = f"_{version}" 

153 version_str = str(version) 

154 

155 self.engine_storage.fileStorage.pull() 

156 try: 

157 code = self.engine_storage.fileStorage.file_get(f"code{version_mark}") 

158 modules_str = self.engine_storage.fileStorage.file_get(f"modules{version_mark}") 

159 except FileNotFoundError: 

160 raise Exception(f"Engine version '{version}' does not exists") 

161 

162 if version_str not in self.model_wrappers: 

163 connection_args = self.engine_storage.get_connection_args() 

164 version_meta = connection_args["versions"][version_str] 

165 

166 try: 

167 engine_version_type = BYOM_TYPE[version_meta.get("type", self._default_byom_type.name).upper()] 

168 except KeyError: 

169 raise Exception("Unknown BYOM engine type") 

170 

171 if engine_version_type == BYOM_TYPE.INHOUSE: 

172 if self._inhouse_enabled is False: 

173 raise Exception("'Inhouse' BYOM engine type can not be used") 

174 if self.inhouse_model_wrapper is None: 

175 self.inhouse_model_wrapper = ModelWrapperUnsafe( 

176 code=code, 

177 modules_str=modules_str, 

178 engine_id=self.engine_storage.integration_id, 

179 engine_version=version, 

180 ) 

181 self.model_wrappers[version_str] = self.inhouse_model_wrapper 

182 elif engine_version_type == BYOM_TYPE.VENV: 

183 if version_meta.get("venv_status") != "ready": 

184 version_meta["venv_status"] = "creating" 

185 self.engine_storage.update_connection_args(connection_args) 

186 self.model_wrappers[version_str] = ModelWrapperSafe( 

187 code=code, 

188 modules_str=modules_str, 

189 engine_id=self.engine_storage.integration_id, 

190 engine_version=version, 

191 ) 

192 version_meta["venv_status"] = "ready" 

193 self.engine_storage.update_connection_args(connection_args) 

194 

195 return self.model_wrappers[version_str] 

196 

197 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

198 engine_version = self.get_model_engine_version() 

199 mp = self._get_model_proxy(engine_version) 

200 model_state = self.model_storage.file_get("model") 

201 return mp.describe(model_state, attribute) 

202 

203 def create(self, target, df=None, args=None, **kwargs): 

204 using_args = args.get("using", {}) 

205 engine_version = using_args.get("engine_version") 

206 

207 model_proxy = self._get_model_proxy(engine_version) 

208 model_state = model_proxy.train(df, target, args) 

209 

210 self.model_storage.file_set("model", model_state) 

211 

212 # TODO return columns? 

213 

214 def convert_type(field_type): 

215 if pd_types.is_integer_dtype(field_type): 

216 return "integer" 

217 elif pd_types.is_numeric_dtype(field_type): 

218 return "float" 

219 elif pd_types.is_datetime64_any_dtype(field_type): 

220 return "datetime" 

221 else: 

222 return "categorical" 

223 

224 columns = {target: convert_type(object)} 

225 

226 self.model_storage.columns_set(columns) 

227 

228 def predict(self, df, args=None): 

229 pred_args = args.get("predict_params", {}) 

230 

231 engine_version = pred_args.get("engine_version") 

232 if engine_version is not None: 

233 engine_version = int(engine_version) 

234 else: 

235 engine_version = self.get_model_engine_version() 

236 

237 model_proxy = self._get_model_proxy(engine_version) 

238 model_state = self.model_storage.file_get("model") 

239 pred_df = model_proxy.predict(df, model_state, pred_args) 

240 

241 return pred_df 

242 

243 def create_engine(self, connection_args): 

244 code_path = Path(connection_args["code"]) 

245 self.engine_storage.fileStorage.file_set("code", code_path.read_bytes()) 

246 

247 requirements_path = Path(connection_args["modules"]) 

248 self.engine_storage.fileStorage.file_set("modules", requirements_path.read_bytes()) 

249 

250 self.engine_storage.fileStorage.push() 

251 

252 self.engine_storage.update_connection_args( 

253 { 

254 "handler_version": __version__, 

255 "mode": connection_args.get("mode"), 

256 "versions": { 

257 "1": { 

258 "code": code_path.name, 

259 "requirements": requirements_path.name, 

260 "type": self.normalize_byom_type(connection_args.get("type")).name.lower(), 

261 } 

262 }, 

263 } 

264 ) 

265 

266 model_proxy = self._get_model_proxy() 

267 try: 

268 info = model_proxy.check(connection_args.get("mode")) 

269 self.engine_storage.json_set("methods", info["methods"]) 

270 

271 except Exception as e: 

272 if hasattr(model_proxy, "remove_venv"): 

273 model_proxy.remove_venv() 

274 raise e 

275 

276 def update_engine(self, connection_args: dict) -> None: 

277 """Add new version of engine 

278 

279 Args: 

280 connection_args (dict): paths to code and requirements 

281 """ 

282 code_path = Path(connection_args["code"]) 

283 requirements_path = Path(connection_args["modules"]) 

284 

285 engine_connection_args = self.engine_storage.get_connection_args() 

286 if isinstance(engine_connection_args, dict) is False or "handler_version" not in engine_connection_args: 

287 engine_connection_args = { 

288 "handler_version": __version__, 

289 "versions": { 

290 "1": { 

291 "code": "code.py", 

292 "requirements": "requirements.txt", 

293 "type": self._default_byom_type.name.lower(), 

294 } 

295 }, 

296 } 

297 new_version = str(max([int(x) for x in engine_connection_args["versions"].keys()]) + 1) 

298 

299 engine_connection_args["versions"][new_version] = { 

300 "code": code_path.name, 

301 "requirements": requirements_path.name, 

302 "type": self.normalize_byom_type(connection_args.get("type")).name.lower(), 

303 } 

304 

305 self.engine_storage.fileStorage.file_set(f"code_{new_version}", code_path.read_bytes()) 

306 

307 self.engine_storage.fileStorage.file_set(f"modules_{new_version}", requirements_path.read_bytes()) 

308 self.engine_storage.fileStorage.push() 

309 

310 self.engine_storage.update_connection_args(engine_connection_args) 

311 

312 model_proxy = self._get_model_proxy(new_version) 

313 try: 

314 methods = model_proxy.check() 

315 self.engine_storage.json_set("methods", methods) 

316 

317 except Exception as e: 

318 if hasattr(model_proxy, "remove_venv"): 

319 model_proxy.remove_venv() 

320 raise e 

321 

322 def function_list(self): 

323 return self.engine_storage.json_get("methods") 

324 

325 def function_call(self, name, args): 

326 mp = self._get_model_proxy() 

327 return mp.func_call(name, args) 

328 

329 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

330 using_args = args.get("using", {}) 

331 engine_version = using_args.get("engine_version") 

332 

333 model_storage = self.model_storage 

334 # TODO: should probably refactor at some point, as a bit of the logic is shared with lightwood's finetune logic 

335 try: 

336 base_predictor_id = args["base_model_id"] 

337 base_predictor_record = db.Predictor.query.get(base_predictor_id) 

338 if base_predictor_record.status != PREDICTOR_STATUS.COMPLETE: 

339 raise Exception("Base model must be in status 'complete'") 

340 

341 predictor_id = model_storage.predictor_id 

342 predictor_record = db.Predictor.query.get(predictor_id) 

343 

344 predictor_record.data = { 

345 "training_log": "training" 

346 } # TODO move to ModelStorage (don't work w/ db directly) 

347 predictor_record.training_start_at = datetime.now() 

348 predictor_record.status = PREDICTOR_STATUS.FINETUNING # TODO: parallel execution block 

349 db.session.commit() 

350 

351 model_proxy = self._get_model_proxy(engine_version) 

352 model_state = self.base_model_storage.file_get("model") 

353 model_state = model_proxy.finetune(df, model_state, args=args.get("using", {})) 

354 

355 # region hack to speedup file saving 

356 with profiler.Context("finetune-byom-write-file"): 

357 dest_abs_path = model_storage.fileStorage.folder_path / "model" 

358 with open(dest_abs_path, "wb") as fd: 

359 fd.write(model_state) 

360 model_storage.fileStorage.push(compression_level=0) 

361 # endregion 

362 

363 predictor_record.update_status = "up_to_date" 

364 predictor_record.status = PREDICTOR_STATUS.COMPLETE 

365 predictor_record.training_stop_at = datetime.now() 

366 db.session.commit() 

367 

368 except Exception as e: 

369 logger.error("Unexpected error during BYOM finetune:", exc_info=True) 

370 predictor_id = model_storage.predictor_id 

371 predictor_record = db.Predictor.query.with_for_update().get(predictor_id) 

372 error_message = format_exception_error(e) 

373 predictor_record.data = {"error": error_message} 

374 predictor_record.status = PREDICTOR_STATUS.ERROR 

375 db.session.commit() 

376 raise 

377 

378 finally: 

379 if predictor_record.training_stop_at is None: 

380 predictor_record.training_stop_at = datetime.now() 

381 db.session.commit() 

382 

383 

384class ModelWrapperUnsafe: 

385 """Model wrapper that executes learn/predict in current process""" 

386 

387 def __init__(self, code, modules_str, engine_id, engine_version: int): 

388 self.module = import_string(code) 

389 

390 model_instance = None 

391 model_class = find_model_class(self.module) 

392 if model_class is not None: 

393 model_instance = model_class() 

394 

395 self.model_instance = model_instance 

396 

397 def train(self, df, target, args): 

398 self.model_instance.train(df, target, args) 

399 return pickle.dumps(self.model_instance.__dict__, protocol=5) 

400 

401 def predict(self, df, model_state, args): 

402 model_state = pickle.loads(model_state) 

403 self.model_instance.__dict__ = model_state 

404 try: 

405 result = self.model_instance.predict(df, args) 

406 except Exception: 

407 result = self.model_instance.predict(df) 

408 return result 

409 

410 def finetune(self, df, model_state, args): 

411 self.model_instance.__dict__ = pickle.loads(model_state) 

412 

413 call_args = [df] 

414 if args: 

415 call_args.append(args) 

416 

417 self.model_instance.finetune(df, args) 

418 

419 return pickle.dumps(self.model_instance.__dict__, protocol=5) 

420 

421 def describe(self, model_state, attribute: Optional[str] = None) -> pd.DataFrame: 

422 if hasattr(self.model_instance, "describe"): 

423 model_state = pickle.loads(model_state) 

424 self.model_instance.__dict__ = model_state 

425 return self.model_instance.describe(attribute) 

426 return pd.DataFrame() 

427 

428 def func_call(self, func_name, args): 

429 func = getattr(self.module, func_name) 

430 return func(*args) 

431 

432 def check(self, mode: str = None): 

433 methods = check_module(self.module, mode) 

434 return methods 

435 

436 

437class ModelWrapperSafe: 

438 """Model wrapper that executes learn/predict in venv""" 

439 

440 def __init__(self, code, modules_str, engine_id, engine_version: int): 

441 self.code = code 

442 modules = self.parse_requirements(modules_str) 

443 

444 self.config = Config() 

445 self.is_cloud = Config().get("cloud", False) 

446 

447 self.env_path = None 

448 self.env_storage_path = None 

449 self.prepare_env(modules, engine_id, engine_version) 

450 

451 def prepare_env(self, modules, engine_id, engine_version: int): 

452 try: 

453 import virtualenv 

454 

455 base_path = self.config.get("byom", {}).get("venv_path") 

456 if base_path is None: 

457 # create in root path 

458 base_path = Path(self.config.paths["root"]) / "venvs" 

459 else: 

460 base_path = Path(base_path) 

461 base_path.mkdir(parents=True, exist_ok=True) 

462 

463 env_folder_name = f"env_{engine_id}" 

464 if isinstance(engine_version, int) and engine_version > 1: 

465 env_folder_name = f"{env_folder_name}_{engine_version}" 

466 

467 self.env_storage_path = base_path / env_folder_name 

468 if self.is_cloud: 

469 bese_env_path = Path(tempfile.gettempdir()) / "mindsdb" / "venv" 

470 bese_env_path.mkdir(parents=True, exist_ok=True) 

471 self.env_path = bese_env_path / env_folder_name 

472 tar_path = self.env_storage_path.with_suffix(".tar") 

473 if self.env_path.exists() is False and tar_path.exists() is True: 

474 with tarfile.open(tar_path) as tar: 

475 safe_extract(tar, path=bese_env_path) 

476 else: 

477 self.env_path = self.env_storage_path 

478 

479 if sys.platform in ("win32", "cygwin"): 

480 exectable_folder_name = "Scripts" 

481 else: 

482 exectable_folder_name = "bin" 

483 

484 pip_cmd = self.env_path / exectable_folder_name / "pip" 

485 self.python_path = self.env_path / exectable_folder_name / "python" 

486 

487 if self.env_path.exists(): 

488 # already exists. it means requirements are already installed 

489 return 

490 

491 # create 

492 logger.info(f"Creating new environment: {self.env_path}") 

493 virtualenv.cli_run(["-p", sys.executable, str(self.env_path)]) 

494 logger.info(f"Created new environment: {self.env_path}") 

495 

496 if len(modules) > 0: 

497 self.install_modules(modules, pip_cmd=pip_cmd) 

498 except Exception: 

499 # DANGER !!! VENV MUST BE CREATED 

500 logger.info("Can't create virtual environment. venv module should be installed") 

501 

502 if self.is_cloud: 

503 raise 

504 

505 self.python_path = Path(sys.executable) 

506 

507 # try to install modules everytime 

508 self.install_modules(modules, pip_cmd=pip_cmd) 

509 

510 # fastest way to copy files if destination is NFS 

511 if self.is_cloud and self.env_storage_path != self.env_path: 

512 old_cwd = os.getcwd() 

513 os.chdir(str(bese_env_path)) 

514 tar_path = self.env_path.with_suffix(".tar") 

515 with tarfile.open(name=str(tar_path), mode="w") as tar: 

516 tar.add(str(self.env_path.name)) 

517 os.chdir(old_cwd) 

518 subprocess.run( 

519 ["cp", "-R", "--no-preserve=mode,ownership", str(tar_path), str(base_path / tar_path.name)], 

520 check=True, 

521 shell=False, 

522 ) 

523 tar_path.unlink() 

524 

525 def remove_venv(self): 

526 if self.env_path is not None and self.env_path.exists(): 

527 shutil.rmtree(str(self.env_path)) 

528 

529 if self.is_cloud: 

530 tar_path = self.env_storage_path.with_suffix(".tar") 

531 tar_path.unlink() 

532 

533 def parse_requirements(self, requirements): 

534 # get requirements from string 

535 # they should be located at the top of the file, before code 

536 

537 pattern = "^[\w\\[\\]-]+[=!<>\s]*[\d\.]*[,=!<>\s]*[\d\.]*$" # noqa 

538 modules = [] 

539 for line in requirements.split(b"\n"): 

540 line = line.decode().strip() 

541 if line: 

542 if re.match(pattern, line): 

543 modules.append(line) 

544 else: 

545 raise Exception(f"Wrong requirement: {line}") 

546 

547 is_pandas = any([m.lower().startswith("pandas") for m in modules]) 

548 if not is_pandas: 

549 modules.append("pandas>=2.0.0,<2.1.0") 

550 modules.append("numpy<2.0.0") 

551 

552 # for dataframe serialization 

553 modules.append("pyarrow==19.0.0") 

554 return modules 

555 

556 def install_modules(self, modules, pip_cmd): 

557 # install in current environment using pip 

558 for module in modules: 

559 logger.debug(f"BYOM install module: {module}") 

560 p = subprocess.Popen([pip_cmd, "install", module], stderr=subprocess.PIPE) 

561 p.wait() 

562 if p.returncode != 0: 

563 raise Exception(f"Problem with installing module {module}: {p.stderr.read()}") 

564 

565 def _run_command(self, params): 

566 logger.debug(f"BYOM run command: {params.get('method')}") 

567 params_enc = encode(params) 

568 

569 wrapper_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "proc_wrapper.py") 

570 p = subprocess.Popen( 

571 [str(self.python_path), wrapper_path], 

572 stdin=subprocess.PIPE, 

573 stdout=subprocess.PIPE, 

574 stderr=subprocess.PIPE, 

575 ) 

576 

577 p.stdin.write(params_enc) 

578 p.stdin.close() 

579 ret_enc = p.stdout.read() 

580 

581 p.wait() 

582 

583 try: 

584 ret = decode(ret_enc) 

585 except (pickle.UnpicklingError, EOFError): 

586 raise RuntimeError(p.stderr.read()) 

587 return ret 

588 

589 def check(self, mode: str = None): 

590 params = { 

591 "method": BYOM_METHOD.CHECK.value, 

592 "code": self.code, 

593 "mode": mode, 

594 } 

595 return self._run_command(params) 

596 

597 def train(self, df, target, args): 

598 params = { 

599 "method": BYOM_METHOD.TRAIN.value, 

600 "code": self.code, 

601 "df": None, 

602 "to_predict": target, 

603 "args": args, 

604 } 

605 if df is not None: 

606 params["df"] = pd_encode(df) 

607 

608 model_state = self._run_command(params) 

609 return model_state 

610 

611 def predict(self, df, model_state, args): 

612 params = { 

613 "method": BYOM_METHOD.PREDICT.value, 

614 "code": self.code, 

615 "model_state": model_state, 

616 "df": pd_encode(df), 

617 "args": args, 

618 } 

619 pred_df = self._run_command(params) 

620 return pd_decode(pred_df) 

621 

622 def finetune(self, df, model_state, args): 

623 params = { 

624 "method": BYOM_METHOD.FINETUNE.value, 

625 "code": self.code, 

626 "model_state": model_state, 

627 "df": pd_encode(df), 

628 "args": args, 

629 } 

630 

631 model_state = self._run_command(params) 

632 return model_state 

633 

634 def describe(self, model_state, attribute: Optional[str] = None) -> pd.DataFrame: 

635 params = { 

636 "method": BYOM_METHOD.DESCRIBE.value, 

637 "code": self.code, 

638 "model_state": model_state, 

639 "attribute": attribute, 

640 } 

641 enc_df = self._run_command(params) 

642 df = pd_decode(enc_df) 

643 return df 

644 

645 def func_call(self, func_name, args): 

646 params = { 

647 "method": BYOM_METHOD.FUNC_CALL.value, 

648 "code": self.code, 

649 "func_name": func_name, 

650 "args": args, 

651 } 

652 result = self._run_command(params) 

653 return result