Coverage for mindsdb / utilities / ml_task_queue / consumer.py: 0%
159 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import time
3import signal
4import tempfile
5import threading
6from pathlib import Path
7from functools import wraps
8from collections.abc import Callable
10import psutil
11from walrus import Database
12from pandas import DataFrame
13from redis.exceptions import ConnectionError as RedisConnectionError
15from mindsdb.utilities.config import Config
16from mindsdb.utilities.context import context as ctx
17from mindsdb.integrations.libs.process_cache import process_cache
18from mindsdb.utilities.ml_task_queue.utils import RedisKey, StatusNotifier, to_bytes, from_bytes
19from mindsdb.utilities.ml_task_queue.base import BaseRedisQueue
20from mindsdb.utilities.fs import clean_unlinked_process_marks
21from mindsdb.utilities.functions import mark_process
22from mindsdb.utilities.ml_task_queue.const import (
23 ML_TASK_TYPE,
24 ML_TASK_STATUS,
25 TASKS_STREAM_NAME,
26 TASKS_STREAM_CONSUMER_NAME,
27 TASKS_STREAM_CONSUMER_GROUP_NAME,
28)
29from mindsdb.utilities import log
30from mindsdb.utilities.sentry import sentry_sdk # noqa: F401
32logger = log.getLogger(__name__)
35def _save_thread_link(func: Callable) -> Callable:
36 """Decorator for MLTaskConsumer.
37 Save thread in which func is executed to a list.
38 """
40 @wraps(func)
41 def wrapper(self, *args, **kwargs) -> None:
42 current_thread = threading.current_thread()
43 self._listen_message_threads.append(current_thread)
44 try:
45 result = func(self, *args, **kwargs)
46 finally:
47 self._listen_message_threads.remove(current_thread)
48 return result
50 return wrapper
53class MLTaskConsumer(BaseRedisQueue):
54 """Listener of ML tasks queue and tasks executioner.
55 Each new message waited and executed in separate thread.
57 Attributes:
58 _ready_event (Event): set if ready to start new queue listen thread
59 _stop_event (Event): set if need to stop all threads/processes
60 cpu_stat (list[float]): CPU usage statistic. Each value is 0-100 float representing CPU usage in %
61 _collect_cpu_stat_thread (Thread): pointer to thread that collecting CPU usage statistic
62 _listen_message_threads (list[Thread]): list of pointers to threads where queue messages are listening/processing
63 db (Redis): database object
64 cache: redis cache abstrtaction
65 consumer_group: redis consumer group object
66 """
68 def __init__(self) -> None:
69 self._ready_event = threading.Event()
70 self._ready_event.set()
72 self._stop_event = threading.Event()
73 self._stop_event.clear()
75 process_cache.init()
77 # region collect cpu usage statistic
78 self.cpu_stat = [0] * 10
79 self._collect_cpu_stat_thread = threading.Thread(
80 target=self._collect_cpu_stat, name="MLTaskConsumer._collect_cpu_stat"
81 )
82 self._collect_cpu_stat_thread.start()
83 # endregion
85 self._listen_message_threads = []
87 # region connect to redis
88 config = Config().get("ml_task_queue", {})
89 self.db = Database(
90 host=config.get("host", "localhost"),
91 port=config.get("port", 6379),
92 db=config.get("db", 0),
93 username=config.get("username"),
94 password=config.get("password"),
95 protocol=3,
96 )
97 self.wait_redis_ping(60)
99 self.db.Stream(TASKS_STREAM_NAME)
100 self.cache = self.db.cache()
101 self.consumer_group = self.db.consumer_group(TASKS_STREAM_CONSUMER_GROUP_NAME, [TASKS_STREAM_NAME])
102 self.consumer_group.create()
103 self.consumer_group.consumer(TASKS_STREAM_CONSUMER_NAME)
104 # endregion
106 def _collect_cpu_stat(self) -> None:
107 """Collect CPU usage statistic. Executerd in thread."""
108 while self._stop_event.is_set() is False:
109 self.cpu_stat = self.cpu_stat[1:]
110 self.cpu_stat.append(psutil.cpu_percent())
111 time.sleep(1)
113 def get_avg_cpu_usage(self) -> float:
114 """get average CPU usage for last period (10s by default)
116 Returns:
117 float: 0-100 value, average CPU usage
118 """
119 return sum(self.cpu_stat) / len(self.cpu_stat)
121 def wait_free_resources(self) -> None:
122 """Sleep in thread untill there are free resources. Checks:
123 - avg CPU usage is less than 60%
124 - current CPU usage is less than 60%
125 - current tasks count is less than (N CPU cores) / 8
126 """
127 config = Config()
128 is_cloud = config.get("cloud", False)
129 processes_dir = Path(tempfile.gettempdir()).joinpath("mindsdb/processes/learn/")
130 while True:
131 while self.get_avg_cpu_usage() > 60 or max(self.cpu_stat[-3:]) > 60:
132 time.sleep(1)
133 if is_cloud and processes_dir.is_dir():
134 clean_unlinked_process_marks()
135 while (len(list(processes_dir.iterdir())) * 8) >= os.cpu_count():
136 time.sleep(1)
137 clean_unlinked_process_marks()
138 if (self.get_avg_cpu_usage() > 60 or max(self.cpu_stat[-3:]) > 60) is False:
139 return
141 @_save_thread_link
142 def _listen(self) -> None:
143 """Listen message queue untill get new message. Execute task."""
144 message = None
145 while message is None:
146 self.wait_free_resources()
147 self.wait_redis_ping()
148 if self._stop_event.is_set():
149 return
151 try:
152 message = self.consumer_group.read(count=1, block=1000, consumer=TASKS_STREAM_CONSUMER_NAME)
153 except RedisConnectionError:
154 logger.exception("Can't connect to Redis:")
155 self._stop_event.set()
156 return
157 except Exception:
158 self._stop_event.set()
159 raise
161 if message.get(TASKS_STREAM_NAME) is None or len(message.get(TASKS_STREAM_NAME)) == 0:
162 message = None
164 try:
165 message = message[TASKS_STREAM_NAME][0][0]
166 message_id = message[0].decode()
167 message_content = message[1]
168 self.consumer_group.streams[TASKS_STREAM_NAME].ack(message_id)
169 self.consumer_group.streams[TASKS_STREAM_NAME].delete(message_id)
171 payload = from_bytes(message_content[b"payload"])
172 task_type = ML_TASK_TYPE(message_content[b"task_type"])
173 model_id = int(message_content[b"model_id"])
174 company_id = message_content[b"company_id"]
175 if len(company_id) == 0:
176 company_id = None
177 redis_key = RedisKey(message_content.get(b"redis_key"))
179 # region read dataframe
180 dataframe_bytes = self.cache.get(redis_key.dataframe)
181 dataframe = None
182 if dataframe_bytes is not None:
183 dataframe = from_bytes(dataframe_bytes)
184 self.cache.delete(redis_key.dataframe)
185 # endregion
187 ctx.load(payload["context"])
188 finally:
189 self._ready_event.set()
191 try:
192 task = process_cache.apply_async(
193 task_type=task_type, model_id=model_id, payload=payload, dataframe=dataframe
194 )
195 status_notifier = StatusNotifier(redis_key, ML_TASK_STATUS.PROCESSING, self.db, self.cache)
196 status_notifier.start()
197 result = task.result()
198 except Exception as e:
199 self.wait_redis_ping()
200 status_notifier.stop()
201 exception_bytes = to_bytes(e)
202 self.cache.set(redis_key.exception, exception_bytes, 10)
203 self.db.publish(redis_key.status, ML_TASK_STATUS.ERROR.value)
204 self.cache.set(redis_key.status, ML_TASK_STATUS.ERROR.value, 180)
205 else:
206 self.wait_redis_ping()
207 status_notifier.stop()
208 if isinstance(result, DataFrame):
209 dataframe_bytes = to_bytes(result)
210 self.cache.set(redis_key.dataframe, dataframe_bytes, 10)
211 self.db.publish(redis_key.status, ML_TASK_STATUS.COMPLETE.value)
212 self.cache.set(redis_key.status, ML_TASK_STATUS.COMPLETE.value, 180)
214 def run(self) -> None:
215 """Start new listen thread each time when _ready_event is set"""
216 self._ready_event.set()
217 while self._stop_event.is_set() is False:
218 self._ready_event.wait(timeout=1)
219 if self._ready_event.is_set() is False:
220 continue
221 self._ready_event.clear()
222 threading.Thread(target=self._listen, name="MLTaskConsumer._listen").start()
223 self.stop()
225 def stop(self) -> None:
226 """Stop all executing threads"""
227 self._stop_event.set()
228 for thread in (*self._listen_message_threads, self._collect_cpu_stat_thread):
229 try:
230 if thread.is_alive():
231 thread.join()
232 except Exception:
233 pass
236@mark_process(name="internal", custom_mark="ml_task_consumer")
237def start(verbose: bool) -> None:
238 """Create task queue consumer and start listen the queue"""
239 consumer = MLTaskConsumer()
240 signal.signal(signal.SIGTERM, lambda _x, _y: consumer.stop())
241 try:
242 consumer.run()
243 except Exception as e:
244 consumer.stop()
245 logger.error(f"Got exception: {e}", flush=True)
246 raise
247 finally:
248 logger.info("Consumer process stopped", flush=True)