Coverage for mindsdb / integrations / handlers / huggingface_handler / huggingface_handler.py: 0%
215 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Dict, Optional
3import pandas as pd
4import transformers
5from huggingface_hub import HfApi
7from mindsdb.integrations.handlers.huggingface_handler.settings import FINETUNE_MAP
8from mindsdb.integrations.libs.base import BaseMLEngine
9from mindsdb.utilities import log
11logger = log.getLogger(__name__)
14class HuggingFaceHandler(BaseMLEngine):
15 name = "huggingface"
17 @staticmethod
18 def create_validation(target, args=None, **kwargs):
19 if "using" in args:
20 args = args["using"]
22 hf_api = HfApi()
24 # check model is pytorch based
25 metadata = hf_api.model_info(args["model_name"])
26 if "pytorch" not in metadata.tags:
27 raise Exception(
28 "Currently only PyTorch models are supported (https://huggingface.co/models?library=pytorch&sort=downloads). To request another library, please contact us on our community slack (https://mindsdb.com/joincommunity)."
29 )
31 # check model task
32 supported_tasks = [
33 "text-classification",
34 "text-generation",
35 "zero-shot-classification",
36 "translation",
37 "summarization",
38 "text2text-generation",
39 "fill-mask",
40 ]
42 if metadata.pipeline_tag not in supported_tasks:
43 raise Exception(
44 f"Not supported task for model: {metadata.pipeline_tag}.\
45 Should be one of {', '.join(supported_tasks)}"
46 )
48 if "task" not in args:
49 args["task"] = metadata.pipeline_tag
50 elif args["task"] != metadata.pipeline_tag:
51 raise Exception(f"Task mismatch for model: {args['task']}!={metadata.pipeline_tag}")
53 input_keys = list(args.keys())
55 # task, model_name, input_column is essential
56 for key in ["task", "model_name", "input_column"]:
57 if key not in args:
58 raise Exception(f'Parameter "{key}" is required')
59 input_keys.remove(key)
61 # check tasks input
63 if args["task"] == "zero-shot-classification":
64 key = "candidate_labels"
65 if key not in args:
66 raise Exception('"candidate_labels" is required for zero-shot-classification')
67 input_keys.remove(key)
69 if args["task"] == "translation":
70 keys = ["lang_input", "lang_output"]
71 for key in keys:
72 if key not in args:
73 raise Exception(f"{key} is required for translation")
74 input_keys.remove(key)
76 if args["task"] == "summarization":
77 keys = ["min_output_length", "max_output_length"]
78 for key in keys:
79 if key not in args:
80 raise Exception(f"{key} is required for summarization")
81 input_keys.remove(key)
83 # optional keys
84 for key in ["labels", "max_length", "truncation_policy"]:
85 if key in input_keys:
86 input_keys.remove(key)
88 if len(input_keys) > 0:
89 raise Exception(f"Not expected parameters: {', '.join(input_keys)}")
91 def create(self, target, args=None, **kwargs):
92 # TODO change BaseMLEngine api?
93 if "using" in args:
94 args = args["using"]
96 args["target"] = target
98 model_name = args["model_name"]
99 hf_model_storage_path = self.engine_storage.folder_get(model_name) # real
101 if args["task"] == "translation":
102 args["task_proper"] = f"translation_{args['lang_input']}_to_{args['lang_output']}"
103 else:
104 args["task_proper"] = args["task"]
106 logger.debug(f"Checking file system for {model_name}...")
108 ####
109 # Check if pipeline has already been downloaded
110 try:
111 pipeline = transformers.pipeline(
112 task=args["task_proper"], model=hf_model_storage_path, tokenizer=hf_model_storage_path
113 )
114 logger.debug("Model already downloaded!")
115 ####
116 # Otherwise download it
117 except (ValueError, OSError):
118 try:
119 logger.debug(f"Downloading {model_name}...")
120 pipeline = transformers.pipeline(task=args["task_proper"], model=model_name)
122 pipeline.save_pretrained(hf_model_storage_path)
124 logger.debug(f"Saved to {hf_model_storage_path}")
125 except Exception:
126 raise Exception(
127 "Error while downloading and setting up the model. Please try a different model. We're working on expanding the list of supported models, so we would appreciate it if you let us know about this in our community slack (https://mindsdb.com/joincommunity)."
128 ) # noqa
129 ####
131 if "max_length" in args:
132 pass
133 elif "max_position_embeddings" in pipeline.model.config.to_dict().keys():
134 args["max_length"] = pipeline.model.config.max_position_embeddings
135 elif "max_length" in pipeline.model.config.to_dict().keys():
136 args["max_length"] = pipeline.model.config.max_length
137 else:
138 logger.debug("No max_length found!")
140 labels_default = pipeline.model.config.id2label
141 labels_map = {}
142 if "labels" in args:
143 for num in labels_default.keys():
144 labels_map[labels_default[num]] = args["labels"][num]
145 args["labels_map"] = labels_map
146 else:
147 for num in labels_default.keys():
148 labels_map[labels_default[num]] = labels_default[num]
149 args["labels_map"] = labels_map
151 # store and persist in model folder
152 self.model_storage.json_set("args", args)
154 # persist changes to handler folder
155 self.engine_storage.folder_sync(model_name)
157 # todo move infer tasks to a seperate file
158 def predict_text_classification(self, pipeline, item, args):
159 top_k = args.get("top_k", 1000)
161 result = pipeline([item], top_k=top_k, truncation=True, max_length=args["max_length"])[0]
163 final = {}
164 explain = {}
165 if type(result) == dict:
166 result = [result]
167 final[args["target"]] = args["labels_map"][result[0]["label"]]
168 for elem in result:
169 if args["labels_map"]:
170 explain[args["labels_map"][elem["label"]]] = elem["score"]
171 else:
172 explain[elem["label"]] = elem["score"]
173 final[f"{args['target']}_explain"] = explain
174 return final
176 def predict_text_generation(self, pipeline, item, args):
177 result = pipeline([item], max_length=args["max_length"])[0]
179 final = {}
180 final[args["target"]] = result["generated_text"]
182 return final
184 def predict_zero_shot(self, pipeline, item, args):
185 top_k = args.get("top_k", 1000)
187 result = pipeline(
188 [item],
189 candidate_labels=args["candidate_labels"],
190 truncation=True,
191 top_k=top_k,
192 max_length=args["max_length"],
193 )[0]
195 final = {}
196 final[args["target"]] = result["labels"][0]
198 explain = dict(zip(result["labels"], result["scores"]))
199 final[f"{args['target']}_explain"] = explain
201 return final
203 def predict_translation(self, pipeline, item, args):
204 result = pipeline([item], max_length=args["max_length"])[0]
206 final = {}
207 final[args["target"]] = result["translation_text"]
209 return final
211 def predict_summarization(self, pipeline, item, args):
212 result = pipeline(
213 [item],
214 min_length=args["min_output_length"],
215 max_length=args["max_output_length"],
216 )[0]
218 final = {}
219 final[args["target"]] = result["summary_text"]
221 return final
223 def predict_text2text(self, pipeline, item, args):
224 result = pipeline([item], max_length=args["max_length"])[0]
226 final = {}
227 final[args["target"]] = result["generated_text"]
229 return final
231 def predict_fill_mask(self, pipeline, item, args):
232 result = pipeline([item])[0]
234 final = {}
235 final[args["target"]] = result[0]["sequence"]
236 explain = {elem["sequence"]: elem["score"] for elem in result}
237 final[f"{args['target']}_explain"] = explain
239 return final
241 def predict(self, df, args=None):
242 fnc_list = {
243 "text-classification": self.predict_text_classification,
244 "text-generation": self.predict_text_generation,
245 "zero-shot-classification": self.predict_zero_shot,
246 "translation": self.predict_translation,
247 "summarization": self.predict_summarization,
248 "fill-mask": self.predict_fill_mask,
249 }
251 # get stuff from model folder
252 args = self.model_storage.json_get("args")
254 task = args["task"]
256 if task not in fnc_list:
257 raise RuntimeError(f"Unknown task: {task}")
259 fnc = fnc_list[task]
261 try:
262 # load from model storage (finetuned models will use this)
263 hf_model_storage_path = self.model_storage.folder_get(args["model_name"])
264 pipeline = transformers.pipeline(
265 task=args["task_proper"],
266 model=hf_model_storage_path,
267 tokenizer=hf_model_storage_path,
268 )
269 except (ValueError, OSError):
270 # load from engine storage (i.e. 'common' models)
271 hf_model_storage_path = self.engine_storage.folder_get(args["model_name"])
272 pipeline = transformers.pipeline(
273 task=args["task_proper"],
274 model=hf_model_storage_path,
275 tokenizer=hf_model_storage_path,
276 )
278 input_column = args["input_column"]
279 if input_column not in df.columns:
280 raise RuntimeError(f'Column "{input_column}" not found in input data')
281 input_list = df[input_column]
283 max_tokens = pipeline.tokenizer.model_max_length
285 results = []
286 for item in input_list:
287 if max_tokens is not None:
288 tokens = pipeline.tokenizer.encode(item)
289 if len(tokens) > max_tokens:
290 truncation_policy = args.get("truncation_policy", "strict")
291 if truncation_policy == "strict":
292 results.append({"error": f"Tokens count exceed model limit: {len(tokens)} > {max_tokens}"})
293 continue
294 elif truncation_policy == "left":
295 tokens = tokens[-max_tokens + 1 : -1] # cut 2 empty tokens from left and right
296 else:
297 tokens = tokens[1 : max_tokens - 1] # cut 2 empty tokens from left and right
299 item = pipeline.tokenizer.decode(tokens)
301 item = str(item)
302 try:
303 result = fnc(pipeline, item, args)
304 except Exception as e:
305 msg = str(e).strip()
306 if msg == "":
307 msg = e.__class__.__name__
308 result = {"error": msg}
309 results.append(result)
311 pred_df = pd.DataFrame(results)
313 return pred_df
315 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
316 args = self.model_storage.json_get("args")
317 if attribute == "args":
318 return pd.DataFrame(args.items(), columns=["key", "value"])
319 elif attribute == "metadata":
320 hf_api = HfApi()
321 metadata = hf_api.model_info(args["model_name"])
322 data = metadata.__dict__
323 return pd.DataFrame(list(data.items()), columns=["key", "value"])
324 else:
325 tables = ["args", "metadata"]
326 return pd.DataFrame(tables, columns=["tables"])
328 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
329 finetune_args = args if args else {}
330 args = self.base_model_storage.json_get("args")
331 args.update(finetune_args)
333 model_name = args["model_name"]
334 model_folder = self.model_storage.folder_get(model_name)
335 args["model_folder"] = model_folder
336 model_folder_name = model_folder.split("/")[-1]
337 task = args["task"]
339 if task not in FINETUNE_MAP:
340 raise KeyError(
341 f"{task} is not currently supported, please choose a supported task - {', '.join(FINETUNE_MAP)}"
342 )
344 tokenizer, trainer = FINETUNE_MAP[task](df, args)
346 try:
347 trainer.train()
348 trainer.save_model(
349 model_folder
350 ) # TODO: save entire pipeline instead https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.Pipeline.save_pretrained
351 tokenizer.save_pretrained(model_folder)
353 # persist changes
354 self.model_storage.json_set("args", args)
355 self.model_storage.folder_sync(model_folder_name)
357 except Exception as e:
358 err_str = f"Finetune failed with error: {str(e)}"
359 logger.debug(err_str)
360 raise Exception(err_str)