Coverage for mindsdb / integrations / handlers / palm_handler / palm_handler.py: 0%

229 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import textwrap 

2from pydantic import BaseModel, Extra 

3 

4import google.generativeai as palm 

5import numpy as np 

6import pandas as pd 

7 

8from mindsdb.utilities.hooks import before_palm_query, after_palm_query 

9from mindsdb.utilities import log 

10from mindsdb.integrations.libs.base import BaseMLEngine 

11from mindsdb.integrations.libs.llm.utils import get_completed_prompts 

12 

13from mindsdb.integrations.utilities.handler_utils import get_api_key 

14 

15CHAT_MODELS = ( 

16 "models/chat-bison-001", 

17 "models/embedding-gecko-001", 

18 "models/text-bison-001", 

19) 

20 

21logger = log.getLogger(__name__) 

22 

23 

24class PalmHandlerArgs(BaseModel): 

25 target: str = None 

26 model_name: str = "models/chat-bison-001" 

27 mode: str = "default" 

28 predict_params: dict = None 

29 input_text: str = None 

30 ft_api_info: dict = None 

31 ft_result_stats: dict = None 

32 runtime: str = None 

33 max_output_tokens: int = 64 

34 temperature: float = 0.0 

35 api_key: str = None 

36 palm_api_key: str = None 

37 

38 question_column: str = None 

39 answer_column: str = None 

40 context_column: str = None 

41 prompt_template: str = None 

42 prompt: str = None 

43 user_column: str = None 

44 assistant_column: str = None 

45 

46 class Config: 

47 # for all args that are not expected, raise an error 

48 extra = Extra.forbid 

49 

50 

51class PalmHandler(BaseMLEngine): 

52 name = "palm" 

53 

54 def __init__(self, *args, **kwargs): 

55 super().__init__(*args, **kwargs) 

56 self.generative = True 

57 self.model_name = "models/chat-bison-001" 

58 self.model_name = ( 

59 "default" # can also be 'conversational' or 'conversational-full' 

60 ) 

61 self.supported_modes = [ 

62 "default", 

63 "conversational", 

64 "conversational-full", 

65 "embedding", 

66 ] 

67 self.rate_limit = 60 # requests per minute 

68 self.max_batch_size = 20 

69 self.default_max_output_tokens = 64 

70 self.chat_completion_models = CHAT_MODELS 

71 

72 @staticmethod 

73 def create_validation(target, args=None, **kwargs): 

74 if "using" not in args: 

75 raise Exception( 

76 "palm engine requires a USING clause! Refer to its documentation for more details." 

77 ) 

78 else: 

79 args = args["using"] 

80 

81 if ( 

82 len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) 

83 == 0 

84 ): 

85 raise Exception( 

86 "One of `question_column` or `prompt_template` is required for this engine." 

87 ) 

88 

89 # TODO: add example_column for conversational mode 

90 keys_collection = [ 

91 ["prompt_template"], 

92 ["question_column", "context_column"], 

93 ["prompt", "user_column", "assistant_column"], 

94 ] 

95 for keys in keys_collection: 

96 if keys[0] in args and any( 

97 x[0] in args for x in keys_collection if x != keys 

98 ): 

99 raise Exception( 

100 textwrap.dedent( 

101 """\ 

102 Please provide one of 

103 1) a `prompt_template` 

104 2) a `question_column` and an optional `context_column` 

105 3) a `prompt' and 'user_column' and 'assistant_column` 

106 """ 

107 ) 

108 ) 

109 

110 def create(self, target, args=None, **kwargs): 

111 args = args["using"] 

112 args_model = PalmHandlerArgs(**args) 

113 

114 args_model.target = target 

115 api_key = get_api_key("palm", args["using"], self.engine_storage, strict=False) 

116 

117 # Set palm api key 

118 palm.configure(api_key=api_key) 

119 

120 available_models = [m.name for m in palm.list_models()] 

121 

122 if not args_model.model_name: 

123 args_model.model_name = self.model_name 

124 elif args_model.model_name not in available_models: 

125 raise Exception(f"Invalid model name. Please use one of {available_models}") 

126 

127 if not args_model.mode: 

128 args_model.mode = self.model_name 

129 elif args_model.mode not in self.supported_modes: 

130 raise Exception( 

131 f"Invalid operation mode. Please use one of {self.supported_modes}" 

132 ) 

133 

134 self.model_storage.json_set("args", args_model.model_dump()) 

135 

136 def predict(self, df, args=None): 

137 """ 

138 If there is a prompt template, we use it. Otherwise, we use the concatenation of `context_column` (optional) and `question_column` to ask for a completion. 

139 """ # noqa 

140 # TODO: support for edits, embeddings and moderation 

141 

142 pred_args = args["predict_params"] if args else {} 

143 args_model = PalmHandlerArgs(**self.model_storage.json_get("args")) 

144 df = df.reset_index(drop=True) 

145 

146 if pred_args.get("mode"): 

147 if pred_args["mode"] in self.supported_modes: 

148 args_model.mode = pred_args["mode"] 

149 else: 

150 raise Exception( 

151 f"Invalid operation mode. Please use one of {self.supported_modes}." 

152 ) # noqa 

153 

154 if pred_args.get("prompt_template", False): 

155 base_template = pred_args[ 

156 "prompt_template" 

157 ] # override with predict-time template if available 

158 elif args_model.prompt_template: 

159 base_template = args_model.prompt_template 

160 else: 

161 base_template = None 

162 

163 # Embedding Mode 

164 if args_model.mode == "embedding": 

165 api_args = { 

166 "model": pred_args.get("model_name", "models/embedding-gecko-001") 

167 } 

168 model_name = "models/embedding-gecko-001" 

169 if args_model.question_column: 

170 prompts = list(df[args_model.question_column].apply(lambda x: str(x))) 

171 empty_prompt_ids = np.where( 

172 df[[args_model.question_column]].isna().all(axis=1).values 

173 )[0] 

174 else: 

175 raise Exception("Embedding mode needs a question_column") 

176 

177 # Chat or normal completion mode 

178 else: 

179 if ( 

180 args_model.question_column 

181 and args_model.question_column not in df.columns 

182 ): 

183 raise Exception( 

184 f"This model expects a question to answer in the '{args_model.question_column}' column." 

185 ) 

186 

187 if ( 

188 args_model.context_column 

189 and args_model.context_column not in df.columns 

190 ): 

191 raise Exception( 

192 f"This model expects context in the '{args_model.context_column}' column." 

193 ) 

194 

195 # api argument validation 

196 model_name = args_model.model_name 

197 api_args = { 

198 "max_output_tokens": pred_args.get( 

199 "max_output_tokens", 

200 args_model.max_output_tokens, 

201 ), 

202 "temperature": min( 

203 1.0, 

204 max(0.0, pred_args.get("temperature", args_model.temperature)), 

205 ), 

206 "top_p": pred_args.get("top_p", None), 

207 "candidate_count": pred_args.get("candidate_count", None), 

208 "stop_sequences": pred_args.get("stop_sequences", None), 

209 } 

210 

211 if ( 

212 args_model.mode != "default" 

213 and model_name not in self.chat_completion_models 

214 ): 

215 raise Exception( 

216 f"Conversational modes are only available for the following models: {', '.join(self.chat_completion_models)}" 

217 ) # noqa 

218 

219 if args_model.prompt_template: 

220 prompts, empty_prompt_ids = get_completed_prompts( 

221 base_template, df 

222 ) 

223 if len(prompts) == 0: 

224 raise Exception("No prompts found") 

225 

226 elif args_model.context_column: 

227 empty_prompt_ids = np.where( 

228 df[[args_model.context_column, args_model.question_column]] 

229 .isna() 

230 .all(axis=1) 

231 .values 

232 )[0] 

233 contexts = list(df[args_model.context_column].apply(lambda x: str(x))) 

234 questions = list(df[args_model.question_column].apply(lambda x: str(x))) 

235 prompts = [ 

236 f"Context: {c}\nQuestion: {q}\nAnswer: " 

237 for c, q in zip(contexts, questions) 

238 ] 

239 api_args["context"] = "".join(contexts) 

240 

241 elif args_model.prompt: 

242 empty_prompt_ids = [] 

243 prompts = list(df[args_model.user_column]) 

244 if len(prompts) == 0: 

245 raise Exception("No prompts found") 

246 else: 

247 empty_prompt_ids = np.where( 

248 df[[args_model.question_column]].isna().all(axis=1).values 

249 )[0] 

250 prompts = list(df[args_model.question_column].apply(lambda x: str(x))) 

251 

252 # remove prompts without signal from completion queue 

253 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids] 

254 

255 api_key = get_api_key("palm", args["using"], self.engine_storage, strict=False) 

256 api_args = { 

257 k: v for k, v in api_args.items() if v is not None 

258 } # filter out non-specified api args 

259 completion = self._completion( 

260 model_name, prompts, api_key, api_args, args_model, df 

261 ) 

262 

263 # add null completion for empty prompts 

264 for i in sorted(empty_prompt_ids): 

265 completion.insert(i, None) 

266 

267 pred_df = pd.DataFrame(completion, columns=[args_model.target]) 

268 

269 return pred_df 

270 

271 def _completion(self, model_name, prompts, api_key, api_args, args_model, df): 

272 """ 

273 Handles completion for an arbitrary amount of rows. 

274 Additionally, single completion calls are done with exponential backoff to guarantee all prompts are processed, 

275 because even with previous checks the tokens-per-minute limit may apply. 

276 """ 

277 

278 def _submit_completion(model_name, prompts, api_key, api_args, args_model, df): 

279 kwargs = { 

280 "model": model_name, 

281 } 

282 

283 # configure the PaLM SDK with the provided API KEY 

284 palm.configure(api_key=api_key) 

285 

286 if model_name == "models/embedding-gecko-001": 

287 prompts = "".join(prompts) 

288 return _submit_embedding_completion(kwargs, prompts, api_args) 

289 elif model_name == args_model.model_name: 

290 return _submit_chat_completion( 

291 kwargs, 

292 prompts, 

293 api_args, 

294 df, 

295 mode=args_model.mode, 

296 ) 

297 else: 

298 prompts = "".join(prompts) 

299 return _submit_normal_completion(kwargs, prompts, api_args) 

300 

301 def _log_api_call(params, response): 

302 after_palm_query(params, response) 

303 

304 params2 = params.copy() 

305 params2.pop("palm_api_key", None) 

306 params2.pop("user", None) 

307 logger.debug(f">>>palm call: {params2}:\n{response}") 

308 

309 def _submit_normal_completion(kwargs, prompts, api_args): 

310 def _tidy(comp): 

311 tidy_comps = [] 

312 if comp.candidates and len(comp.candidates) == 0: 

313 return ["No completions found"] 

314 for c in comp.candidates: 

315 if "output" in c: 

316 tidy_comps.append(c["output"].strip("\n").strip("")) 

317 return tidy_comps 

318 

319 kwargs["prompt"] = prompts 

320 kwargs = {**kwargs, **api_args} 

321 

322 before_palm_query(kwargs) 

323 

324 # call the palm sdk with text-bison-001 model 

325 resp = _tidy(palm.generate_text(**kwargs)) 

326 _log_api_call(kwargs, resp) 

327 return resp 

328 

329 def _submit_embedding_completion(kwargs, prompts, api_args): 

330 def _tidy(comp): 

331 tidy_comps = [] 

332 if "embedding" not in comp: 

333 return [f"No completion found, err {comp}"] 

334 for c in comp["embedding"]: 

335 tidy_comps.append([c]) 

336 return tidy_comps 

337 

338 kwargs = {} 

339 kwargs["model"] = api_args["model"] 

340 kwargs["text"] = prompts 

341 

342 before_palm_query(kwargs) 

343 

344 # call the palm sdk with embedding-gecko-001 model 

345 resp = _tidy(palm.generate_embeddings(**kwargs)) 

346 _log_api_call(kwargs, resp) 

347 return resp 

348 

349 def _submit_chat_completion( 

350 kwargs, prompts, api_args, df, mode="conversational" 

351 ): 

352 def _tidy(comp): 

353 tidy_comps = [] 

354 if comp.candidates and len(comp.candidates) == 0: 

355 return ["No completions found"] 

356 

357 for c in comp.candidates: 

358 if "content" in c: 

359 tidy_comps.append(c["content"].strip("\n").strip("")) 

360 if "output" in c: 

361 tidy_comps.append(c["output"].strip("\n").strip("")) 

362 return tidy_comps 

363 

364 completions = [] 

365 if mode != "conversational": 

366 initial_prompt = { 

367 "author": "system", 

368 "content": "You are a helpful assistant. Your task is to continue the chat.", 

369 } # noqa 

370 else: 

371 # get prompt from model 

372 prompt = "".join(prompts) 

373 initial_prompt = {"author": "system", "content": prompt} # noqa 

374 kwargs["messages"] = [initial_prompt] 

375 

376 last_completion_content = None 

377 

378 for pidx in range(len(prompts)): 

379 if mode == "conversational": 

380 kwargs["messages"].append( 

381 {"author": "user", "content": prompts[pidx]} 

382 ) 

383 

384 if mode == "conversational-full" or ( 

385 mode == "conversational" and pidx == len(prompts) - 1 

386 ): 

387 pkwargs = {**kwargs, **api_args} 

388 pkwargs["candidate_count"] = 3 

389 pkwargs.pop("max_output_tokens") 

390 before_palm_query(kwargs) 

391 

392 # call the palm sdk with chat-bison-001 model 

393 resp = _tidy(palm.chat(**pkwargs)) 

394 

395 _log_api_call(pkwargs, resp) 

396 

397 completions.extend(resp) 

398 elif mode == "default": 

399 pkwargs = {**kwargs, **api_args} 

400 

401 pkwargs["model"] = "models/text-bison-001" 

402 pkwargs["prompt"] = prompts[pidx] 

403 before_palm_query(kwargs) 

404 if pkwargs["prompt"] == "": 

405 return ["No prompt provided"] 

406 

407 # call the palm sdk with text-bison-001 model 

408 resp = _tidy(palm.generate_text(**pkwargs)) 

409 _log_api_call(pkwargs, resp) 

410 

411 completions.extend(resp) 

412 else: 

413 # in "normal" conversational mode, we request completions only for the last row 

414 last_completion_content = None 

415 if args_model.answer_column in df.columns: 

416 # insert completion if provided, which saves redundant API calls 

417 completions.extend([df.iloc[pidx][args_model.answer_column]]) 

418 else: 

419 completions.extend([""]) 

420 

421 if args_model.answer_column in df.columns: 

422 kwargs["messages"].append( 

423 { 

424 "author": "assistant", 

425 "content": df.iloc[pidx][args_model.answer_column], 

426 } 

427 ) 

428 elif last_completion_content: 

429 # interleave assistant responses with user input 

430 kwargs["messages"].append( 

431 {"author": "assistant", "content": last_completion_content[0]} 

432 ) 

433 

434 return completions 

435 

436 try: 

437 completion = _submit_completion( 

438 model_name, prompts, api_key, api_args, args_model, df 

439 ) 

440 return completion 

441 except Exception as e: 

442 completion = [] 

443 logger.exception(e) 

444 completion.extend({"error": str(e)}) 

445 

446 return completion