Coverage for mindsdb / integrations / handlers / google_gemini_handler / google_gemini_handler.py: 0%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2from typing import Dict, Optional 

3 

4from PIL import Image 

5import requests 

6import numpy as np 

7from io import BytesIO 

8import json 

9import textwrap 

10import google.generativeai as genai 

11import pandas as pd 

12from mindsdb.integrations.libs.base import BaseMLEngine 

13from mindsdb.utilities import log 

14from mindsdb.utilities.config import Config 

15from mindsdb.integrations.libs.llm.utils import get_completed_prompts 

16import concurrent.futures 

17 

18logger = log.getLogger(__name__) 

19 

20 

21class GoogleGeminiHandler(BaseMLEngine): 

22 """ 

23 Integration with the Google generative AI Python Library 

24 """ 

25 

26 name = "google_gemini" 

27 

28 def __init__(self, *args, **kwargs): 

29 super().__init__(*args, **kwargs) 

30 self.default_model = "gemini-pro" 

31 self.default_embedding_model = "models/embedding-001" 

32 self.generative = True 

33 self.mode = "default" 

34 

35 # Similiar to openai handler 

36 @staticmethod 

37 def create_validation(target, args=None, **kwargs): 

38 if "using" not in args: 

39 raise Exception( 

40 "Gemini engine requires a USING clause! Refer to its documentation for more details." 

41 ) 

42 else: 

43 args = args["using"] 

44 

45 if ( 

46 len( 

47 set(args.keys()) 

48 & { 

49 "img_url", 

50 "input_text", 

51 "question_column", 

52 "prompt_template", 

53 "json_struct", 

54 "prompt", 

55 } 

56 ) 

57 == 0 

58 ): 

59 raise Exception( 

60 "One of `question_column`, `prompt_template` or `json_struct` is required for this engine." 

61 ) 

62 

63 keys_collection = [ 

64 ["prompt_template"], 

65 ["question_column", "context_column"], 

66 ["prompt", "user_column", "assistant_column"], 

67 ["json_struct", "input_text"], 

68 ["img_url", "ctx_column"], 

69 ] 

70 for keys in keys_collection: 

71 if keys[0] in args and any( 

72 x[0] in args for x in keys_collection if x != keys 

73 ): 

74 raise Exception( 

75 textwrap.dedent( 

76 """\ 

77 Please provide one of 

78 1) a `prompt_template` 

79 2) a `question_column` and an optional `context_column` 

80 3) a `json_struct` 

81 4) a `prompt' and 'user_column' and 'assistant_column` 

82 5) a `img_url` and optional `ctx_column` for mode=`vision` 

83 """ 

84 ) 

85 ) 

86 

87 # for all args that are not expected, raise an error 

88 known_args = set() 

89 # flatten of keys_collection 

90 for keys in keys_collection: 

91 known_args = known_args.union(set(keys)) 

92 

93 # TODO: need a systematic way to maintain a list of known args 

94 known_args = known_args.union( 

95 { 

96 "target", 

97 "model_name", 

98 "mode", 

99 "title_column", 

100 "predict_params", 

101 "type", 

102 "max_tokens", 

103 "temperature", 

104 "api_key", 

105 } 

106 ) 

107 

108 unknown_args = set(args.keys()) - known_args 

109 if unknown_args: 

110 # return a list of unknown args as a string 

111 raise Exception( 

112 f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}" 

113 ) 

114 

115 def create(self, target, args=None, **kwargs): 

116 args = args["using"] 

117 args["target"] = target 

118 self.model_storage.json_set("args", args) 

119 

120 def predict( 

121 self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None 

122 ) -> pd.DataFrame: 

123 pred_args = args["predict_params"] if args else {} 

124 args = self.model_storage.json_get("args") 

125 df = df.reset_index(drop=True) 

126 

127 # same as opeani handler for getting prompt template and mode 

128 if pred_args.get("prompt_template", False): 

129 base_template = pred_args[ 

130 "prompt_template" 

131 ] # override with predict-time template if available 

132 elif args.get("prompt_template", False): 

133 base_template = args["prompt_template"] 

134 else: 

135 base_template = None 

136 

137 # Embedding Mode 

138 if args.get("mode") == "embedding": 

139 args["type"] = pred_args.get("type", "query") 

140 return self.embedding_worker(args, df) 

141 

142 elif args.get("mode") == "vision": 

143 return self.vision_worker(args, df) 

144 

145 elif args.get("mode") == "conversational": 

146 # Enable chat mode using 

147 # https://ai.google.dev/tutorials/python_quickstart#chat_conversations 

148 # OR 

149 # https://github.com/google/generative-ai-python?tab=readme-ov-file#developers-who-use-the-palm-api 

150 pass 

151 

152 else: 

153 if args.get("prompt_template", False): 

154 prompts, empty_prompt_ids = get_completed_prompts(base_template, df) 

155 

156 # Disclaimer: The following code has been adapted from the OpenAI handler. 

157 elif args.get("context_column", False): 

158 empty_prompt_ids = np.where( 

159 df[[args["context_column"], args["question_column"]]] 

160 .isna() 

161 .all(axis=1) 

162 .values 

163 )[0] 

164 contexts = list(df[args["context_column"]].apply(lambda x: str(x))) 

165 questions = list(df[args["question_column"]].apply(lambda x: str(x))) 

166 prompts = [ 

167 f"Give only answer for: \nContext: {c}\nQuestion: {q}\nAnswer: " 

168 for c, q in zip(contexts, questions) 

169 ] 

170 

171 # Disclaimer: The following code has been adapted from the OpenAI handler. 

172 elif args.get("json_struct", False): 

173 empty_prompt_ids = np.where( 

174 df[[args["input_text"]]].isna().all(axis=1).values 

175 )[0] 

176 prompts = [] 

177 for i in df.index: 

178 if "json_struct" in df.columns: 

179 if isinstance(df["json_struct"][i], str): 

180 df["json_struct"][i] = json.loads(df["json_struct"][i]) 

181 json_struct = "" 

182 for ind, val in enumerate(df["json_struct"][i].values()): 

183 json_struct = json_struct + f"{ind}. {val}\n" 

184 else: 

185 json_struct = "" 

186 for ind, val in enumerate(args["json_struct"].values()): 

187 json_struct = json_struct + f"{ind + 1}. {val}\n" 

188 

189 p = textwrap.dedent( 

190 f"""\ 

191 Using text starting after 'The text is:', give exactly {len(args['json_struct'])} answers to the questions: 

192 {{{{json_struct}}}} 

193 

194 Answers should be in the same order as the questions. 

195 Answer should be in form of one JSON Object eg. {"{'key':'value',..}"} where key=question and value=answer. 

196 If there is no answer to the question in the text, put a -. 

197 Answers should be as short as possible, ideally 1-2 words (unless otherwise specified). 

198 

199 The text is: 

200 {{{{{args['input_text']}}}}} 

201 """ 

202 ) 

203 p = p.replace("{{json_struct}}", json_struct) 

204 for column in df.columns: 

205 if column == "json_struct": 

206 continue 

207 p = p.replace(f"{{{{{column}}}}}", str(df[column][i])) 

208 prompts.append(p) 

209 elif "prompt" in args: 

210 empty_prompt_ids = [] 

211 prompts = list(df[args["user_column"]]) 

212 else: 

213 empty_prompt_ids = np.where( 

214 df[[args["question_column"]]].isna().all(axis=1).values 

215 )[0] 

216 prompts = list(df[args["question_column"]].apply(lambda x: str(x))) 

217 

218 # remove prompts without signal from completion queue 

219 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids] 

220 

221 api_key = self._get_google_gemini_api_key(args) 

222 genai.configure(api_key=api_key) 

223 

224 # called gemini model withinputs 

225 model = genai.GenerativeModel(args.get("model_name", self.default_model)) 

226 results = [] 

227 for m in prompts: 

228 results.append(model.generate_content(m).text) 

229 

230 pred_df = pd.DataFrame(results, columns=[args["target"]]) 

231 return pred_df 

232 

233 def _get_google_gemini_api_key(self, args, strict=True): 

234 """ 

235 API_KEY preference order: 

236 1. provided at model creation 

237 2. provided at engine creation 

238 3. GOOGLE_GENAI_API_KEY env variable 

239 4. google_gemini.api_key setting in config.json 

240 """ 

241 

242 if "api_key" in args: 

243 return args["api_key"] 

244 # 2 

245 connection_args = self.engine_storage.get_connection_args() 

246 if "api_key" in connection_args: 

247 return connection_args["api_key"] 

248 # 3 

249 api_key = os.getenv("GOOGLE_GENAI_API_KEY") 

250 if api_key is not None: 

251 return api_key 

252 # 4 

253 config = Config() 

254 google_gemini_config = config.get("google_gemini", {}) 

255 if "api_key" in google_gemini_config: 

256 return google_gemini_config["api_key"] 

257 

258 if strict: 

259 raise Exception( 

260 'Missing API key "api_key". Either re-create this ML_ENGINE specifying the `api_key` parameter,\ 

261 or re-create this model and pass the API key with `USING` syntax.' 

262 ) 

263 

264 def embedding_worker(self, args: Dict, df: pd.DataFrame): 

265 if args.get("question_column"): 

266 prompts = list(df[args["question_column"]].apply(lambda x: str(x))) 

267 if args.get("title_column", None): 

268 titles = list(df[args["title_column"]].apply(lambda x: str(x))) 

269 else: 

270 titles = None 

271 

272 api_key = self._get_google_gemini_api_key(args) 

273 genai.configure(api_key=api_key) 

274 model_name = args.get("model_name", self.default_embedding_model) 

275 task_type = args.get("type") 

276 task_type = f"retrieval_{task_type}" 

277 

278 if task_type == "retrieval_query": 

279 results = [ 

280 str( 

281 genai.embed_content( 

282 model=model_name, content=query, task_type=task_type 

283 )["embedding"] 

284 ) 

285 for query in prompts 

286 ] 

287 elif titles: 

288 results = [ 

289 str( 

290 genai.embed_content( 

291 model=model_name, 

292 content=doc, 

293 task_type=task_type, 

294 title=title, 

295 )["embedding"] 

296 ) 

297 for title, doc in zip(titles, prompts) 

298 ] 

299 else: 

300 results = [ 

301 str( 

302 genai.embed_content( 

303 model=model_name, content=doc, task_type=task_type 

304 )["embedding"] 

305 ) 

306 for doc in prompts 

307 ] 

308 

309 pred_df = pd.DataFrame(results, columns=[args["target"]]) 

310 return pred_df 

311 else: 

312 raise Exception("Embedding mode needs a question_column") 

313 

314 def vision_worker(self, args: Dict, df: pd.DataFrame): 

315 def get_img(url): 

316 # URL Validation 

317 response = requests.get(url) 

318 if response.status_code == 200 and response.headers.get( 

319 "content-type", "" 

320 ).startswith("image/"): 

321 return Image.open(BytesIO(response.content)) 

322 else: 

323 raise Exception(f"{url} is not vaild image URL..") 

324 

325 if args.get("img_url"): 

326 urls = list(df[args["img_url"]].apply(lambda x: str(x))) 

327 

328 else: 

329 raise Exception("Vision mode needs a img_url") 

330 

331 prompts = None 

332 if args.get("ctx_column"): 

333 prompts = list(df[args["ctx_column"]].apply(lambda x: str(x))) 

334 

335 api_key = self._get_google_gemini_api_key(args) 

336 genai.configure(api_key=api_key) 

337 model = genai.GenerativeModel("gemini-pro-vision") 

338 with concurrent.futures.ThreadPoolExecutor() as executor: 

339 # Download images concurrently using ThreadPoolExecutor 

340 imgs = list(executor.map(get_img, urls)) 

341 # imgs = [Image.open(BytesIO(requests.get(url).content)) for url in urls] 

342 if prompts: 

343 results = [ 

344 model.generate_content([img, text]).text 

345 for img, text in zip(imgs, prompts) 

346 ] 

347 else: 

348 results = [model.generate_content(img).text for img in imgs] 

349 

350 pred_df = pd.DataFrame(results, columns=[args["target"]]) 

351 

352 return pred_df 

353 

354 # Disclaimer: The following code has been adapted from the OpenAI handler. 

355 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

356 

357 args = self.model_storage.json_get("args") 

358 

359 if attribute == "args": 

360 return pd.DataFrame(args.items(), columns=["key", "value"]) 

361 elif attribute == "metadata": 

362 api_key = self._get_google_gemini_api_key(args) 

363 genai.configure(api_key=api_key) 

364 model_name = args.get("model_name", self.default_model) 

365 

366 meta = genai.get_model(f"models/{model_name}").__dict__ 

367 return pd.DataFrame(meta.items(), columns=["key", "value"]) 

368 else: 

369 tables = ["args", "metadata"] 

370 return pd.DataFrame(tables, columns=["tables"])