Coverage for mindsdb / integrations / handlers / google_gemini_handler / google_gemini_handler.py: 0%
171 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2from typing import Dict, Optional
4from PIL import Image
5import requests
6import numpy as np
7from io import BytesIO
8import json
9import textwrap
10import google.generativeai as genai
11import pandas as pd
12from mindsdb.integrations.libs.base import BaseMLEngine
13from mindsdb.utilities import log
14from mindsdb.utilities.config import Config
15from mindsdb.integrations.libs.llm.utils import get_completed_prompts
16import concurrent.futures
18logger = log.getLogger(__name__)
21class GoogleGeminiHandler(BaseMLEngine):
22 """
23 Integration with the Google generative AI Python Library
24 """
26 name = "google_gemini"
28 def __init__(self, *args, **kwargs):
29 super().__init__(*args, **kwargs)
30 self.default_model = "gemini-pro"
31 self.default_embedding_model = "models/embedding-001"
32 self.generative = True
33 self.mode = "default"
35 # Similiar to openai handler
36 @staticmethod
37 def create_validation(target, args=None, **kwargs):
38 if "using" not in args:
39 raise Exception(
40 "Gemini engine requires a USING clause! Refer to its documentation for more details."
41 )
42 else:
43 args = args["using"]
45 if (
46 len(
47 set(args.keys())
48 & {
49 "img_url",
50 "input_text",
51 "question_column",
52 "prompt_template",
53 "json_struct",
54 "prompt",
55 }
56 )
57 == 0
58 ):
59 raise Exception(
60 "One of `question_column`, `prompt_template` or `json_struct` is required for this engine."
61 )
63 keys_collection = [
64 ["prompt_template"],
65 ["question_column", "context_column"],
66 ["prompt", "user_column", "assistant_column"],
67 ["json_struct", "input_text"],
68 ["img_url", "ctx_column"],
69 ]
70 for keys in keys_collection:
71 if keys[0] in args and any(
72 x[0] in args for x in keys_collection if x != keys
73 ):
74 raise Exception(
75 textwrap.dedent(
76 """\
77 Please provide one of
78 1) a `prompt_template`
79 2) a `question_column` and an optional `context_column`
80 3) a `json_struct`
81 4) a `prompt' and 'user_column' and 'assistant_column`
82 5) a `img_url` and optional `ctx_column` for mode=`vision`
83 """
84 )
85 )
87 # for all args that are not expected, raise an error
88 known_args = set()
89 # flatten of keys_collection
90 for keys in keys_collection:
91 known_args = known_args.union(set(keys))
93 # TODO: need a systematic way to maintain a list of known args
94 known_args = known_args.union(
95 {
96 "target",
97 "model_name",
98 "mode",
99 "title_column",
100 "predict_params",
101 "type",
102 "max_tokens",
103 "temperature",
104 "api_key",
105 }
106 )
108 unknown_args = set(args.keys()) - known_args
109 if unknown_args:
110 # return a list of unknown args as a string
111 raise Exception(
112 f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}"
113 )
115 def create(self, target, args=None, **kwargs):
116 args = args["using"]
117 args["target"] = target
118 self.model_storage.json_set("args", args)
120 def predict(
121 self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None
122 ) -> pd.DataFrame:
123 pred_args = args["predict_params"] if args else {}
124 args = self.model_storage.json_get("args")
125 df = df.reset_index(drop=True)
127 # same as opeani handler for getting prompt template and mode
128 if pred_args.get("prompt_template", False):
129 base_template = pred_args[
130 "prompt_template"
131 ] # override with predict-time template if available
132 elif args.get("prompt_template", False):
133 base_template = args["prompt_template"]
134 else:
135 base_template = None
137 # Embedding Mode
138 if args.get("mode") == "embedding":
139 args["type"] = pred_args.get("type", "query")
140 return self.embedding_worker(args, df)
142 elif args.get("mode") == "vision":
143 return self.vision_worker(args, df)
145 elif args.get("mode") == "conversational":
146 # Enable chat mode using
147 # https://ai.google.dev/tutorials/python_quickstart#chat_conversations
148 # OR
149 # https://github.com/google/generative-ai-python?tab=readme-ov-file#developers-who-use-the-palm-api
150 pass
152 else:
153 if args.get("prompt_template", False):
154 prompts, empty_prompt_ids = get_completed_prompts(base_template, df)
156 # Disclaimer: The following code has been adapted from the OpenAI handler.
157 elif args.get("context_column", False):
158 empty_prompt_ids = np.where(
159 df[[args["context_column"], args["question_column"]]]
160 .isna()
161 .all(axis=1)
162 .values
163 )[0]
164 contexts = list(df[args["context_column"]].apply(lambda x: str(x)))
165 questions = list(df[args["question_column"]].apply(lambda x: str(x)))
166 prompts = [
167 f"Give only answer for: \nContext: {c}\nQuestion: {q}\nAnswer: "
168 for c, q in zip(contexts, questions)
169 ]
171 # Disclaimer: The following code has been adapted from the OpenAI handler.
172 elif args.get("json_struct", False):
173 empty_prompt_ids = np.where(
174 df[[args["input_text"]]].isna().all(axis=1).values
175 )[0]
176 prompts = []
177 for i in df.index:
178 if "json_struct" in df.columns:
179 if isinstance(df["json_struct"][i], str):
180 df["json_struct"][i] = json.loads(df["json_struct"][i])
181 json_struct = ""
182 for ind, val in enumerate(df["json_struct"][i].values()):
183 json_struct = json_struct + f"{ind}. {val}\n"
184 else:
185 json_struct = ""
186 for ind, val in enumerate(args["json_struct"].values()):
187 json_struct = json_struct + f"{ind + 1}. {val}\n"
189 p = textwrap.dedent(
190 f"""\
191 Using text starting after 'The text is:', give exactly {len(args['json_struct'])} answers to the questions:
192 {{{{json_struct}}}}
194 Answers should be in the same order as the questions.
195 Answer should be in form of one JSON Object eg. {"{'key':'value',..}"} where key=question and value=answer.
196 If there is no answer to the question in the text, put a -.
197 Answers should be as short as possible, ideally 1-2 words (unless otherwise specified).
199 The text is:
200 {{{{{args['input_text']}}}}}
201 """
202 )
203 p = p.replace("{{json_struct}}", json_struct)
204 for column in df.columns:
205 if column == "json_struct":
206 continue
207 p = p.replace(f"{{{{{column}}}}}", str(df[column][i]))
208 prompts.append(p)
209 elif "prompt" in args:
210 empty_prompt_ids = []
211 prompts = list(df[args["user_column"]])
212 else:
213 empty_prompt_ids = np.where(
214 df[[args["question_column"]]].isna().all(axis=1).values
215 )[0]
216 prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
218 # remove prompts without signal from completion queue
219 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids]
221 api_key = self._get_google_gemini_api_key(args)
222 genai.configure(api_key=api_key)
224 # called gemini model withinputs
225 model = genai.GenerativeModel(args.get("model_name", self.default_model))
226 results = []
227 for m in prompts:
228 results.append(model.generate_content(m).text)
230 pred_df = pd.DataFrame(results, columns=[args["target"]])
231 return pred_df
233 def _get_google_gemini_api_key(self, args, strict=True):
234 """
235 API_KEY preference order:
236 1. provided at model creation
237 2. provided at engine creation
238 3. GOOGLE_GENAI_API_KEY env variable
239 4. google_gemini.api_key setting in config.json
240 """
242 if "api_key" in args:
243 return args["api_key"]
244 # 2
245 connection_args = self.engine_storage.get_connection_args()
246 if "api_key" in connection_args:
247 return connection_args["api_key"]
248 # 3
249 api_key = os.getenv("GOOGLE_GENAI_API_KEY")
250 if api_key is not None:
251 return api_key
252 # 4
253 config = Config()
254 google_gemini_config = config.get("google_gemini", {})
255 if "api_key" in google_gemini_config:
256 return google_gemini_config["api_key"]
258 if strict:
259 raise Exception(
260 'Missing API key "api_key". Either re-create this ML_ENGINE specifying the `api_key` parameter,\
261 or re-create this model and pass the API key with `USING` syntax.'
262 )
264 def embedding_worker(self, args: Dict, df: pd.DataFrame):
265 if args.get("question_column"):
266 prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
267 if args.get("title_column", None):
268 titles = list(df[args["title_column"]].apply(lambda x: str(x)))
269 else:
270 titles = None
272 api_key = self._get_google_gemini_api_key(args)
273 genai.configure(api_key=api_key)
274 model_name = args.get("model_name", self.default_embedding_model)
275 task_type = args.get("type")
276 task_type = f"retrieval_{task_type}"
278 if task_type == "retrieval_query":
279 results = [
280 str(
281 genai.embed_content(
282 model=model_name, content=query, task_type=task_type
283 )["embedding"]
284 )
285 for query in prompts
286 ]
287 elif titles:
288 results = [
289 str(
290 genai.embed_content(
291 model=model_name,
292 content=doc,
293 task_type=task_type,
294 title=title,
295 )["embedding"]
296 )
297 for title, doc in zip(titles, prompts)
298 ]
299 else:
300 results = [
301 str(
302 genai.embed_content(
303 model=model_name, content=doc, task_type=task_type
304 )["embedding"]
305 )
306 for doc in prompts
307 ]
309 pred_df = pd.DataFrame(results, columns=[args["target"]])
310 return pred_df
311 else:
312 raise Exception("Embedding mode needs a question_column")
314 def vision_worker(self, args: Dict, df: pd.DataFrame):
315 def get_img(url):
316 # URL Validation
317 response = requests.get(url)
318 if response.status_code == 200 and response.headers.get(
319 "content-type", ""
320 ).startswith("image/"):
321 return Image.open(BytesIO(response.content))
322 else:
323 raise Exception(f"{url} is not vaild image URL..")
325 if args.get("img_url"):
326 urls = list(df[args["img_url"]].apply(lambda x: str(x)))
328 else:
329 raise Exception("Vision mode needs a img_url")
331 prompts = None
332 if args.get("ctx_column"):
333 prompts = list(df[args["ctx_column"]].apply(lambda x: str(x)))
335 api_key = self._get_google_gemini_api_key(args)
336 genai.configure(api_key=api_key)
337 model = genai.GenerativeModel("gemini-pro-vision")
338 with concurrent.futures.ThreadPoolExecutor() as executor:
339 # Download images concurrently using ThreadPoolExecutor
340 imgs = list(executor.map(get_img, urls))
341 # imgs = [Image.open(BytesIO(requests.get(url).content)) for url in urls]
342 if prompts:
343 results = [
344 model.generate_content([img, text]).text
345 for img, text in zip(imgs, prompts)
346 ]
347 else:
348 results = [model.generate_content(img).text for img in imgs]
350 pred_df = pd.DataFrame(results, columns=[args["target"]])
352 return pred_df
354 # Disclaimer: The following code has been adapted from the OpenAI handler.
355 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
357 args = self.model_storage.json_get("args")
359 if attribute == "args":
360 return pd.DataFrame(args.items(), columns=["key", "value"])
361 elif attribute == "metadata":
362 api_key = self._get_google_gemini_api_key(args)
363 genai.configure(api_key=api_key)
364 model_name = args.get("model_name", self.default_model)
366 meta = genai.get_model(f"models/{model_name}").__dict__
367 return pd.DataFrame(meta.items(), columns=["key", "value"])
368 else:
369 tables = ["args", "metadata"]
370 return pd.DataFrame(tables, columns=["tables"])