Coverage for mindsdb / utilities / ml_task_queue / consumer.py: 0%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import time 

3import signal 

4import tempfile 

5import threading 

6from pathlib import Path 

7from functools import wraps 

8from collections.abc import Callable 

9 

10import psutil 

11from walrus import Database 

12from pandas import DataFrame 

13from redis.exceptions import ConnectionError as RedisConnectionError 

14 

15from mindsdb.utilities.config import Config 

16from mindsdb.utilities.context import context as ctx 

17from mindsdb.integrations.libs.process_cache import process_cache 

18from mindsdb.utilities.ml_task_queue.utils import RedisKey, StatusNotifier, to_bytes, from_bytes 

19from mindsdb.utilities.ml_task_queue.base import BaseRedisQueue 

20from mindsdb.utilities.fs import clean_unlinked_process_marks 

21from mindsdb.utilities.functions import mark_process 

22from mindsdb.utilities.ml_task_queue.const import ( 

23 ML_TASK_TYPE, 

24 ML_TASK_STATUS, 

25 TASKS_STREAM_NAME, 

26 TASKS_STREAM_CONSUMER_NAME, 

27 TASKS_STREAM_CONSUMER_GROUP_NAME, 

28) 

29from mindsdb.utilities import log 

30from mindsdb.utilities.sentry import sentry_sdk # noqa: F401 

31 

32logger = log.getLogger(__name__) 

33 

34 

35def _save_thread_link(func: Callable) -> Callable: 

36 """Decorator for MLTaskConsumer. 

37 Save thread in which func is executed to a list. 

38 """ 

39 

40 @wraps(func) 

41 def wrapper(self, *args, **kwargs) -> None: 

42 current_thread = threading.current_thread() 

43 self._listen_message_threads.append(current_thread) 

44 try: 

45 result = func(self, *args, **kwargs) 

46 finally: 

47 self._listen_message_threads.remove(current_thread) 

48 return result 

49 

50 return wrapper 

51 

52 

53class MLTaskConsumer(BaseRedisQueue): 

54 """Listener of ML tasks queue and tasks executioner. 

55 Each new message waited and executed in separate thread. 

56 

57 Attributes: 

58 _ready_event (Event): set if ready to start new queue listen thread 

59 _stop_event (Event): set if need to stop all threads/processes 

60 cpu_stat (list[float]): CPU usage statistic. Each value is 0-100 float representing CPU usage in % 

61 _collect_cpu_stat_thread (Thread): pointer to thread that collecting CPU usage statistic 

62 _listen_message_threads (list[Thread]): list of pointers to threads where queue messages are listening/processing 

63 db (Redis): database object 

64 cache: redis cache abstrtaction 

65 consumer_group: redis consumer group object 

66 """ 

67 

68 def __init__(self) -> None: 

69 self._ready_event = threading.Event() 

70 self._ready_event.set() 

71 

72 self._stop_event = threading.Event() 

73 self._stop_event.clear() 

74 

75 process_cache.init() 

76 

77 # region collect cpu usage statistic 

78 self.cpu_stat = [0] * 10 

79 self._collect_cpu_stat_thread = threading.Thread( 

80 target=self._collect_cpu_stat, name="MLTaskConsumer._collect_cpu_stat" 

81 ) 

82 self._collect_cpu_stat_thread.start() 

83 # endregion 

84 

85 self._listen_message_threads = [] 

86 

87 # region connect to redis 

88 config = Config().get("ml_task_queue", {}) 

89 self.db = Database( 

90 host=config.get("host", "localhost"), 

91 port=config.get("port", 6379), 

92 db=config.get("db", 0), 

93 username=config.get("username"), 

94 password=config.get("password"), 

95 protocol=3, 

96 ) 

97 self.wait_redis_ping(60) 

98 

99 self.db.Stream(TASKS_STREAM_NAME) 

100 self.cache = self.db.cache() 

101 self.consumer_group = self.db.consumer_group(TASKS_STREAM_CONSUMER_GROUP_NAME, [TASKS_STREAM_NAME]) 

102 self.consumer_group.create() 

103 self.consumer_group.consumer(TASKS_STREAM_CONSUMER_NAME) 

104 # endregion 

105 

106 def _collect_cpu_stat(self) -> None: 

107 """Collect CPU usage statistic. Executerd in thread.""" 

108 while self._stop_event.is_set() is False: 

109 self.cpu_stat = self.cpu_stat[1:] 

110 self.cpu_stat.append(psutil.cpu_percent()) 

111 time.sleep(1) 

112 

113 def get_avg_cpu_usage(self) -> float: 

114 """get average CPU usage for last period (10s by default) 

115 

116 Returns: 

117 float: 0-100 value, average CPU usage 

118 """ 

119 return sum(self.cpu_stat) / len(self.cpu_stat) 

120 

121 def wait_free_resources(self) -> None: 

122 """Sleep in thread untill there are free resources. Checks: 

123 - avg CPU usage is less than 60% 

124 - current CPU usage is less than 60% 

125 - current tasks count is less than (N CPU cores) / 8 

126 """ 

127 config = Config() 

128 is_cloud = config.get("cloud", False) 

129 processes_dir = Path(tempfile.gettempdir()).joinpath("mindsdb/processes/learn/") 

130 while True: 

131 while self.get_avg_cpu_usage() > 60 or max(self.cpu_stat[-3:]) > 60: 

132 time.sleep(1) 

133 if is_cloud and processes_dir.is_dir(): 

134 clean_unlinked_process_marks() 

135 while (len(list(processes_dir.iterdir())) * 8) >= os.cpu_count(): 

136 time.sleep(1) 

137 clean_unlinked_process_marks() 

138 if (self.get_avg_cpu_usage() > 60 or max(self.cpu_stat[-3:]) > 60) is False: 

139 return 

140 

141 @_save_thread_link 

142 def _listen(self) -> None: 

143 """Listen message queue untill get new message. Execute task.""" 

144 message = None 

145 while message is None: 

146 self.wait_free_resources() 

147 self.wait_redis_ping() 

148 if self._stop_event.is_set(): 

149 return 

150 

151 try: 

152 message = self.consumer_group.read(count=1, block=1000, consumer=TASKS_STREAM_CONSUMER_NAME) 

153 except RedisConnectionError: 

154 logger.exception("Can't connect to Redis:") 

155 self._stop_event.set() 

156 return 

157 except Exception: 

158 self._stop_event.set() 

159 raise 

160 

161 if message.get(TASKS_STREAM_NAME) is None or len(message.get(TASKS_STREAM_NAME)) == 0: 

162 message = None 

163 

164 try: 

165 message = message[TASKS_STREAM_NAME][0][0] 

166 message_id = message[0].decode() 

167 message_content = message[1] 

168 self.consumer_group.streams[TASKS_STREAM_NAME].ack(message_id) 

169 self.consumer_group.streams[TASKS_STREAM_NAME].delete(message_id) 

170 

171 payload = from_bytes(message_content[b"payload"]) 

172 task_type = ML_TASK_TYPE(message_content[b"task_type"]) 

173 model_id = int(message_content[b"model_id"]) 

174 company_id = message_content[b"company_id"] 

175 if len(company_id) == 0: 

176 company_id = None 

177 redis_key = RedisKey(message_content.get(b"redis_key")) 

178 

179 # region read dataframe 

180 dataframe_bytes = self.cache.get(redis_key.dataframe) 

181 dataframe = None 

182 if dataframe_bytes is not None: 

183 dataframe = from_bytes(dataframe_bytes) 

184 self.cache.delete(redis_key.dataframe) 

185 # endregion 

186 

187 ctx.load(payload["context"]) 

188 finally: 

189 self._ready_event.set() 

190 

191 try: 

192 task = process_cache.apply_async( 

193 task_type=task_type, model_id=model_id, payload=payload, dataframe=dataframe 

194 ) 

195 status_notifier = StatusNotifier(redis_key, ML_TASK_STATUS.PROCESSING, self.db, self.cache) 

196 status_notifier.start() 

197 result = task.result() 

198 except Exception as e: 

199 self.wait_redis_ping() 

200 status_notifier.stop() 

201 exception_bytes = to_bytes(e) 

202 self.cache.set(redis_key.exception, exception_bytes, 10) 

203 self.db.publish(redis_key.status, ML_TASK_STATUS.ERROR.value) 

204 self.cache.set(redis_key.status, ML_TASK_STATUS.ERROR.value, 180) 

205 else: 

206 self.wait_redis_ping() 

207 status_notifier.stop() 

208 if isinstance(result, DataFrame): 

209 dataframe_bytes = to_bytes(result) 

210 self.cache.set(redis_key.dataframe, dataframe_bytes, 10) 

211 self.db.publish(redis_key.status, ML_TASK_STATUS.COMPLETE.value) 

212 self.cache.set(redis_key.status, ML_TASK_STATUS.COMPLETE.value, 180) 

213 

214 def run(self) -> None: 

215 """Start new listen thread each time when _ready_event is set""" 

216 self._ready_event.set() 

217 while self._stop_event.is_set() is False: 

218 self._ready_event.wait(timeout=1) 

219 if self._ready_event.is_set() is False: 

220 continue 

221 self._ready_event.clear() 

222 threading.Thread(target=self._listen, name="MLTaskConsumer._listen").start() 

223 self.stop() 

224 

225 def stop(self) -> None: 

226 """Stop all executing threads""" 

227 self._stop_event.set() 

228 for thread in (*self._listen_message_threads, self._collect_cpu_stat_thread): 

229 try: 

230 if thread.is_alive(): 

231 thread.join() 

232 except Exception: 

233 pass 

234 

235 

236@mark_process(name="internal", custom_mark="ml_task_consumer") 

237def start(verbose: bool) -> None: 

238 """Create task queue consumer and start listen the queue""" 

239 consumer = MLTaskConsumer() 

240 signal.signal(signal.SIGTERM, lambda _x, _y: consumer.stop()) 

241 try: 

242 consumer.run() 

243 except Exception as e: 

244 consumer.stop() 

245 logger.error(f"Got exception: {e}", flush=True) 

246 raise 

247 finally: 

248 logger.info("Consumer process stopped", flush=True)