Coverage for mindsdb / integrations / handlers / palm_handler / palm_handler.py: 0%
229 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import textwrap
2from pydantic import BaseModel, Extra
4import google.generativeai as palm
5import numpy as np
6import pandas as pd
8from mindsdb.utilities.hooks import before_palm_query, after_palm_query
9from mindsdb.utilities import log
10from mindsdb.integrations.libs.base import BaseMLEngine
11from mindsdb.integrations.libs.llm.utils import get_completed_prompts
13from mindsdb.integrations.utilities.handler_utils import get_api_key
15CHAT_MODELS = (
16 "models/chat-bison-001",
17 "models/embedding-gecko-001",
18 "models/text-bison-001",
19)
21logger = log.getLogger(__name__)
24class PalmHandlerArgs(BaseModel):
25 target: str = None
26 model_name: str = "models/chat-bison-001"
27 mode: str = "default"
28 predict_params: dict = None
29 input_text: str = None
30 ft_api_info: dict = None
31 ft_result_stats: dict = None
32 runtime: str = None
33 max_output_tokens: int = 64
34 temperature: float = 0.0
35 api_key: str = None
36 palm_api_key: str = None
38 question_column: str = None
39 answer_column: str = None
40 context_column: str = None
41 prompt_template: str = None
42 prompt: str = None
43 user_column: str = None
44 assistant_column: str = None
46 class Config:
47 # for all args that are not expected, raise an error
48 extra = Extra.forbid
51class PalmHandler(BaseMLEngine):
52 name = "palm"
54 def __init__(self, *args, **kwargs):
55 super().__init__(*args, **kwargs)
56 self.generative = True
57 self.model_name = "models/chat-bison-001"
58 self.model_name = (
59 "default" # can also be 'conversational' or 'conversational-full'
60 )
61 self.supported_modes = [
62 "default",
63 "conversational",
64 "conversational-full",
65 "embedding",
66 ]
67 self.rate_limit = 60 # requests per minute
68 self.max_batch_size = 20
69 self.default_max_output_tokens = 64
70 self.chat_completion_models = CHAT_MODELS
72 @staticmethod
73 def create_validation(target, args=None, **kwargs):
74 if "using" not in args:
75 raise Exception(
76 "palm engine requires a USING clause! Refer to its documentation for more details."
77 )
78 else:
79 args = args["using"]
81 if (
82 len(set(args.keys()) & {"question_column", "prompt_template", "prompt"})
83 == 0
84 ):
85 raise Exception(
86 "One of `question_column` or `prompt_template` is required for this engine."
87 )
89 # TODO: add example_column for conversational mode
90 keys_collection = [
91 ["prompt_template"],
92 ["question_column", "context_column"],
93 ["prompt", "user_column", "assistant_column"],
94 ]
95 for keys in keys_collection:
96 if keys[0] in args and any(
97 x[0] in args for x in keys_collection if x != keys
98 ):
99 raise Exception(
100 textwrap.dedent(
101 """\
102 Please provide one of
103 1) a `prompt_template`
104 2) a `question_column` and an optional `context_column`
105 3) a `prompt' and 'user_column' and 'assistant_column`
106 """
107 )
108 )
110 def create(self, target, args=None, **kwargs):
111 args = args["using"]
112 args_model = PalmHandlerArgs(**args)
114 args_model.target = target
115 api_key = get_api_key("palm", args["using"], self.engine_storage, strict=False)
117 # Set palm api key
118 palm.configure(api_key=api_key)
120 available_models = [m.name for m in palm.list_models()]
122 if not args_model.model_name:
123 args_model.model_name = self.model_name
124 elif args_model.model_name not in available_models:
125 raise Exception(f"Invalid model name. Please use one of {available_models}")
127 if not args_model.mode:
128 args_model.mode = self.model_name
129 elif args_model.mode not in self.supported_modes:
130 raise Exception(
131 f"Invalid operation mode. Please use one of {self.supported_modes}"
132 )
134 self.model_storage.json_set("args", args_model.model_dump())
136 def predict(self, df, args=None):
137 """
138 If there is a prompt template, we use it. Otherwise, we use the concatenation of `context_column` (optional) and `question_column` to ask for a completion.
139 """ # noqa
140 # TODO: support for edits, embeddings and moderation
142 pred_args = args["predict_params"] if args else {}
143 args_model = PalmHandlerArgs(**self.model_storage.json_get("args"))
144 df = df.reset_index(drop=True)
146 if pred_args.get("mode"):
147 if pred_args["mode"] in self.supported_modes:
148 args_model.mode = pred_args["mode"]
149 else:
150 raise Exception(
151 f"Invalid operation mode. Please use one of {self.supported_modes}."
152 ) # noqa
154 if pred_args.get("prompt_template", False):
155 base_template = pred_args[
156 "prompt_template"
157 ] # override with predict-time template if available
158 elif args_model.prompt_template:
159 base_template = args_model.prompt_template
160 else:
161 base_template = None
163 # Embedding Mode
164 if args_model.mode == "embedding":
165 api_args = {
166 "model": pred_args.get("model_name", "models/embedding-gecko-001")
167 }
168 model_name = "models/embedding-gecko-001"
169 if args_model.question_column:
170 prompts = list(df[args_model.question_column].apply(lambda x: str(x)))
171 empty_prompt_ids = np.where(
172 df[[args_model.question_column]].isna().all(axis=1).values
173 )[0]
174 else:
175 raise Exception("Embedding mode needs a question_column")
177 # Chat or normal completion mode
178 else:
179 if (
180 args_model.question_column
181 and args_model.question_column not in df.columns
182 ):
183 raise Exception(
184 f"This model expects a question to answer in the '{args_model.question_column}' column."
185 )
187 if (
188 args_model.context_column
189 and args_model.context_column not in df.columns
190 ):
191 raise Exception(
192 f"This model expects context in the '{args_model.context_column}' column."
193 )
195 # api argument validation
196 model_name = args_model.model_name
197 api_args = {
198 "max_output_tokens": pred_args.get(
199 "max_output_tokens",
200 args_model.max_output_tokens,
201 ),
202 "temperature": min(
203 1.0,
204 max(0.0, pred_args.get("temperature", args_model.temperature)),
205 ),
206 "top_p": pred_args.get("top_p", None),
207 "candidate_count": pred_args.get("candidate_count", None),
208 "stop_sequences": pred_args.get("stop_sequences", None),
209 }
211 if (
212 args_model.mode != "default"
213 and model_name not in self.chat_completion_models
214 ):
215 raise Exception(
216 f"Conversational modes are only available for the following models: {', '.join(self.chat_completion_models)}"
217 ) # noqa
219 if args_model.prompt_template:
220 prompts, empty_prompt_ids = get_completed_prompts(
221 base_template, df
222 )
223 if len(prompts) == 0:
224 raise Exception("No prompts found")
226 elif args_model.context_column:
227 empty_prompt_ids = np.where(
228 df[[args_model.context_column, args_model.question_column]]
229 .isna()
230 .all(axis=1)
231 .values
232 )[0]
233 contexts = list(df[args_model.context_column].apply(lambda x: str(x)))
234 questions = list(df[args_model.question_column].apply(lambda x: str(x)))
235 prompts = [
236 f"Context: {c}\nQuestion: {q}\nAnswer: "
237 for c, q in zip(contexts, questions)
238 ]
239 api_args["context"] = "".join(contexts)
241 elif args_model.prompt:
242 empty_prompt_ids = []
243 prompts = list(df[args_model.user_column])
244 if len(prompts) == 0:
245 raise Exception("No prompts found")
246 else:
247 empty_prompt_ids = np.where(
248 df[[args_model.question_column]].isna().all(axis=1).values
249 )[0]
250 prompts = list(df[args_model.question_column].apply(lambda x: str(x)))
252 # remove prompts without signal from completion queue
253 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids]
255 api_key = get_api_key("palm", args["using"], self.engine_storage, strict=False)
256 api_args = {
257 k: v for k, v in api_args.items() if v is not None
258 } # filter out non-specified api args
259 completion = self._completion(
260 model_name, prompts, api_key, api_args, args_model, df
261 )
263 # add null completion for empty prompts
264 for i in sorted(empty_prompt_ids):
265 completion.insert(i, None)
267 pred_df = pd.DataFrame(completion, columns=[args_model.target])
269 return pred_df
271 def _completion(self, model_name, prompts, api_key, api_args, args_model, df):
272 """
273 Handles completion for an arbitrary amount of rows.
274 Additionally, single completion calls are done with exponential backoff to guarantee all prompts are processed,
275 because even with previous checks the tokens-per-minute limit may apply.
276 """
278 def _submit_completion(model_name, prompts, api_key, api_args, args_model, df):
279 kwargs = {
280 "model": model_name,
281 }
283 # configure the PaLM SDK with the provided API KEY
284 palm.configure(api_key=api_key)
286 if model_name == "models/embedding-gecko-001":
287 prompts = "".join(prompts)
288 return _submit_embedding_completion(kwargs, prompts, api_args)
289 elif model_name == args_model.model_name:
290 return _submit_chat_completion(
291 kwargs,
292 prompts,
293 api_args,
294 df,
295 mode=args_model.mode,
296 )
297 else:
298 prompts = "".join(prompts)
299 return _submit_normal_completion(kwargs, prompts, api_args)
301 def _log_api_call(params, response):
302 after_palm_query(params, response)
304 params2 = params.copy()
305 params2.pop("palm_api_key", None)
306 params2.pop("user", None)
307 logger.debug(f">>>palm call: {params2}:\n{response}")
309 def _submit_normal_completion(kwargs, prompts, api_args):
310 def _tidy(comp):
311 tidy_comps = []
312 if comp.candidates and len(comp.candidates) == 0:
313 return ["No completions found"]
314 for c in comp.candidates:
315 if "output" in c:
316 tidy_comps.append(c["output"].strip("\n").strip(""))
317 return tidy_comps
319 kwargs["prompt"] = prompts
320 kwargs = {**kwargs, **api_args}
322 before_palm_query(kwargs)
324 # call the palm sdk with text-bison-001 model
325 resp = _tidy(palm.generate_text(**kwargs))
326 _log_api_call(kwargs, resp)
327 return resp
329 def _submit_embedding_completion(kwargs, prompts, api_args):
330 def _tidy(comp):
331 tidy_comps = []
332 if "embedding" not in comp:
333 return [f"No completion found, err {comp}"]
334 for c in comp["embedding"]:
335 tidy_comps.append([c])
336 return tidy_comps
338 kwargs = {}
339 kwargs["model"] = api_args["model"]
340 kwargs["text"] = prompts
342 before_palm_query(kwargs)
344 # call the palm sdk with embedding-gecko-001 model
345 resp = _tidy(palm.generate_embeddings(**kwargs))
346 _log_api_call(kwargs, resp)
347 return resp
349 def _submit_chat_completion(
350 kwargs, prompts, api_args, df, mode="conversational"
351 ):
352 def _tidy(comp):
353 tidy_comps = []
354 if comp.candidates and len(comp.candidates) == 0:
355 return ["No completions found"]
357 for c in comp.candidates:
358 if "content" in c:
359 tidy_comps.append(c["content"].strip("\n").strip(""))
360 if "output" in c:
361 tidy_comps.append(c["output"].strip("\n").strip(""))
362 return tidy_comps
364 completions = []
365 if mode != "conversational":
366 initial_prompt = {
367 "author": "system",
368 "content": "You are a helpful assistant. Your task is to continue the chat.",
369 } # noqa
370 else:
371 # get prompt from model
372 prompt = "".join(prompts)
373 initial_prompt = {"author": "system", "content": prompt} # noqa
374 kwargs["messages"] = [initial_prompt]
376 last_completion_content = None
378 for pidx in range(len(prompts)):
379 if mode == "conversational":
380 kwargs["messages"].append(
381 {"author": "user", "content": prompts[pidx]}
382 )
384 if mode == "conversational-full" or (
385 mode == "conversational" and pidx == len(prompts) - 1
386 ):
387 pkwargs = {**kwargs, **api_args}
388 pkwargs["candidate_count"] = 3
389 pkwargs.pop("max_output_tokens")
390 before_palm_query(kwargs)
392 # call the palm sdk with chat-bison-001 model
393 resp = _tidy(palm.chat(**pkwargs))
395 _log_api_call(pkwargs, resp)
397 completions.extend(resp)
398 elif mode == "default":
399 pkwargs = {**kwargs, **api_args}
401 pkwargs["model"] = "models/text-bison-001"
402 pkwargs["prompt"] = prompts[pidx]
403 before_palm_query(kwargs)
404 if pkwargs["prompt"] == "":
405 return ["No prompt provided"]
407 # call the palm sdk with text-bison-001 model
408 resp = _tidy(palm.generate_text(**pkwargs))
409 _log_api_call(pkwargs, resp)
411 completions.extend(resp)
412 else:
413 # in "normal" conversational mode, we request completions only for the last row
414 last_completion_content = None
415 if args_model.answer_column in df.columns:
416 # insert completion if provided, which saves redundant API calls
417 completions.extend([df.iloc[pidx][args_model.answer_column]])
418 else:
419 completions.extend([""])
421 if args_model.answer_column in df.columns:
422 kwargs["messages"].append(
423 {
424 "author": "assistant",
425 "content": df.iloc[pidx][args_model.answer_column],
426 }
427 )
428 elif last_completion_content:
429 # interleave assistant responses with user input
430 kwargs["messages"].append(
431 {"author": "assistant", "content": last_completion_content[0]}
432 )
434 return completions
436 try:
437 completion = _submit_completion(
438 model_name, prompts, api_key, api_args, args_model, df
439 )
440 return completion
441 except Exception as e:
442 completion = []
443 logger.exception(e)
444 completion.extend({"error": str(e)})
446 return completion