Coverage for mindsdb / utilities / cache.py: 24%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1""" 

2How to use it: 

3 

4 from mindsdb.utilities.cache import get_cache, dataframe_checksum, json_checksum 

5 

6 # namespace of cache 

7 cache = get_cache('predict') 

8 

9 key = dataframe_checksum(df) # or json_checksum, depends on object type 

10 df_predict = cache(key) 

11 

12 if df_predict is None: 

13 # no cache, save it 

14 df_predict = predictor.predict(df) 

15 cache.set(key, df_predict) 

16 

17 

18 

19Configuration: 

20 

21- max_size size of cache in count of records, default is 500 

22- serializer, module for serialization, default is dill 

23 

24It can be set via: 

25- get_cache function: 

26 cache = get_cache('predict', max_size=2) 

27- using specific cache class: 

28 cache = FileCache('predict', max_size=2) 

29- using mindsdb config file: 

30 "cache": { 

31 "type": "redis", 

32 "max_size": 2 

33 } 

34 

35Cache engines: 

36 

37Can be specified in mindsdb config json. Possible values: 

38- local - for FileCache, default 

39- redis - for RedisCache 

40By default is used local redis server. You can specify 

41 "cache": { 

42 "type": "redis", 

43 "connection": { 

44 "host": "127.0.0.1", 

45 "port": 6379 

46 } 

47 } 

48 

49How to test: 

50 

51 env PYTHONPATH=./ pytest tests/unit/test_cache.py 

52 

53""" 

54 

55import os 

56import time 

57from abc import ABC 

58from pathlib import Path 

59import re 

60import hashlib 

61import typing as t 

62 

63import pandas as pd 

64import walrus 

65 

66from mindsdb.utilities.config import Config 

67from mindsdb.utilities.json_encoder import CustomJSONEncoder 

68from mindsdb.interfaces.storage.fs import FileLock 

69from mindsdb.utilities.context import context as ctx 

70 

71_CACHE_MAX_SIZE = 500 

72 

73 

74def dataframe_checksum(df: pd.DataFrame): 

75 original_columns = df.columns 

76 df.columns = list(range(len(df.columns))) 

77 result = hashlib.sha256( 

78 str(df.values).encode() 

79 ).hexdigest() 

80 df.columns = original_columns 

81 return result 

82 

83 

84def json_checksum(obj: t.Union[dict, list]): 

85 checksum = str_checksum(CustomJSONEncoder().encode(obj)) 

86 return checksum 

87 

88 

89def str_checksum(obj: str): 

90 checksum = hashlib.sha256(obj.encode()).hexdigest() 

91 return checksum 

92 

93 

94class BaseCache(ABC): 

95 def __init__(self, max_size=None, serializer=None): 

96 self.config = Config() 

97 if max_size is None: 

98 max_size = self.config["cache"].get("max_size", _CACHE_MAX_SIZE) 

99 self.max_size = max_size 

100 if serializer is None: 

101 serializer_module = self.config["cache"].get('serializer') 

102 if serializer_module == 'pickle': 

103 import pickle as s_module 

104 else: 

105 import dill as s_module 

106 self.serializer = s_module 

107 

108 # default functions 

109 

110 def set_df(self, name, df): 

111 return self.set(name, df) 

112 

113 def get_df(self, name): 

114 return self.get(name) 

115 

116 def serialize(self, value): 

117 return self.serializer.dumps(value) 

118 

119 def deserialize(self, value): 

120 return self.serializer.loads(value) 

121 

122 

123class FileCache(BaseCache): 

124 def __init__(self, category, path=None, **kwargs): 

125 super().__init__(**kwargs) 

126 

127 if path is None: 

128 path = self.config['paths']['cache'] 

129 

130 cache_path = Path(path) / category 

131 

132 company_id = ctx.company_id 

133 if company_id is not None: 

134 cache_path = cache_path / str(company_id) 

135 cache_path.mkdir(parents=True, exist_ok=True) 

136 

137 self.path = cache_path 

138 

139 def clear_old_cache(self): 

140 with FileLock(self.path): 

141 # buffer to delete, to not run delete on every adding 

142 buffer_size = 5 

143 

144 if self.max_size is None: 

145 return 

146 

147 cur_count = len(os.listdir(self.path)) 

148 

149 if cur_count > self.max_size + buffer_size: 

150 try: 

151 files = sorted(Path(self.path).iterdir(), key=os.path.getmtime) 

152 for file in files[:cur_count - self.max_size]: 

153 self.delete_file(file) 

154 except FileNotFoundError: 

155 pass 

156 

157 def file_path(self, name): 

158 # Sanitize the key to avoid table (file) names with backticks and slashes. 

159 sanitized_name = re.sub(r'[^\w\-.]', '_', name) 

160 return self.path / sanitized_name 

161 

162 def set_df(self, name, df): 

163 path = self.file_path(name) 

164 df.to_pickle(path) 

165 self.clear_old_cache() 

166 

167 def set(self, name, value): 

168 path = self.file_path(name) 

169 value = self.serialize(value) 

170 

171 with open(path, 'wb') as fd: 

172 fd.write(value) 

173 self.clear_old_cache() 

174 

175 def get_df(self, name): 

176 path = self.file_path(name) 

177 with FileLock(self.path): 

178 if not os.path.exists(path): 

179 return None 

180 value = pd.read_pickle(path) 

181 return value 

182 

183 def get(self, name): 

184 path = self.file_path(name) 

185 

186 with FileLock(self.path): 

187 if not os.path.exists(path): 

188 return None 

189 with open(path, 'rb') as fd: 

190 value = fd.read() 

191 value = self.deserialize(value) 

192 return value 

193 

194 def delete(self, name): 

195 path = self.file_path(name) 

196 self.delete_file(path) 

197 

198 def delete_file(self, path): 

199 os.unlink(path) 

200 

201 

202class RedisCache(BaseCache): 

203 def __init__(self, category, connection_info=None, **kwargs): 

204 super().__init__(**kwargs) 

205 

206 self.category = category 

207 

208 if connection_info is None: 

209 # if no params will be used local redis 

210 connection_info = self.config["cache"].get("connection", {}) 

211 self.client = walrus.Database(**connection_info) 

212 

213 def clear_old_cache(self, key_added): 

214 

215 if self.max_size is None: 

216 return 

217 

218 # buffer to delete, to not run delete on every adding 

219 buffer_size = 5 

220 

221 cur_count = self.client.hlen(self.category) 

222 

223 # remove oldest 

224 if cur_count > self.max_size + buffer_size: 

225 # 5 is buffer to delete, to not run delete on every adding 

226 

227 keys = self.client.hgetall(self.category) 

228 # to list 

229 keys = list(keys.items()) 

230 # sort by timestamp 

231 keys.sort(key=lambda x: x[1]) 

232 

233 for key, _ in keys[:cur_count - self.max_size]: 

234 self.delete_key(key) 

235 

236 def redis_key(self, name): 

237 return f'{self.category}_{name}' 

238 

239 def set(self, name, value): 

240 key = self.redis_key(name) 

241 value = self.serialize(value) 

242 

243 self.client.set(key, value) 

244 # using key with category name to store all keys with modify time 

245 self.client.hset(self.category, key, int(time.time() * 1000)) 

246 

247 self.clear_old_cache(key) 

248 

249 def get(self, name): 

250 key = self.redis_key(name) 

251 value = self.client.get(key) 

252 if value is None: 

253 # no value in cache 

254 return None 

255 return self.deserialize(value) 

256 

257 def delete(self, name): 

258 key = self.redis_key(name) 

259 

260 self.delete_key(key) 

261 

262 def delete_key(self, key): 

263 self.client.delete(key) 

264 self.client.hdel(self.category, key) 

265 

266 

267class NoCache: 

268 ''' 

269 class for no cache mode 

270 ''' 

271 def __init__(self, *args, **kwargs): 

272 pass 

273 

274 def get(self, name): 

275 return None 

276 

277 def set(self, name, value): 

278 pass 

279 

280 

281def get_cache(category, **kwargs): 

282 config = Config() 

283 if config.get('cache')['type'] == 'redis': 

284 return RedisCache(category, **kwargs) 

285 if config.get('cache')['type'] == 'none': 

286 return NoCache(category, **kwargs) 

287 else: 

288 return FileCache(category, **kwargs)