Coverage for mindsdb / integrations / handlers / huggingface_handler / huggingface_handler.py: 0%

215 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Dict, Optional 

2 

3import pandas as pd 

4import transformers 

5from huggingface_hub import HfApi 

6 

7from mindsdb.integrations.handlers.huggingface_handler.settings import FINETUNE_MAP 

8from mindsdb.integrations.libs.base import BaseMLEngine 

9from mindsdb.utilities import log 

10 

11logger = log.getLogger(__name__) 

12 

13 

14class HuggingFaceHandler(BaseMLEngine): 

15 name = "huggingface" 

16 

17 @staticmethod 

18 def create_validation(target, args=None, **kwargs): 

19 if "using" in args: 

20 args = args["using"] 

21 

22 hf_api = HfApi() 

23 

24 # check model is pytorch based 

25 metadata = hf_api.model_info(args["model_name"]) 

26 if "pytorch" not in metadata.tags: 

27 raise Exception( 

28 "Currently only PyTorch models are supported (https://huggingface.co/models?library=pytorch&sort=downloads). To request another library, please contact us on our community slack (https://mindsdb.com/joincommunity)." 

29 ) 

30 

31 # check model task 

32 supported_tasks = [ 

33 "text-classification", 

34 "text-generation", 

35 "zero-shot-classification", 

36 "translation", 

37 "summarization", 

38 "text2text-generation", 

39 "fill-mask", 

40 ] 

41 

42 if metadata.pipeline_tag not in supported_tasks: 

43 raise Exception( 

44 f"Not supported task for model: {metadata.pipeline_tag}.\ 

45 Should be one of {', '.join(supported_tasks)}" 

46 ) 

47 

48 if "task" not in args: 

49 args["task"] = metadata.pipeline_tag 

50 elif args["task"] != metadata.pipeline_tag: 

51 raise Exception(f"Task mismatch for model: {args['task']}!={metadata.pipeline_tag}") 

52 

53 input_keys = list(args.keys()) 

54 

55 # task, model_name, input_column is essential 

56 for key in ["task", "model_name", "input_column"]: 

57 if key not in args: 

58 raise Exception(f'Parameter "{key}" is required') 

59 input_keys.remove(key) 

60 

61 # check tasks input 

62 

63 if args["task"] == "zero-shot-classification": 

64 key = "candidate_labels" 

65 if key not in args: 

66 raise Exception('"candidate_labels" is required for zero-shot-classification') 

67 input_keys.remove(key) 

68 

69 if args["task"] == "translation": 

70 keys = ["lang_input", "lang_output"] 

71 for key in keys: 

72 if key not in args: 

73 raise Exception(f"{key} is required for translation") 

74 input_keys.remove(key) 

75 

76 if args["task"] == "summarization": 

77 keys = ["min_output_length", "max_output_length"] 

78 for key in keys: 

79 if key not in args: 

80 raise Exception(f"{key} is required for summarization") 

81 input_keys.remove(key) 

82 

83 # optional keys 

84 for key in ["labels", "max_length", "truncation_policy"]: 

85 if key in input_keys: 

86 input_keys.remove(key) 

87 

88 if len(input_keys) > 0: 

89 raise Exception(f"Not expected parameters: {', '.join(input_keys)}") 

90 

91 def create(self, target, args=None, **kwargs): 

92 # TODO change BaseMLEngine api? 

93 if "using" in args: 

94 args = args["using"] 

95 

96 args["target"] = target 

97 

98 model_name = args["model_name"] 

99 hf_model_storage_path = self.engine_storage.folder_get(model_name) # real 

100 

101 if args["task"] == "translation": 

102 args["task_proper"] = f"translation_{args['lang_input']}_to_{args['lang_output']}" 

103 else: 

104 args["task_proper"] = args["task"] 

105 

106 logger.debug(f"Checking file system for {model_name}...") 

107 

108 #### 

109 # Check if pipeline has already been downloaded 

110 try: 

111 pipeline = transformers.pipeline( 

112 task=args["task_proper"], model=hf_model_storage_path, tokenizer=hf_model_storage_path 

113 ) 

114 logger.debug("Model already downloaded!") 

115 #### 

116 # Otherwise download it 

117 except (ValueError, OSError): 

118 try: 

119 logger.debug(f"Downloading {model_name}...") 

120 pipeline = transformers.pipeline(task=args["task_proper"], model=model_name) 

121 

122 pipeline.save_pretrained(hf_model_storage_path) 

123 

124 logger.debug(f"Saved to {hf_model_storage_path}") 

125 except Exception: 

126 raise Exception( 

127 "Error while downloading and setting up the model. Please try a different model. We're working on expanding the list of supported models, so we would appreciate it if you let us know about this in our community slack (https://mindsdb.com/joincommunity)." 

128 ) # noqa 

129 #### 

130 

131 if "max_length" in args: 

132 pass 

133 elif "max_position_embeddings" in pipeline.model.config.to_dict().keys(): 

134 args["max_length"] = pipeline.model.config.max_position_embeddings 

135 elif "max_length" in pipeline.model.config.to_dict().keys(): 

136 args["max_length"] = pipeline.model.config.max_length 

137 else: 

138 logger.debug("No max_length found!") 

139 

140 labels_default = pipeline.model.config.id2label 

141 labels_map = {} 

142 if "labels" in args: 

143 for num in labels_default.keys(): 

144 labels_map[labels_default[num]] = args["labels"][num] 

145 args["labels_map"] = labels_map 

146 else: 

147 for num in labels_default.keys(): 

148 labels_map[labels_default[num]] = labels_default[num] 

149 args["labels_map"] = labels_map 

150 

151 # store and persist in model folder 

152 self.model_storage.json_set("args", args) 

153 

154 # persist changes to handler folder 

155 self.engine_storage.folder_sync(model_name) 

156 

157 # todo move infer tasks to a seperate file 

158 def predict_text_classification(self, pipeline, item, args): 

159 top_k = args.get("top_k", 1000) 

160 

161 result = pipeline([item], top_k=top_k, truncation=True, max_length=args["max_length"])[0] 

162 

163 final = {} 

164 explain = {} 

165 if type(result) == dict: 

166 result = [result] 

167 final[args["target"]] = args["labels_map"][result[0]["label"]] 

168 for elem in result: 

169 if args["labels_map"]: 

170 explain[args["labels_map"][elem["label"]]] = elem["score"] 

171 else: 

172 explain[elem["label"]] = elem["score"] 

173 final[f"{args['target']}_explain"] = explain 

174 return final 

175 

176 def predict_text_generation(self, pipeline, item, args): 

177 result = pipeline([item], max_length=args["max_length"])[0] 

178 

179 final = {} 

180 final[args["target"]] = result["generated_text"] 

181 

182 return final 

183 

184 def predict_zero_shot(self, pipeline, item, args): 

185 top_k = args.get("top_k", 1000) 

186 

187 result = pipeline( 

188 [item], 

189 candidate_labels=args["candidate_labels"], 

190 truncation=True, 

191 top_k=top_k, 

192 max_length=args["max_length"], 

193 )[0] 

194 

195 final = {} 

196 final[args["target"]] = result["labels"][0] 

197 

198 explain = dict(zip(result["labels"], result["scores"])) 

199 final[f"{args['target']}_explain"] = explain 

200 

201 return final 

202 

203 def predict_translation(self, pipeline, item, args): 

204 result = pipeline([item], max_length=args["max_length"])[0] 

205 

206 final = {} 

207 final[args["target"]] = result["translation_text"] 

208 

209 return final 

210 

211 def predict_summarization(self, pipeline, item, args): 

212 result = pipeline( 

213 [item], 

214 min_length=args["min_output_length"], 

215 max_length=args["max_output_length"], 

216 )[0] 

217 

218 final = {} 

219 final[args["target"]] = result["summary_text"] 

220 

221 return final 

222 

223 def predict_text2text(self, pipeline, item, args): 

224 result = pipeline([item], max_length=args["max_length"])[0] 

225 

226 final = {} 

227 final[args["target"]] = result["generated_text"] 

228 

229 return final 

230 

231 def predict_fill_mask(self, pipeline, item, args): 

232 result = pipeline([item])[0] 

233 

234 final = {} 

235 final[args["target"]] = result[0]["sequence"] 

236 explain = {elem["sequence"]: elem["score"] for elem in result} 

237 final[f"{args['target']}_explain"] = explain 

238 

239 return final 

240 

241 def predict(self, df, args=None): 

242 fnc_list = { 

243 "text-classification": self.predict_text_classification, 

244 "text-generation": self.predict_text_generation, 

245 "zero-shot-classification": self.predict_zero_shot, 

246 "translation": self.predict_translation, 

247 "summarization": self.predict_summarization, 

248 "fill-mask": self.predict_fill_mask, 

249 } 

250 

251 # get stuff from model folder 

252 args = self.model_storage.json_get("args") 

253 

254 task = args["task"] 

255 

256 if task not in fnc_list: 

257 raise RuntimeError(f"Unknown task: {task}") 

258 

259 fnc = fnc_list[task] 

260 

261 try: 

262 # load from model storage (finetuned models will use this) 

263 hf_model_storage_path = self.model_storage.folder_get(args["model_name"]) 

264 pipeline = transformers.pipeline( 

265 task=args["task_proper"], 

266 model=hf_model_storage_path, 

267 tokenizer=hf_model_storage_path, 

268 ) 

269 except (ValueError, OSError): 

270 # load from engine storage (i.e. 'common' models) 

271 hf_model_storage_path = self.engine_storage.folder_get(args["model_name"]) 

272 pipeline = transformers.pipeline( 

273 task=args["task_proper"], 

274 model=hf_model_storage_path, 

275 tokenizer=hf_model_storage_path, 

276 ) 

277 

278 input_column = args["input_column"] 

279 if input_column not in df.columns: 

280 raise RuntimeError(f'Column "{input_column}" not found in input data') 

281 input_list = df[input_column] 

282 

283 max_tokens = pipeline.tokenizer.model_max_length 

284 

285 results = [] 

286 for item in input_list: 

287 if max_tokens is not None: 

288 tokens = pipeline.tokenizer.encode(item) 

289 if len(tokens) > max_tokens: 

290 truncation_policy = args.get("truncation_policy", "strict") 

291 if truncation_policy == "strict": 

292 results.append({"error": f"Tokens count exceed model limit: {len(tokens)} > {max_tokens}"}) 

293 continue 

294 elif truncation_policy == "left": 

295 tokens = tokens[-max_tokens + 1 : -1] # cut 2 empty tokens from left and right 

296 else: 

297 tokens = tokens[1 : max_tokens - 1] # cut 2 empty tokens from left and right 

298 

299 item = pipeline.tokenizer.decode(tokens) 

300 

301 item = str(item) 

302 try: 

303 result = fnc(pipeline, item, args) 

304 except Exception as e: 

305 msg = str(e).strip() 

306 if msg == "": 

307 msg = e.__class__.__name__ 

308 result = {"error": msg} 

309 results.append(result) 

310 

311 pred_df = pd.DataFrame(results) 

312 

313 return pred_df 

314 

315 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

316 args = self.model_storage.json_get("args") 

317 if attribute == "args": 

318 return pd.DataFrame(args.items(), columns=["key", "value"]) 

319 elif attribute == "metadata": 

320 hf_api = HfApi() 

321 metadata = hf_api.model_info(args["model_name"]) 

322 data = metadata.__dict__ 

323 return pd.DataFrame(list(data.items()), columns=["key", "value"]) 

324 else: 

325 tables = ["args", "metadata"] 

326 return pd.DataFrame(tables, columns=["tables"]) 

327 

328 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

329 finetune_args = args if args else {} 

330 args = self.base_model_storage.json_get("args") 

331 args.update(finetune_args) 

332 

333 model_name = args["model_name"] 

334 model_folder = self.model_storage.folder_get(model_name) 

335 args["model_folder"] = model_folder 

336 model_folder_name = model_folder.split("/")[-1] 

337 task = args["task"] 

338 

339 if task not in FINETUNE_MAP: 

340 raise KeyError( 

341 f"{task} is not currently supported, please choose a supported task - {', '.join(FINETUNE_MAP)}" 

342 ) 

343 

344 tokenizer, trainer = FINETUNE_MAP[task](df, args) 

345 

346 try: 

347 trainer.train() 

348 trainer.save_model( 

349 model_folder 

350 ) # TODO: save entire pipeline instead https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.Pipeline.save_pretrained 

351 tokenizer.save_pretrained(model_folder) 

352 

353 # persist changes 

354 self.model_storage.json_set("args", args) 

355 self.model_storage.folder_sync(model_folder_name) 

356 

357 except Exception as e: 

358 err_str = f"Finetune failed with error: {str(e)}" 

359 logger.debug(err_str) 

360 raise Exception(err_str)