Coverage for mindsdb / integrations / libs / process_cache.py: 60%
200 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import time
2import threading
3import traceback
4from typing import Optional, Callable
5from concurrent.futures import ProcessPoolExecutor, Future
7from pandas import DataFrame
9import mindsdb.interfaces.storage.db as db
10from mindsdb.utilities.config import Config
11from mindsdb.utilities.context import context as ctx
12from mindsdb.utilities.ml_task_queue.const import ML_TASK_TYPE
13from mindsdb.integrations.libs.ml_handler_process import (
14 learn_process,
15 update_process,
16 predict_process,
17 describe_process,
18 create_engine_process,
19 update_engine_process,
20 create_validation_process,
21 func_call_process,
22)
25def init_ml_handler(module_path):
26 import importlib # noqa
28 import mindsdb.integrations.libs.ml_handler_process # noqa
30 db.init()
31 importlib.import_module(module_path)
34def dummy_task():
35 return None
38def empty_callback(_task):
39 return None
42class MLProcessException(Exception):
43 """Wrapper for exception to safely send it back to the main process.
45 If exception can not be pickled (pickle.loads(pickle.dumps(e))) then it may lead to termination of the ML process.
46 Also in this case, the error sent to the user will not be relevant. This wrapper should prevent it.
47 """
49 base_exception_bytes: bytes = None
51 def __init__(self, base_exception: Exception, message: str = None) -> None:
52 super().__init__(message)
53 traceback_text = "\n".join(traceback.format_exception(base_exception))
54 self.message = f"{base_exception.__class__.__name__}: {base_exception}\n{traceback_text}"
56 @property
57 def base_exception(self) -> Exception:
58 return RuntimeError(self.message)
61class WarmProcess:
62 """Class-wrapper for a process that persist for a long time. The process
63 may be initialized with any handler requirements. Current implimentation
64 is based on ProcessPoolExecutor just because of multiprocessing.pool
65 produce daemon processes, which can not be used for learning. That
66 bahaviour may be changed only using inheritance.
67 """
69 def __init__(self, initializer: Optional[Callable] = None, initargs: tuple = ()):
70 """create and init new process
72 Args:
73 initializer (Callable): the same as ProcessPoolExecutor initializer
74 initargs (tuple): the same as ProcessPoolExecutor initargs
75 """
76 self.pool = ProcessPoolExecutor(1, initializer=initializer, initargs=initargs)
77 self.last_usage_at = time.time()
78 self._markers = set()
79 # region bacause of ProcessPoolExecutor does not start new process
80 # untill it get a task, we need manually run dummy task to force init.
81 self.task = self.pool.submit(dummy_task)
82 self._init_done = False
83 self.task.add_done_callback(self._init_done_callback)
84 # endregion
86 def __del__(self):
87 self.shutdown()
89 def shutdown(self, wait: bool = False) -> None:
90 """Like ProcessPoolExecutor.shutdown
92 Args:
93 wait (bool): If True then shutdown will not return until all running futures have finished executing
94 """
95 self.pool.shutdown(wait=wait)
97 def _init_done_callback(self, _task):
98 """callback for initial task"""
99 self._init_done = True
101 def _update_last_usage_at_callback(self, _task):
102 self.last_usage_at = time.time()
104 def ready(self) -> bool:
105 """check is process ready to get a task or not
107 Returns:
108 bool
109 """
110 if self._init_done is False:
111 self.task.result()
112 self._init_done = True
113 if self.task is None or self.task.done(): 113 ↛ 115line 113 didn't jump to line 115 because the condition on line 113 was always true
114 return True
115 return False
117 def add_marker(self, marker: tuple):
118 """remember that that process processed task for that model
120 Args:
121 marker (tuple): identifier of model
122 """
123 if marker is not None: 123 ↛ exitline 123 didn't return from function 'add_marker' because the condition on line 123 was always true
124 self._markers.add(marker)
126 def has_marker(self, marker: tuple) -> bool:
127 """check if that process processed task for model
129 Args:
130 marker (tuple): identifier of model
132 Returns:
133 bool
134 """
135 if marker is None: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true
136 return False
137 return marker in self._markers
139 def is_marked(self) -> bool:
140 """check if process has any marker
142 Returns:
143 bool
144 """
145 return len(self._markers) > 0
147 def apply_async(self, func: Callable, *args: tuple, **kwargs: dict) -> Future:
148 """Run new task
150 Args:
151 func (Callable): function to run
152 args (tuple): args to be passed to function
153 kwargs (dict): kwargs to be passed to function
155 Returns:
156 Future
157 """
158 if not self.ready(): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 raise Exception("Process task is not ready")
160 self.task = self.pool.submit(func, *args, **kwargs)
161 self.task.add_done_callback(self._update_last_usage_at_callback)
162 self.last_usage_at = time.time()
163 return self.task
166def warm_function(func, context: str, *args, **kwargs):
167 ctx.load(context)
168 try:
169 return func(*args, **kwargs)
170 except Exception as e:
171 if type(e) in (ImportError, ModuleNotFoundError):
172 raise
173 raise MLProcessException(base_exception=e)
176class ProcessCache:
177 """simple cache for WarmProcess-es"""
179 def __init__(self, ttl: int = 120):
180 """Args:
181 ttl (int) time to live for unused process
182 """
183 self.cache = {}
184 self._init = False
185 self._lock = threading.Lock()
186 self._ttl = ttl
187 self._keep_alive = {}
188 self._stop_event = threading.Event()
189 self.cleaner_thread = None
191 def __del__(self):
192 self._stop_clean()
194 def _start_clean(self) -> None:
195 """start worker that close connections after ttl expired"""
196 if isinstance(self.cleaner_thread, threading.Thread) and self.cleaner_thread.is_alive():
197 return
198 self._stop_event.clear()
199 self.cleaner_thread = threading.Thread(target=self._clean, name="ProcessCache.clean")
200 self.cleaner_thread.daemon = True
201 self.cleaner_thread.start()
203 def _stop_clean(self) -> None:
204 """stop clean worker"""
205 self._stop_event.set()
207 def init(self):
208 """run processes for specified handlers"""
209 from mindsdb.interfaces.database.integrations import integration_controller
211 preload_handlers = {}
212 config = Config()
213 is_cloud = config.get("cloud", False) # noqa
215 if config["ml_task_queue"]["type"] != "redis":
216 if is_cloud:
217 lightwood_handler = integration_controller.get_handler_module("lightwood")
218 if lightwood_handler is not None and lightwood_handler.Handler is not None:
219 preload_handlers[lightwood_handler.Handler] = 4 if is_cloud else 1
221 huggingface_handler = integration_controller.get_handler_module("huggingface")
222 if huggingface_handler is not None and huggingface_handler.Handler is not None:
223 preload_handlers[huggingface_handler.Handler] = 1
225 openai_handler = integration_controller.get_handler_module("openai")
226 if openai_handler is not None and openai_handler.Handler is not None:
227 preload_handlers[openai_handler.Handler] = 1
229 with self._lock:
230 if self._init is False:
231 self._init = True
232 for handler in preload_handlers:
233 self._keep_alive[handler.name] = preload_handlers[handler]
234 self.cache[handler.name] = {
235 "last_usage_at": time.time(),
236 "handler_module": handler.__module__,
237 "processes": [
238 WarmProcess(init_ml_handler, (handler.__module__,))
239 for _x in range(preload_handlers[handler])
240 ],
241 }
243 def apply_async(
244 self, task_type: ML_TASK_TYPE, model_id: Optional[int], payload: dict, dataframe: Optional[DataFrame] = None
245 ) -> Future:
246 """run new task. If possible - do it in existing process, if not - start new one.
248 Args:
249 task_type (ML_TASK_TYPE): type of the task (learn, predict, etc)
250 model_id (int): id of the model
251 payload (dict): any 'lightweight' data that needs to be send in the process
252 dataframe (DataFrame): DataFrame to be send in the process
254 Returns:
255 Future
256 """
257 self._start_clean()
258 handler_module_path = payload["handler_meta"]["module_path"]
259 integration_id = payload["handler_meta"]["integration_id"]
260 if task_type in (ML_TASK_TYPE.LEARN, ML_TASK_TYPE.FINETUNE):
261 func = learn_process
262 kwargs = {
263 "data_integration_ref": payload["data_integration_ref"],
264 "problem_definition": payload["problem_definition"],
265 "fetch_data_query": payload["fetch_data_query"],
266 "project_name": payload["project_name"],
267 "model_id": model_id,
268 "base_model_id": payload.get("base_model_id"),
269 "set_active": payload["set_active"],
270 "integration_id": integration_id,
271 "module_path": handler_module_path,
272 }
273 elif task_type == ML_TASK_TYPE.PREDICT: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 func = predict_process
275 kwargs = {
276 "predictor_record": payload["predictor_record"],
277 "ml_engine_name": payload["handler_meta"]["engine"],
278 "args": payload["args"],
279 "dataframe": dataframe,
280 "integration_id": integration_id,
281 "module_path": handler_module_path,
282 }
283 elif task_type == ML_TASK_TYPE.DESCRIBE: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 func = describe_process
285 kwargs = {
286 "attribute": payload.get("attribute"),
287 "model_id": model_id,
288 "integration_id": integration_id,
289 "module_path": handler_module_path,
290 }
291 elif task_type == ML_TASK_TYPE.CREATE_VALIDATION:
292 func = create_validation_process
293 kwargs = {
294 "target": payload.get("target"),
295 "args": payload.get("args"),
296 "integration_id": integration_id,
297 "module_path": handler_module_path,
298 }
299 elif task_type == ML_TASK_TYPE.CREATE_ENGINE: 299 ↛ 306line 299 didn't jump to line 306 because the condition on line 299 was always true
300 func = create_engine_process
301 kwargs = {
302 "connection_args": payload["connection_args"],
303 "integration_id": integration_id,
304 "module_path": handler_module_path,
305 }
306 elif task_type == ML_TASK_TYPE.UPDATE_ENGINE:
307 func = update_engine_process
308 kwargs = {
309 "connection_args": payload["connection_args"],
310 "integration_id": integration_id,
311 "module_path": handler_module_path,
312 }
313 elif task_type == ML_TASK_TYPE.UPDATE:
314 func = update_process
315 kwargs = {
316 "args": payload["args"],
317 "integration_id": integration_id,
318 "model_id": model_id,
319 "module_path": handler_module_path,
320 }
321 elif task_type == ML_TASK_TYPE.FUNC_CALL:
322 func = func_call_process
323 kwargs = {
324 "name": payload["name"],
325 "args": payload["args"],
326 "integration_id": integration_id,
327 "module_path": handler_module_path,
328 }
329 else:
330 raise Exception(f"Unknown ML task type: {task_type}")
332 ml_engine_name = payload["handler_meta"]["engine"]
333 model_marker = (model_id, payload["context"]["company_id"])
334 with self._lock:
335 if ml_engine_name not in self.cache:
336 warm_process = WarmProcess(init_ml_handler, (handler_module_path,))
337 self.cache[ml_engine_name] = {
338 "last_usage_at": None,
339 "handler_module": handler_module_path,
340 "processes": [warm_process],
341 }
342 else:
343 warm_process = None
344 if model_marker is not None: 344 ↛ 353line 344 didn't jump to line 353 because the condition on line 344 was always true
345 try:
346 warm_process = next(
347 p
348 for p in self.cache[ml_engine_name]["processes"]
349 if p.ready() and p.has_marker(model_marker)
350 )
351 except StopIteration:
352 pass
353 if warm_process is None:
354 try:
355 warm_process = next(p for p in self.cache[ml_engine_name]["processes"] if p.ready())
356 except StopIteration:
357 pass
358 if warm_process is None:
359 warm_process = WarmProcess(init_ml_handler, (handler_module_path,))
360 self.cache[ml_engine_name]["processes"].append(warm_process)
362 task = warm_process.apply_async(warm_function, func, payload["context"], **kwargs)
363 self.cache[ml_engine_name]["last_usage_at"] = time.time()
364 warm_process.add_marker(model_marker)
365 return task
367 def _clean(self) -> None:
368 """worker that stop unused processes"""
369 while self._stop_event.wait(timeout=10) is False: 369 ↛ exitline 369 didn't return from function '_clean' because the condition on line 369 was always true
370 with self._lock:
371 for handler_name in self.cache.keys():
372 processes = self.cache[handler_name]["processes"]
373 processes.sort(key=lambda x: x.is_marked())
375 expected_count = 0
376 if handler_name in self._keep_alive: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true
377 expected_count = self._keep_alive[handler_name]
379 # stop processes which was used, it needs to free memory
380 for i, process in enumerate(processes): 380 ↛ 381line 380 didn't jump to line 381 because the loop on line 380 never started
381 if (
382 process.ready()
383 and process.is_marked()
384 and (time.time() - process.last_usage_at) > self._ttl
385 ):
386 processes.pop(i)
387 # del process
388 process.shutdown()
389 break
391 while expected_count > len(processes): 391 ↛ 392line 391 didn't jump to line 392 because the condition on line 391 was never true
392 processes.append(WarmProcess(init_ml_handler, (self.cache[handler_name]["handler_module"],)))
394 def shutdown(self, wait: bool = True) -> None:
395 """Call 'shutdown' for each process cache
397 wait (bool): like ProcessPoolExecutor.shutdown wait arg.
398 """
399 with self._lock:
400 for handler_name in self.cache:
401 for process in self.cache[handler_name]["processes"]:
402 process.shutdown(wait=wait)
403 self.cache[handler_name]["processes"] = []
405 def remove_processes_for_handler(self, handler_name: str) -> None:
406 """
407 Remove all warm processes for a given handler.
408 This is useful when the previous processes use an outdated instance of the handler.
409 A good example is when the dependencies for a handler are installed after attempting to use the handler.
411 Args:
412 handler_name (str): name of the handler.
413 """
414 with self._lock:
415 if handler_name in self.cache:
416 for process in self.cache[handler_name]["processes"]:
417 process.shutdown()
419 self.cache[handler_name]["processes"] = []
422process_cache = ProcessCache()