Coverage for mindsdb / __main__.py: 0%

336 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import gc 

2 

3gc.disable() 

4import os 

5import sys 

6import time 

7import atexit 

8import signal 

9import psutil 

10import asyncio 

11import threading 

12import shutil 

13from enum import Enum 

14from dataclasses import dataclass, field 

15from typing import Callable, Optional, Tuple, List 

16 

17from sqlalchemy import func 

18from sqlalchemy.orm.attributes import flag_modified 

19 

20from mindsdb.utilities import log 

21 

22logger = log.getLogger("mindsdb") 

23logger.debug("Starting MindsDB...") 

24 

25from mindsdb.__about__ import __version__ as mindsdb_version 

26from mindsdb.utilities.config import config 

27from mindsdb.utilities.starters import ( 

28 start_http, 

29 start_mysql, 

30 start_ml_task_queue, 

31 start_scheduler, 

32 start_tasks, 

33 start_litellm, 

34) 

35from mindsdb.utilities.ps import is_pid_listen_port, get_child_pids 

36import mindsdb.interfaces.storage.db as db 

37from mindsdb.utilities.fs import clean_process_marks, clean_unlinked_process_marks, create_pid_file, delete_pid_file 

38from mindsdb.utilities.context import context as ctx 

39from mindsdb.utilities.auth import register_oauth_client, get_aws_meta_data 

40from mindsdb.utilities.sentry import sentry_sdk # noqa: F401 

41from mindsdb.utilities.api_status import set_api_status 

42 

43try: 

44 import torch.multiprocessing as mp 

45except Exception: 

46 import multiprocessing as mp 

47try: 

48 mp.set_start_method("spawn") 

49except RuntimeError: 

50 logger.info("Torch multiprocessing context already set, ignoring...") 

51 

52gc.enable() 

53 

54_stop_event = threading.Event() 

55 

56 

57class TrunkProcessEnum(Enum): 

58 HTTP = "http" 

59 MYSQL = "mysql" 

60 JOBS = "jobs" 

61 TASKS = "tasks" 

62 ML_TASK_QUEUE = "ml_task_queue" 

63 LITELLM = "litellm" 

64 

65 @classmethod 

66 def _missing_(cls, value): 

67 logger.error(f'"{value}" is not a valid name of subprocess') 

68 sys.exit(1) 

69 

70 

71@dataclass 

72class TrunkProcessData: 

73 name: str 

74 entrypoint: Callable 

75 need_to_run: bool = False 

76 port: Optional[int] = None 

77 process: Optional[mp.Process] = None 

78 started: bool = False 

79 args: Optional[Tuple] = None 

80 restart_on_failure: bool = False 

81 max_restart_count: int = 3 

82 max_restart_interval_seconds: int = 60 

83 

84 _restart_count: int = 0 

85 _restarts_time: List[int] = field(default_factory=list) 

86 

87 def request_restart_attempt(self) -> bool: 

88 """Check if the process may be restarted. 

89 If `max_restart_count` == 0, then there are not restrictions on restarts count or interval. 

90 If `max_restart_interval_seconds` == 0, then there are no time limit for restarts count. 

91 

92 Returns: 

93 bool: `True` if the number of restarts in the interval does not exceed 

94 """ 

95 if self.max_restart_count == 0: 

96 return True 

97 current_time_seconds = int(time.time()) 

98 self._restarts_time.append(current_time_seconds) 

99 if self.max_restart_interval_seconds > 0: 

100 self._restarts_time = [ 

101 x for x in self._restarts_time if x >= (current_time_seconds - self.max_restart_interval_seconds) 

102 ] 

103 if len(self._restarts_time) > self.max_restart_count: 

104 return False 

105 return True 

106 

107 @property 

108 def should_restart(self) -> bool: 

109 """In case of OOM we want to restart the process. OS kill the process with code 9 on linux when an OOM occurs. 

110 On other OS process will be restarted regardless the code. 

111 

112 Returns: 

113 bool: `True` if the process need to be restarted on failure 

114 """ 

115 if config.is_cloud: 

116 return False 

117 if sys.platform in ("linux", "darwin"): 

118 return self.restart_on_failure and self.process.exitcode == -signal.SIGKILL.value 

119 else: 

120 if self.max_restart_count == 0: 

121 # to prevent infinity restarts, max_restart_count should be > 0 

122 logger.warning("In the current OS, it is not possible to use `max_restart_count=0`") 

123 return False 

124 return self.restart_on_failure 

125 

126 

127def close_api_gracefully(trunc_processes_struct): 

128 _stop_event.set() 

129 

130 delete_pid_file() 

131 

132 try: 

133 for trunc_processes_data in trunc_processes_struct.values(): 

134 process = trunc_processes_data.process 

135 if process is None: 

136 continue 

137 try: 

138 childs = get_child_pids(process.pid) 

139 for p in childs: 

140 try: 

141 os.kill(p, signal.SIGTERM) 

142 except Exception: 

143 p.kill() 

144 sys.stdout.flush() 

145 process.terminate() 

146 process.join() 

147 sys.stdout.flush() 

148 except psutil.NoSuchProcess: 

149 pass 

150 except KeyboardInterrupt: 

151 sys.exit(0) 

152 

153 

154def clean_mindsdb_tmp_dir(): 

155 """Clean the MindsDB tmp dir at exit.""" 

156 temp_dir = config["paths"]["tmp"] 

157 for file in temp_dir.iterdir(): 

158 if file.is_dir(): 

159 shutil.rmtree(file) 

160 else: 

161 file.unlink() 

162 

163 

164def set_error_model_status_by_pids(unexisting_pids: List[int]): 

165 """Models have id of its traiing process in the 'training_metadata' field. 

166 If the pid does not exist, we should set the model status to "error". 

167 Note: only for local usage. 

168 

169 Args: 

170 unexisting_pids (List[int]): list of 'pids' that do not exist. 

171 """ 

172 predictor_records = ( 

173 db.session.query(db.Predictor) 

174 .filter( 

175 db.Predictor.deleted_at.is_(None), 

176 db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), 

177 ) 

178 .all() 

179 ) 

180 for predictor_record in predictor_records: 

181 predictor_process_id = (predictor_record.training_metadata or {}).get("process_id") 

182 if predictor_process_id in unexisting_pids: 

183 predictor_record.status = db.PREDICTOR_STATUS.ERROR 

184 if isinstance(predictor_record.data, dict) is False: 

185 predictor_record.data = {} 

186 if "error" not in predictor_record.data: 

187 predictor_record.data["error"] = "The training process was terminated for unknown reasons" 

188 flag_modified(predictor_record, "data") 

189 db.session.commit() 

190 

191 

192def set_error_model_status_for_unfinished(): 

193 """Set error status to any model if status not in 'complete' or 'error' 

194 Note: only for local usage. 

195 """ 

196 predictor_records = ( 

197 db.session.query(db.Predictor) 

198 .filter( 

199 db.Predictor.deleted_at.is_(None), 

200 db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), 

201 ) 

202 .all() 

203 ) 

204 for predictor_record in predictor_records: 

205 predictor_record.status = db.PREDICTOR_STATUS.ERROR 

206 if isinstance(predictor_record.data, dict) is False: 

207 predictor_record.data = {} 

208 if "error" not in predictor_record.data: 

209 predictor_record.data["error"] = "Unknown error" 

210 flag_modified(predictor_record, "data") 

211 db.session.commit() 

212 

213 

214def do_clean_process_marks(): 

215 """delete unexisting 'process marks'""" 

216 while _stop_event.wait(timeout=5) is False: 

217 unexisting_pids = clean_unlinked_process_marks() 

218 if not config.is_cloud and len(unexisting_pids) > 0: 

219 set_error_model_status_by_pids(unexisting_pids) 

220 

221 

222def create_permanent_integrations(): 

223 """ 

224 Create permanent integrations, for now only the 'files' integration. 

225 NOTE: this is intentional to avoid importing integration_controller 

226 """ 

227 integration_name = "files" 

228 existing = db.session.query(db.Integration).filter_by(name=integration_name, company_id=None).first() 

229 if existing is not None: 

230 return 

231 integration_record = db.Integration( 

232 name=integration_name, 

233 data={}, 

234 engine=integration_name, 

235 company_id=None, 

236 ) 

237 db.session.add(integration_record) 

238 try: 

239 db.session.commit() 

240 except Exception: 

241 logger.exception(f"Failed to create permanent integration '{integration_name}' in the internal database.") 

242 db.session.rollback() 

243 

244 

245def validate_default_project() -> None: 

246 """Handle 'default_project' config option. 

247 Project with the name specified in 'default_project' must exists and be marked with 

248 'is_default' metadata. If it is not possible, then terminate the process with error. 

249 Note: this can be done using 'project_controller', but we want to save init time and used RAM. 

250 """ 

251 new_default_project_name = config.get("default_project") 

252 logger.debug(f"Checking if default project {new_default_project_name} exists") 

253 filter_company_id = ctx.company_id if ctx.company_id is not None else "0" 

254 

255 current_default_project: db.Project | None = db.Project.query.filter( 

256 db.Project.company_id == filter_company_id, 

257 db.Project.metadata_["is_default"].as_boolean() == True, # noqa 

258 ).first() 

259 

260 if current_default_project is None: 

261 # Legacy: If the default project does not exist, mark the new one as default. 

262 existing_project = db.Project.query.filter( 

263 db.Project.company_id == filter_company_id, 

264 func.lower(db.Project.name) == func.lower(new_default_project_name), 

265 ).first() 

266 if existing_project is None: 

267 logger.critical(f"A project with the name '{new_default_project_name}' does not exist") 

268 sys.exit(1) 

269 

270 existing_project.metadata_ = {"is_default": True} 

271 flag_modified(existing_project, "metadata_") 

272 db.session.commit() 

273 elif current_default_project.name != new_default_project_name: 

274 # If the default project exists, but the name is different, update the name. 

275 existing_project = db.Project.query.filter( 

276 db.Project.company_id == filter_company_id, 

277 func.lower(db.Project.name) == func.lower(new_default_project_name), 

278 ).first() 

279 if existing_project is not None: 

280 logger.critical(f"A project with the name '{new_default_project_name}' already exists") 

281 sys.exit(1) 

282 current_default_project.name = new_default_project_name 

283 db.session.commit() 

284 

285 

286def start_process(trunc_process_data: TrunkProcessData) -> None: 

287 """Start a process. 

288 

289 Args: 

290 trunc_process_data (TrunkProcessData): The data of the process to start. 

291 """ 

292 mp_ctx = mp.get_context("spawn") 

293 logger.info(f"{trunc_process_data.name} API: starting...") 

294 try: 

295 trunc_process_data.process = mp_ctx.Process( 

296 target=trunc_process_data.entrypoint, 

297 args=trunc_process_data.args, 

298 name=trunc_process_data.name, 

299 ) 

300 trunc_process_data.process.start() 

301 except Exception as e: 

302 logger.exception(f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:") 

303 close_api_gracefully(trunc_processes_struct) 

304 raise e 

305 

306 

307if __name__ == "__main__": 

308 mp.freeze_support() 

309 # warn if less than 1Gb of free RAM 

310 if psutil.virtual_memory().available < (1 << 30): 

311 logger.warning( 

312 "The system is running low on memory. " + "This may impact the stability and performance of the program." 

313 ) 

314 

315 ctx.set_default() 

316 

317 # ---- CHECK SYSTEM ---- 

318 if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 10): 

319 print( 

320 """ 

321 MindsDB requires Python >= 3.10 to run 

322 

323 Once you have supported Python version installed you can start mindsdb as follows: 

324 

325 1. create and activate venv: 

326 python3 -m venv venv 

327 source venv/bin/activate 

328 

329 2. install MindsDB: 

330 pip3 install mindsdb 

331 

332 3. Run MindsDB 

333 python3 -m mindsdb 

334 

335 More instructions in https://docs.mindsdb.com 

336 """ 

337 ) 

338 exit(1) 

339 

340 if config.cmd_args.version: 

341 print(f"MindsDB {mindsdb_version}") 

342 sys.exit(0) 

343 

344 if config.cmd_args.update_gui or config.cmd_args.load_tokenizer: 

345 if config.cmd_args.update_gui: 

346 from mindsdb.api.http.initialize import initialize_static 

347 

348 logger.info("Updating the GUI version") 

349 initialize_static() 

350 

351 if config.cmd_args.load_tokenizer: 

352 try: 

353 from langchain_core.language_models import get_tokenizer 

354 

355 get_tokenizer() 

356 logger.info("Tokenizer successfully loaded") 

357 except ImportError: 

358 logger.info("Failed to load tokenizer due to an import error") 

359 except Exception: 

360 logger.info("Failed to load tokenizer: ", exc_info=True) 

361 

362 sys.exit(0) 

363 

364 config.raise_warnings(logger=logger) 

365 os.environ["MINDSDB_RUNTIME"] = "1" 

366 

367 if os.environ.get("ARROW_DEFAULT_MEMORY_POOL") is None: 

368 try: 

369 """It seems like snowflake handler have memory issue that related to pyarrow. Memory usage keep growing with 

370 requests. This is related to 'memory pool' that is 'mimalloc' by default: it is fastest but use a lot of ram 

371 """ 

372 import pyarrow as pa 

373 

374 try: 

375 pa.jemalloc_memory_pool() 

376 os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "jemalloc" 

377 except NotImplementedError: 

378 pa.system_memory_pool() 

379 os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "system" 

380 except Exception: 

381 pass 

382 

383 db.init() 

384 

385 environment = config["environment"] 

386 if environment == "aws_marketplace": 

387 try: 

388 register_oauth_client() 

389 except Exception: 

390 logger.exception("Something went wrong during client register:") 

391 elif environment != "local": 

392 try: 

393 aws_meta_data = get_aws_meta_data() 

394 config.update({"aws_meta_data": aws_meta_data}) 

395 except Exception: 

396 pass 

397 

398 apis = os.getenv("MINDSDB_APIS") or config.cmd_args.api 

399 

400 if apis is None: # If "--api" option is not specified, start the default APIs 

401 api_arr = [TrunkProcessEnum.HTTP, TrunkProcessEnum.MYSQL] 

402 elif apis == "": # If "--api=" (blank) is specified, don't start any APIs 

403 api_arr = [] 

404 else: # The user has provided a list of APIs to start 

405 api_arr = [TrunkProcessEnum(name) for name in apis.split(",")] 

406 

407 logger.info(f"Version: {mindsdb_version}") 

408 logger.info(f"Configuration file: {config.config_path or 'absent'}") 

409 logger.info(f"Storage path: {config.paths['root']}") 

410 log.log_system_info(logger) 

411 logger.debug(f"User config: {config.user_config}") 

412 logger.debug(f"System config: {config.auto_config}") 

413 logger.debug(f"Env config: {config.env_config}") 

414 

415 is_cloud = config.is_cloud 

416 unexisting_pids = clean_unlinked_process_marks() 

417 if not is_cloud: 

418 try: 

419 from mindsdb.migrations import migrate 

420 

421 migrate.migrate_to_head() 

422 except Exception: 

423 logger.exception("Failed to apply database migrations. This may prevent MindsDB from operating correctly:") 

424 

425 validate_default_project() 

426 

427 if len(unexisting_pids) > 0: 

428 set_error_model_status_by_pids(unexisting_pids) 

429 set_error_model_status_for_unfinished() 

430 create_permanent_integrations() 

431 

432 clean_process_marks() 

433 

434 # Get config values for APIs 

435 http_api_config = config.get("api", {}).get("http", {}) 

436 mysql_api_config = config.get("api", {}).get("mysql", {}) 

437 litellm_api_config = config.get("api", {}).get("litellm", {}) 

438 trunc_processes_struct = { 

439 TrunkProcessEnum.HTTP: TrunkProcessData( 

440 name=TrunkProcessEnum.HTTP.value, 

441 entrypoint=start_http, 

442 port=http_api_config["port"], 

443 args=(config.cmd_args.verbose,), 

444 restart_on_failure=http_api_config.get("restart_on_failure", False), 

445 max_restart_count=http_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), 

446 max_restart_interval_seconds=http_api_config.get( 

447 "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds 

448 ), 

449 ), 

450 TrunkProcessEnum.MYSQL: TrunkProcessData( 

451 name=TrunkProcessEnum.MYSQL.value, 

452 entrypoint=start_mysql, 

453 port=mysql_api_config["port"], 

454 args=(config.cmd_args.verbose,), 

455 restart_on_failure=mysql_api_config.get("restart_on_failure", False), 

456 max_restart_count=mysql_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), 

457 max_restart_interval_seconds=mysql_api_config.get( 

458 "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds 

459 ), 

460 ), 

461 TrunkProcessEnum.JOBS: TrunkProcessData( 

462 name=TrunkProcessEnum.JOBS.value, entrypoint=start_scheduler, args=(config.cmd_args.verbose,) 

463 ), 

464 TrunkProcessEnum.TASKS: TrunkProcessData( 

465 name=TrunkProcessEnum.TASKS.value, entrypoint=start_tasks, args=(config.cmd_args.verbose,) 

466 ), 

467 TrunkProcessEnum.ML_TASK_QUEUE: TrunkProcessData( 

468 name=TrunkProcessEnum.ML_TASK_QUEUE.value, entrypoint=start_ml_task_queue, args=(config.cmd_args.verbose,) 

469 ), 

470 TrunkProcessEnum.LITELLM: TrunkProcessData( 

471 name=TrunkProcessEnum.LITELLM.value, 

472 entrypoint=start_litellm, 

473 port=litellm_api_config.get("port", 8000), 

474 args=(config.cmd_args.verbose,), 

475 restart_on_failure=litellm_api_config.get("restart_on_failure", False), 

476 max_restart_count=litellm_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), 

477 max_restart_interval_seconds=litellm_api_config.get( 

478 "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds 

479 ), 

480 ), 

481 } 

482 

483 for api_enum in api_arr: 

484 if api_enum in trunc_processes_struct: 

485 trunc_processes_struct[api_enum].need_to_run = True 

486 else: 

487 logger.error(f"ERROR: {api_enum} API is not a valid api in config") 

488 

489 if config["jobs"]["disable"] is False: 

490 trunc_processes_struct[TrunkProcessEnum.JOBS].need_to_run = True 

491 

492 if config["tasks"]["disable"] is False: 

493 trunc_processes_struct[TrunkProcessEnum.TASKS].need_to_run = True 

494 

495 if config.cmd_args.ml_task_queue_consumer is True: 

496 trunc_processes_struct[TrunkProcessEnum.ML_TASK_QUEUE].need_to_run = True 

497 

498 create_pid_file(config) 

499 

500 for trunc_process_data in trunc_processes_struct.values(): 

501 if trunc_process_data.started is True or trunc_process_data.need_to_run is False: 

502 continue 

503 start_process(trunc_process_data) 

504 # Set status for APIs without ports (they don't go through wait_api_start) 

505 if trunc_process_data.port is None: 

506 set_api_status(trunc_process_data.name, True) 

507 

508 atexit.register(close_api_gracefully, trunc_processes_struct=trunc_processes_struct) 

509 atexit.register(clean_mindsdb_tmp_dir) 

510 

511 async def wait_api_start(api_name, pid, port): 

512 timeout = 60 

513 start_time = time.time() 

514 started = is_pid_listen_port(pid, port) 

515 while (time.time() - start_time) < timeout and started is False: 

516 await asyncio.sleep(0.5) 

517 started = is_pid_listen_port(pid, port) 

518 

519 set_api_status(api_name, started) 

520 

521 return api_name, port, started 

522 

523 async def wait_apis_start(): 

524 futures = [ 

525 wait_api_start( 

526 trunc_process_data.name, 

527 trunc_process_data.process.pid, 

528 trunc_process_data.port, 

529 ) 

530 for trunc_process_data in trunc_processes_struct.values() 

531 if trunc_process_data.port is not None and trunc_process_data.need_to_run is True 

532 ] 

533 for future in asyncio.as_completed(futures): 

534 api_name, port, started = await future 

535 if started: 

536 logger.info(f"{api_name} API: started on {port}") 

537 else: 

538 logger.error(f"ERROR: {api_name} API cant start on {port}") 

539 

540 async def join_process(trunc_process_data: TrunkProcessData): 

541 finish = False 

542 while not finish: 

543 process = trunc_process_data.process 

544 try: 

545 while process.is_alive(): 

546 process.join(1) 

547 await asyncio.sleep(0) 

548 except KeyboardInterrupt: 

549 logger.info("Got keyboard interrupt, stopping APIs") 

550 close_api_gracefully(trunc_processes_struct) 

551 finally: 

552 if trunc_process_data.should_restart: 

553 if trunc_process_data.request_restart_attempt(): 

554 logger.warning(f"{trunc_process_data.name} API: stopped unexpectedly, restarting") 

555 trunc_process_data.process = None 

556 if trunc_process_data.name == TrunkProcessEnum.HTTP.value: 

557 # do not open GUI on HTTP API restart 

558 trunc_process_data.args = (config.cmd_args.verbose, None, True) 

559 start_process(trunc_process_data) 

560 api_name, port, started = await wait_api_start( 

561 trunc_process_data.name, 

562 trunc_process_data.process.pid, 

563 trunc_process_data.port, 

564 ) 

565 if started: 

566 logger.info(f"{api_name} API: started on {port}") 

567 else: 

568 logger.error(f"ERROR: {api_name} API cant start on {port}") 

569 else: 

570 finish = True 

571 logger.error( 

572 f'The "{trunc_process_data.name}" process could not restart after failure. ' 

573 "There will be no further attempts to restart." 

574 ) 

575 else: 

576 finish = True 

577 logger.info(f"{trunc_process_data.name} API: stopped") 

578 

579 async def gather_apis(): 

580 await asyncio.gather( 

581 *[ 

582 join_process(trunc_process_data) 

583 for trunc_process_data in trunc_processes_struct.values() 

584 if trunc_process_data.need_to_run is True 

585 ], 

586 return_exceptions=False, 

587 ) 

588 

589 ioloop = asyncio.new_event_loop() 

590 ioloop.run_until_complete(wait_apis_start()) 

591 

592 threading.Thread(target=do_clean_process_marks, name="clean_process_marks").start() 

593 if config["logging"]["resources_log"]["enabled"] is True: 

594 threading.Thread( 

595 target=log.resources_log_thread, 

596 args=(_stop_event, config["logging"]["resources_log"]["interval"]), 

597 name="resources_log", 

598 ).start() 

599 

600 ioloop.run_until_complete(gather_apis()) 

601 ioloop.close()