Coverage for mindsdb / interfaces / storage / model_fs.py: 31%
183 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import re
3import json
4import io
5import zipfile
6from typing import Union
8import mindsdb.interfaces.storage.db as db
10from .fs import RESOURCE_GROUP, FileStorageFactory, SERVICE_FILES_NAMES
11from .json import get_json_storage, get_encrypted_json_storage
14JSON_STORAGE_FILE = "json_storage.json"
17class ModelStorage:
18 """
19 This class deals with all model-related storage requirements, from setting status to storing artifacts.
20 """
22 def __init__(self, predictor_id):
23 storageFactory = FileStorageFactory(resource_group=RESOURCE_GROUP.PREDICTOR, sync=True)
24 self.fileStorage = storageFactory(predictor_id)
25 self.predictor_id = predictor_id
27 # -- fields --
29 def _get_model_record(self, model_id: int, check_exists: bool = False) -> Union[db.Predictor, None]:
30 """Get model record by id
32 Args:
33 model_id (int): model id
34 check_exists (bool): true if need to check that model exists
36 Returns:
37 Union[db.Predictor, None]: model record
39 Raises:
40 KeyError: if `check_exists` is True and model does not exists
41 """
42 model_record = db.Predictor.query.get(self.predictor_id)
43 if check_exists is True and model_record is None:
44 raise KeyError("Model does not exists")
45 return model_record
47 def get_info(self):
48 rec = self._get_model_record(self.predictor_id)
49 return dict(status=rec.status, to_predict=rec.to_predict, data=rec.data, learn_args=rec.learn_args)
51 def status_set(self, status, status_info=None):
52 rec = self._get_model_record(self.predictor_id)
53 rec.status = status
54 if status_info is not None:
55 rec.data = status_info
56 db.session.commit()
58 def training_state_set(self, current_state_num=None, total_states=None, state_name=None):
59 rec = self._get_model_record(self.predictor_id)
60 if current_state_num is not None:
61 rec.training_phase_current = current_state_num
62 if total_states is not None:
63 rec.training_phase_total = total_states
64 if state_name is not None:
65 rec.training_phase_name = state_name
66 db.session.commit()
68 def training_state_get(self):
69 rec = self._get_model_record(self.predictor_id)
70 return [rec.training_phase_current, rec.training_phase_total, rec.training_phase_name]
72 def columns_get(self):
73 rec = self._get_model_record(self.predictor_id)
74 return rec.dtype_dict
76 def columns_set(self, columns):
77 # columns: {name: dtype}
79 rec = self._get_model_record(self.predictor_id)
80 rec.dtype_dict = columns
81 db.session.commit()
83 # files
85 def file_get(self, name):
86 return self.fileStorage.file_get(name)
88 def file_set(self, name, content):
89 self.fileStorage.file_set(name, content)
91 def folder_get(self, name):
92 # pull folder and return path
93 name = name.lower().replace(" ", "_")
94 name = re.sub(r"([^a-z^A-Z^_\d]+)", "_", name)
96 self.fileStorage.pull_path(name)
97 return str(self.fileStorage.get_path(name))
99 def folder_sync(self, name):
100 # sync abs path
101 name = name.lower().replace(" ", "_")
102 name = re.sub(r"([^a-z^A-Z^_\d]+)", "_", name)
104 self.fileStorage.push_path(name)
106 def file_list(self): ...
108 def file_del(self, name): ...
110 # jsons
112 def json_set(self, name, data):
113 json_storage = get_json_storage(resource_id=self.predictor_id, resource_group=RESOURCE_GROUP.PREDICTOR)
114 return json_storage.set(name, data)
116 def encrypted_json_set(self, name: str, data: dict) -> None:
117 json_storage = get_encrypted_json_storage(
118 resource_id=self.predictor_id, resource_group=RESOURCE_GROUP.PREDICTOR
119 )
120 return json_storage.set(name, data)
122 def json_get(self, name):
123 json_storage = get_json_storage(resource_id=self.predictor_id, resource_group=RESOURCE_GROUP.PREDICTOR)
124 return json_storage.get(name)
126 def encrypted_json_get(self, name: str) -> dict:
127 json_storage = get_encrypted_json_storage(
128 resource_id=self.predictor_id, resource_group=RESOURCE_GROUP.PREDICTOR
129 )
130 return json_storage.get(name)
132 def json_list(self): ...
134 def json_del(self, name): ...
136 def delete(self):
137 self.fileStorage.delete()
138 json_storage = get_json_storage(resource_id=self.predictor_id, resource_group=RESOURCE_GROUP.PREDICTOR)
139 json_storage.clean()
142class HandlerStorage:
143 """
144 This class deals with all handler-related storage requirements, from storing metadata to synchronizing folders
145 across instances.
146 """
148 def __init__(self, integration_id: int, root_dir: str = None, is_temporal=False):
149 args = {}
150 if root_dir is not None:
151 args["root_dir"] = root_dir
152 storageFactory = FileStorageFactory(resource_group=RESOURCE_GROUP.INTEGRATION, sync=False, **args)
153 self.fileStorage = storageFactory(integration_id)
154 self.integration_id = integration_id
155 self.is_temporal = is_temporal
156 # do not sync with remote storage
158 def __convert_name(self, name):
159 name = name.lower().replace(" ", "_")
160 return re.sub(r"([^a-z^A-Z^_\d]+)", "_", name)
162 def is_empty(self):
163 """check if storage directory is empty
165 Returns:
166 bool: true if dir is empty
167 """
168 for path in self.fileStorage.folder_path.iterdir():
169 if path.is_file() and path.name in SERVICE_FILES_NAMES:
170 continue
171 return False
172 return True
174 def get_connection_args(self):
175 rec = db.Integration.query.get(self.integration_id)
176 return rec.data
178 def update_connection_args(self, connection_args: dict) -> None:
179 """update integration connection args
181 Args:
182 connection_args (dict): new connection args
183 """
184 rec = db.Integration.query.get(self.integration_id)
185 if rec is None:
186 raise KeyError("Can't find integration")
187 rec.data = connection_args
188 db.session.commit()
190 # files
192 def file_get(self, name):
193 self.fileStorage.pull_path(name)
194 return self.fileStorage.file_get(name)
196 def file_set(self, name, content):
197 self.fileStorage.file_set(name, content)
198 if not self.is_temporal:
199 self.fileStorage.push_path(name)
201 def file_list(self): ...
203 def file_del(self, name): ...
205 # folder
207 def folder_get(self, name):
208 """Copies folder from remote to local file system and returns its path
210 :param name: name of the folder
211 """
212 name = self.__convert_name(name)
214 self.fileStorage.pull_path(name)
215 return str(self.fileStorage.get_path(name))
217 def folder_sync(self, name):
218 # sync abs path
219 if self.is_temporal: 219 ↛ 220line 219 didn't jump to line 220 because the condition on line 219 was never true
220 return
221 name = self.__convert_name(name)
222 self.fileStorage.push_path(name)
224 # jsons
226 def json_set(self, name, content):
227 json_storage = get_json_storage(resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION)
228 return json_storage.set(name, content)
230 def encrypted_json_set(self, name: str, content: dict) -> None:
231 json_storage = get_encrypted_json_storage(
232 resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION
233 )
234 return json_storage.set(name, content)
236 def json_get(self, name):
237 json_storage = get_json_storage(resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION)
238 return json_storage.get(name)
240 def encrypted_json_get(self, name: str) -> dict:
241 json_storage = get_encrypted_json_storage(
242 resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION
243 )
244 return json_storage.get(name)
246 def json_list(self): ...
248 def json_del(self, name): ...
250 def export_files(self) -> bytes:
251 json_storage = self.export_json_storage()
253 if self.is_empty() and not json_storage:
254 return None
256 folder_path = self.folder_get("")
258 zip_fd = io.BytesIO()
260 with zipfile.ZipFile(zip_fd, "w", zipfile.ZIP_DEFLATED) as zipf:
261 for root, dirs, files in os.walk(folder_path):
262 for file_name in files:
263 if file_name in SERVICE_FILES_NAMES:
264 continue
265 abs_path = os.path.join(root, file_name)
266 zipf.write(abs_path, os.path.relpath(abs_path, folder_path))
268 # If JSON storage is not empty, add it to the zip file.
269 if json_storage:
270 json_str = json.dumps(json_storage)
271 zipf.writestr(JSON_STORAGE_FILE, json_str)
273 zip_fd.seek(0)
274 return zip_fd.read()
276 def import_files(self, content: bytes):
277 folder_path = self.folder_get("")
279 zip_fd = io.BytesIO()
280 zip_fd.write(content)
281 zip_fd.seek(0)
283 with zipfile.ZipFile(zip_fd, "r") as zip_ref:
284 for name in zip_ref.namelist():
285 # If JSON storage file is in the zip file, import the content to the JSON storage.
286 # Thereafter, remove the file from the folder.
287 if name == JSON_STORAGE_FILE:
288 json_storage = zip_ref.read(JSON_STORAGE_FILE)
289 self.import_json_storage(json_storage)
291 else:
292 zip_ref.extract(name, folder_path)
294 self.folder_sync("")
296 def export_json_storage(self) -> list[dict]:
297 json_storage = get_json_storage(resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION)
299 records = []
300 for record in json_storage.get_all_records():
301 record_dict = record.to_dict()
302 if record_dict.get("encrypted_content"):
303 record_dict["encrypted_content"] = record_dict["encrypted_content"].decode()
304 records.append(record_dict)
306 return records
308 def import_json_storage(self, records: bytes) -> None:
309 json_storage = get_json_storage(resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION)
311 encrypted_json_storage = get_encrypted_json_storage(
312 resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION
313 )
315 records = json.loads(records.decode())
317 for record in records:
318 if record["encrypted_content"]:
319 encrypted_json_storage.set_str(record["name"], record["encrypted_content"])
320 else:
321 json_storage.set(record["name"], record["content"])
323 def delete(self):
324 self.fileStorage.delete()
325 json_storage = get_json_storage(resource_id=self.integration_id, resource_group=RESOURCE_GROUP.INTEGRATION)
326 json_storage.clean()