Coverage for mindsdb / integrations / handlers / openai_handler / openai_handler.py: 61%

456 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import math 

3import json 

4import shutil 

5import tempfile 

6import datetime 

7import textwrap 

8import subprocess 

9from enum import Enum 

10import concurrent.futures 

11from typing import Text, Tuple, Dict, List, Optional, Any 

12import openai 

13from openai.types.fine_tuning import FineTuningJob 

14from openai import OpenAI, AzureOpenAI, NotFoundError, AuthenticationError 

15import numpy as np 

16import pandas as pd 

17 

18from mindsdb.utilities.hooks import before_openai_query, after_openai_query 

19from mindsdb.utilities import log 

20from mindsdb.integrations.libs.base import BaseMLEngine 

21from mindsdb.integrations.handlers.openai_handler.helpers import ( 

22 retry_with_exponential_backoff, 

23 truncate_msgs_for_token_limit, 

24 get_available_models, 

25 PendingFT, 

26) 

27from mindsdb.integrations.handlers.openai_handler.constants import ( 

28 CHAT_MODELS_PREFIXES, 

29 IMAGE_MODELS, 

30 FINETUNING_MODELS, 

31 OPENAI_API_BASE, 

32 DEFAULT_CHAT_MODEL, 

33 DEFAULT_EMBEDDING_MODEL, 

34 DEFAULT_IMAGE_MODEL, 

35) 

36from mindsdb.integrations.libs.llm.utils import get_completed_prompts 

37from mindsdb.integrations.utilities.handler_utils import get_api_key 

38 

39logger = log.getLogger(__name__) 

40 

41 

42class Mode(Enum): 

43 default = "default" 

44 conversational = "conversational" 

45 conversational_full = "conversational-full" 

46 image = "image" 

47 embedding = "embedding" 

48 legacy = "legacy" 

49 

50 @classmethod 

51 def _missing_(cls, value): 

52 raise ValueError(f"Invalid operation mode '{value}'. Please use one of: {[val.name for val in cls]}") 

53 

54 

55class OpenAIHandler(BaseMLEngine): 

56 """ 

57 This handler handles connection and inference with the OpenAI API. 

58 """ 

59 

60 name = "openai" 

61 

62 def __init__(self, *args, **kwargs): 

63 super().__init__(*args, **kwargs) 

64 self.generative = True 

65 self.default_model = DEFAULT_CHAT_MODEL 

66 self.default_embedding_model = DEFAULT_EMBEDDING_MODEL 

67 self.default_image_model = DEFAULT_IMAGE_MODEL 

68 self.default_mode = Mode.default # can also be 'conversational' or 'conversational-full' 

69 self.rate_limit = 60 # requests per minute 

70 self.max_batch_size = 20 

71 self.default_max_tokens = 100 

72 self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning 

73 # For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods 

74 self.api_key_name = getattr(self, "api_key_name", self.name) 

75 self.api_base = getattr(self, "api_base", OPENAI_API_BASE) 

76 

77 def create_engine(self, connection_args: Dict) -> None: 

78 """ 

79 Validate the OpenAI API credentials on engine creation. 

80 

81 Args: 

82 connection_args (Dict): Parameters for the engine. 

83 

84 Raises: 

85 Exception: If the handler is not configured with valid API credentials. 

86 

87 Returns: 

88 None 

89 """ 

90 connection_args = {k.lower(): v for k, v in connection_args.items()} 

91 api_key = connection_args.get("openai_api_key") 

92 if api_key is not None: 

93 org = connection_args.get("api_organization") 

94 api_base = connection_args.get("api_base") or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE) 

95 client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args) 

96 OpenAIHandler._check_client_connection(client) 

97 

98 @staticmethod 

99 def is_chat_model(model_name): 

100 for prefix in CHAT_MODELS_PREFIXES: 

101 if model_name.startswith(prefix): 

102 return True 

103 return False 

104 

105 @staticmethod 

106 def _check_client_connection(client: OpenAI) -> None: 

107 """ 

108 Check the OpenAI engine client connection by retrieving a model. 

109 

110 Args: 

111 client (openai.OpenAI): OpenAI client configured with the API credentials. 

112 

113 Raises: 

114 Exception: If the client connection (API key) is invalid or something else goes wrong. 

115 

116 Returns: 

117 None 

118 """ 

119 try: 

120 client.models.retrieve("test") 

121 except NotFoundError: 

122 pass 

123 except AuthenticationError as e: 

124 if isinstance(e.body, dict) and e.body.get("code") == "invalid_api_key": 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true

125 raise Exception("Invalid api key") 

126 raise Exception(f"Something went wrong: {e}") 

127 

128 @staticmethod 

129 def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None: 

130 """ 

131 Validate the OpenAI API credentials on model creation. 

132 

133 Args: 

134 target (Text): Target column name. 

135 args (Dict): Parameters for the model. 

136 kwargs (Any): Other keyword arguments. 

137 

138 Raises: 

139 Exception: If the handler is not configured with valid API credentials. 

140 

141 Returns: 

142 None 

143 """ 

144 if "using" not in args: 

145 raise Exception("OpenAI engine requires a USING clause! Refer to its documentation for more details.") 

146 else: 

147 args = args["using"] 

148 

149 if len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) == 0: 

150 raise Exception("One of `question_column`, `prompt_template` or `prompt` is required for this engine.") 

151 

152 keys_collection = [ 

153 ["prompt_template"], 

154 ["question_column", "context_column"], 

155 ["prompt", "user_column", "assistant_column"], 

156 ] 

157 for keys in keys_collection: 

158 if keys[0] in args and any(x[0] in args for x in keys_collection if x != keys): 

159 raise Exception( 

160 textwrap.dedent( 

161 """\ 

162 Please provide one of 

163 1) a `prompt_template` 

164 2) a `question_column` and an optional `context_column` 

165 3) a `prompt', 'user_column' and 'assistant_column` 

166 """ 

167 ) 

168 ) 

169 

170 # for all args that are not expected, raise an error 

171 known_args = set() 

172 # flatten of keys_collection 

173 for keys in keys_collection: 

174 known_args = known_args.union(set(keys)) 

175 

176 # TODO: need a systematic way to maintain a list of known args 

177 known_args = known_args.union( 

178 { 

179 "target", 

180 "model_name", 

181 "mode", 

182 "predict_params", 

183 "json_struct", 

184 "ft_api_info", 

185 "ft_result_stats", 

186 "runtime", 

187 "max_tokens", 

188 "temperature", 

189 "openai_api_key", 

190 "api_organization", 

191 "api_base", 

192 "api_version", 

193 "provider", 

194 } 

195 ) 

196 

197 unknown_args = set(args.keys()) - known_args 

198 if unknown_args: 

199 # return a list of unknown args as a string 

200 raise Exception( 

201 f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}" 

202 ) 

203 

204 engine_storage = kwargs["handler_storage"] 

205 connection_args = engine_storage.get_connection_args() 

206 api_key = get_api_key("openai", args, engine_storage=engine_storage) 

207 api_base = ( 

208 args.get("api_base") 

209 or connection_args.get("api_base") 

210 or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE) 

211 ) 

212 org = args.get("api_organization") 

213 client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args) 

214 OpenAIHandler._check_client_connection(client) 

215 

216 def create(self, target, args: Dict = None, **kwargs: Any) -> None: 

217 """ 

218 Create a model by connecting to the OpenAI API. 

219 

220 Args: 

221 target (Text): Target column name. 

222 args (Dict): Parameters for the model. 

223 kwargs (Any): Other keyword arguments. 

224 

225 Raises: 

226 Exception: If the model is not configured with valid parameters. 

227 

228 Returns: 

229 None 

230 """ 

231 args = args["using"] 

232 args["target"] = target 

233 try: 

234 api_key = get_api_key(self.api_key_name, args, self.engine_storage) 

235 connection_args = self.engine_storage.get_connection_args() 

236 api_base = ( 

237 args.get("api_base") 

238 or connection_args.get("api_base") 

239 or os.environ.get("OPENAI_API_BASE") 

240 or self.api_base 

241 ) 

242 client = self._get_client(api_key=api_key, base_url=api_base, org=args.get("api_organization"), args=args) 

243 available_models = get_available_models(client) 

244 

245 mode = args.get("mode") 

246 if mode is not None: 

247 mode = Mode(mode) 

248 else: 

249 mode = self.default_mode 

250 

251 if not args.get("model_name"): 

252 if mode is Mode.embedding: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 args["model_name"] = self.default_embedding_model 

254 elif mode is Mode.image: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true

255 args["model_name"] = self.default_image_model 

256 else: 

257 args["model_name"] = self.default_model 

258 elif (args["model_name"] not in available_models) and (mode is not Mode.embedding): 258 ↛ 261line 258 didn't jump to line 261 because the condition on line 258 was always true

259 raise Exception(f"Invalid model name. Please use one of {available_models}") 

260 finally: 

261 self.model_storage.json_set("args", args) 

262 

263 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

264 """ 

265 Make predictions using a model connected to the OpenAI API. 

266 

267 Args: 

268 df (pd.DataFrame): Input data to make predictions on. 

269 args (Dict): Parameters passed when making predictions. 

270 

271 Raises: 

272 Exception: If the model is not configured with valid parameters or if the input data is not in the expected format. 

273 

274 Returns: 

275 pd.DataFrame: Input data with the predicted values in a new column. 

276 """ # noqa 

277 # TODO: support for edits, embeddings and moderation 

278 

279 pred_args = args["predict_params"] if args else {} 

280 args = self.model_storage.json_get("args") 

281 connection_args = self.engine_storage.get_connection_args() 

282 

283 args["api_base"] = ( 

284 pred_args.get("api_base") 

285 or args.get("api_base") 

286 or connection_args.get("api_base") 

287 or os.environ.get("OPENAI_API_BASE") 

288 or self.api_base 

289 ) 

290 

291 if pred_args.get("api_organization"): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 args["api_organization"] = pred_args["api_organization"] 

293 df = df.reset_index(drop=True) 

294 

295 if pred_args.get("mode"): 

296 mode = Mode(pred_args["mode"]) 

297 args["mode"] = mode.value 

298 elif args.get("mode"): 298 ↛ 301line 298 didn't jump to line 301 because the condition on line 298 was always true

299 mode = Mode(args["mode"]) 

300 else: 

301 mode = Mode(self.default_mode) 

302 

303 strict_prompt_template = True 

304 if pred_args.get("prompt_template", False): 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 base_template = pred_args["prompt_template"] # override with predict-time template if available 

306 strict_prompt_template = False 

307 elif args.get("prompt_template", False): 

308 base_template = args["prompt_template"] 

309 else: 

310 base_template = None 

311 

312 # Embedding mode 

313 if mode is Mode.embedding: 

314 api_args = { 

315 "question_column": pred_args.get("question_column", None), 

316 "model": pred_args.get("model_name") or args.get("model_name"), 

317 } 

318 model_name = "embedding" 

319 if args.get("question_column"): 

320 prompts = list(df[args["question_column"]].apply(lambda x: str(x))) 

321 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0] 

322 else: 

323 raise Exception("Embedding mode needs a question_column") 

324 

325 # Image mode 

326 elif mode is Mode.image: 

327 api_args = { 

328 "n": pred_args.get("n", None), 

329 "size": pred_args.get("size", None), 

330 "response_format": pred_args.get("response_format", None), 

331 } 

332 api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args 

333 model_name = pred_args.get("model_name") or args.get("model_name") 

334 

335 if args.get("question_column"): 

336 prompts = list(df[args["question_column"]].apply(lambda x: str(x))) 

337 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0] 

338 elif args.get("prompt_template"): 

339 prompts, empty_prompt_ids = get_completed_prompts(base_template, df) 

340 else: 

341 raise Exception("Image mode needs either `prompt_template` or `question_column`.") 

342 

343 # Chat or normal completion mode 

344 else: 

345 if args.get("question_column", False) and args["question_column"] not in df.columns: 

346 raise Exception(f"This model expects a question to answer in the '{args['question_column']}' column.") 

347 

348 if args.get("context_column", False) and args["context_column"] not in df.columns: 

349 raise Exception(f"This model expects context in the '{args['context_column']}' column.") 

350 

351 # API argument validation 

352 model_name = args.get("model_name", self.default_model) 

353 api_args = { 

354 "max_tokens": pred_args.get("max_tokens", args.get("max_tokens", self.default_max_tokens)), 

355 "temperature": min( 

356 1.0, 

357 max(0.0, pred_args.get("temperature", args.get("temperature", 0.0))), 

358 ), 

359 "top_p": pred_args.get("top_p", None), 

360 "n": pred_args.get("n", None), 

361 "stop": pred_args.get("stop", None), 

362 "presence_penalty": pred_args.get("presence_penalty", None), 

363 "frequency_penalty": pred_args.get("frequency_penalty", None), 

364 "best_of": pred_args.get("best_of", None), 

365 "logit_bias": pred_args.get("logit_bias", None), 

366 "user": pred_args.get("user", None), 

367 } 

368 

369 if args.get("prompt_template", False): 

370 prompts, empty_prompt_ids = get_completed_prompts(base_template, df, strict=strict_prompt_template) 

371 

372 elif args.get("context_column", False): 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true

373 empty_prompt_ids = np.where( 

374 df[[args["context_column"], args["question_column"]]].isna().all(axis=1).values 

375 )[0] 

376 contexts = list(df[args["context_column"]].apply(lambda x: str(x))) 

377 questions = list(df[args["question_column"]].apply(lambda x: str(x))) 

378 prompts = [f"Context: {c}\nQuestion: {q}\nAnswer: " for c, q in zip(contexts, questions)] 

379 

380 elif "prompt" in args: 

381 empty_prompt_ids = [] 

382 prompts = list(df[args["user_column"]]) 

383 else: 

384 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0] 

385 prompts = list(df[args["question_column"]].apply(lambda x: str(x))) 

386 

387 # add json struct if available 

388 if args.get("json_struct", False): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true

389 for i, prompt in enumerate(prompts): 

390 json_struct = "" 

391 if "json_struct" in df.columns and i not in empty_prompt_ids: 

392 # if row has a specific json, we try to use it instead of the base prompt template 

393 try: 

394 if isinstance(df["json_struct"][i], str): 

395 df["json_struct"][i] = json.loads(df["json_struct"][i]) 

396 for ind, val in enumerate(df["json_struct"][i].values()): 

397 json_struct = json_struct + f"{ind}. {val}\n" 

398 except Exception: 

399 pass # if the row's json is invalid, we use the prompt template instead 

400 

401 if json_struct == "": 

402 for ind, val in enumerate(args["json_struct"].values()): 

403 json_struct = json_struct + f"{ind + 1}. {val}\n" 

404 

405 p = textwrap.dedent( 

406 f"""\ 

407 Based on the text following 'The reference text is:', assign values to the following {len(args["json_struct"])} JSON attributes: 

408 {{{{json_struct}}}} 

409 

410 Values should follow the same order as the attributes above. 

411 Each line in the answer should start with a dotted number, and should not repeat the name of the attribute, just the value. 

412 Each answer must end with new line. 

413 If there is no valid value to a given attribute in the text, answer with a - character. 

414 Values should be as short as possible, ideally 1-2 words (unless otherwise specified). 

415 

416 Here is an example input of 3 attributes: 

417 1. rental price 

418 2. location 

419 3. number of bathrooms 

420 

421 Here is an example output for the input: 

422 1. 3000 

423 2. Manhattan 

424 3. 2 

425 

426 Now for the real task. The reference text is: 

427 {prompt} 

428 """ 

429 ) 

430 

431 p = p.replace("{{json_struct}}", json_struct) 

432 prompts[i] = p 

433 

434 # remove prompts without signal from completion queue 

435 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids] 

436 

437 api_key = get_api_key(self.api_key_name, args, self.engine_storage) 

438 api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args 

439 completion = self._completion(model_name, prompts, api_key, api_args, args, df) 

440 

441 # add null completion for empty prompts 

442 for i in sorted(empty_prompt_ids): 442 ↛ 443line 442 didn't jump to line 443 because the loop on line 442 never started

443 completion.insert(i, None) 

444 

445 pred_df = pd.DataFrame(completion, columns=[args["target"]]) 

446 

447 # restore json struct 

448 if args.get("json_struct", False): 448 ↛ 449line 448 didn't jump to line 449 because the condition on line 448 was never true

449 for i in pred_df.index: 

450 try: 

451 if "json_struct" in df.columns: 

452 json_keys = df["json_struct"][i].keys() 

453 else: 

454 json_keys = args["json_struct"].keys() 

455 responses = pred_df[args["target"]][i].split("\n") 

456 responses = [x[3:] for x in responses] # del question index 

457 

458 pred_df[args["target"]][i] = {key: val for key, val in zip(json_keys, responses)} 

459 except Exception: 

460 pred_df[args["target"]][i] = None 

461 

462 return pred_df 

463 

464 def _completion( 

465 self, 

466 model_name: Text, 

467 prompts: List[Text], 

468 api_key: Text, 

469 api_args: Dict, 

470 args: Dict, 

471 df: pd.DataFrame, 

472 parallel: bool = True, 

473 ) -> List[Any]: 

474 """ 

475 Handles completion for an arbitrary amount of rows using a model connected to the OpenAI API. 

476 

477 This method consists of several inner methods: 

478 - _submit_completion: Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task. 

479 - _submit_normal_completion: Submit a request to the completion endpoint of the OpenAI API. 

480 - _submit_embedding_completion: Submit a request to the embeddings endpoint of the OpenAI API. 

481 - _submit_chat_completion: Submit a request to the chat completion endpoint of the OpenAI API. 

482 - _submit_image_completion: Submit a request to the image completion endpoint of the OpenAI API. 

483 - _log_api_call: Log the API call made to the OpenAI API. 

484 

485 There are a couple checks that should be done when calling OpenAI's API: 

486 - account max batch size, to maximize batch size first 

487 - account rate limit, to maximize parallel calls second 

488 

489 Additionally, single completion calls are done with exponential backoff to guarantee all prompts are processed, 

490 because even with previous checks the tokens-per-minute limit may apply. 

491 

492 Args: 

493 model_name (Text): OpenAI Model name. 

494 prompts (List[Text]): List of prompts. 

495 api_key (Text): OpenAI API key. 

496 api_args (Dict): OpenAI API arguments. 

497 args (Dict): Parameters for the model. 

498 df (pd.DataFrame): Input data to run completion on. 

499 parallel (bool): Whether to use parallel processing. 

500 

501 Returns: 

502 List[Any]: List of completions. The type of completion depends on the task type. 

503 """ 

504 

505 @retry_with_exponential_backoff() 

506 def _submit_completion( 

507 model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame 

508 ) -> List[Text]: 

509 """ 

510 Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task. 

511 

512 Args: 

513 model_name (Text): OpenAI Model name. 

514 prompts (List[Text]): List of prompts. 

515 api_args (Dict): OpenAI API arguments. 

516 args (Dict): Parameters for the model. 

517 df (pd.DataFrame): Input data to run completion on. 

518 

519 Returns: 

520 List[Text]: List of completions. 

521 """ 

522 kwargs = { 

523 "model": model_name, 

524 } 

525 try: 

526 mode = Mode(args.get("mode")) 

527 except ValueError: 

528 if model_name in IMAGE_MODELS: 

529 mode = Mode.image 

530 elif model_name == "embedding": 

531 mode = Mode.embedding 

532 elif self.is_chat_model(model_name) and model_name != "gpt-3.5-turbo-instruct": 

533 mode = Mode.conversational 

534 elif model_name == "gpt-3.5-turbo-instruct": 

535 mode = Mode.legacy 

536 else: 

537 mode = Mode.default 

538 

539 match mode: 

540 case Mode.image: 

541 return _submit_image_completion(kwargs, prompts, api_args) 

542 case Mode.embedding: 

543 return _submit_embedding_completion(kwargs, prompts, api_args) 

544 case Mode.conversational | Mode.conversational_full | Mode.default: 

545 return _submit_chat_completion( 

546 kwargs, 

547 prompts, 

548 api_args, 

549 df, 

550 mode=args.get("mode", "conversational"), 

551 ) 

552 case Mode.legacy: 552 ↛ exitline 552 didn't return from function '_submit_completion' because the pattern on line 552 always matched

553 return _submit_normal_completion(kwargs, prompts, api_args) 

554 

555 def _log_api_call(params: Dict, response: Any) -> None: 

556 """ 

557 Log the API call made to the OpenAI API. 

558 

559 Args: 

560 params (Dict): Parameters for the API call. 

561 response (Any): Response from the API. 

562 

563 Returns: 

564 None 

565 """ 

566 after_openai_query(params, response) 

567 

568 params2 = params.copy() 

569 params2.pop("api_key", None) 

570 params2.pop("user", None) 

571 logger.debug(f">>>openai call: {params2}:\n{response}") 

572 

573 def _submit_normal_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]: 

574 """ 

575 Submit a request to the completion endpoint of the OpenAI API. 

576 

577 This method consists of an inner method: 

578 - _tidy: Parse and tidy up the response from the completion endpoint of the OpenAI API. 

579 

580 Args: 

581 kwargs (Dict): OpenAI API arguments, including the model to use. 

582 prompts (List[Text]): List of prompts. 

583 api_args (Dict): Other OpenAI API arguments. 

584 

585 Returns: 

586 List[Text]: List of text completions. 

587 """ 

588 

589 def _tidy(comp: openai.types.completion.Completion) -> List[Text]: 

590 """ 

591 Parse and tidy up the response from the completion endpoint of the OpenAI API. 

592 

593 Args: 

594 comp (openai.types.completion.Completion): Completion object. 

595 

596 Returns: 

597 List[Text]: List of completions as text. 

598 """ 

599 tidy_comps = [] 

600 for c in comp.choices: 

601 if hasattr(c, "text"): 601 ↛ 600line 601 didn't jump to line 600 because the condition on line 601 was always true

602 tidy_comps.append(c.text.strip("\n").strip("")) 

603 return tidy_comps 

604 

605 kwargs = {**kwargs, **api_args} 

606 

607 before_openai_query(kwargs) 

608 responses = [] 

609 for prompt in prompts: 

610 responses.extend(_tidy(client.completions.create(prompt=prompt, **kwargs))) 

611 _log_api_call(kwargs, responses) 

612 return responses 

613 

614 def _submit_embedding_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[float]: 

615 """ 

616 Submit a request to the embeddings endpoint of the OpenAI API. 

617 

618 This method consists of an inner method: 

619 - _tidy: Parse and tidy up the response from the embeddings endpoint of the OpenAI API. 

620 

621 Args: 

622 kwargs (Dict): OpenAI API arguments, including the model to use. 

623 prompts (List[Text]): List of prompts. 

624 api_args (Dict): Other OpenAI API arguments. 

625 

626 Returns: 

627 List[float]: List of embeddings as numbers. 

628 """ 

629 

630 def _tidy(comp: openai.types.create_embedding_response.CreateEmbeddingResponse) -> List[float]: 

631 """ 

632 Parse and tidy up the response from the embeddings endpoint of the OpenAI API. 

633 

634 Args: 

635 comp (openai.types.create_embedding_response.CreateEmbeddingResponse): Embedding object. 

636 

637 Returns: 

638 List[float]: List of embeddings as numbers. 

639 """ 

640 tidy_comps = [] 

641 for c in comp.data: 

642 if hasattr(c, "embedding"): 642 ↛ 641line 642 didn't jump to line 641 because the condition on line 642 was always true

643 tidy_comps.append([c.embedding]) 

644 return tidy_comps 

645 

646 kwargs["input"] = prompts 

647 kwargs = {**kwargs, **api_args} 

648 

649 before_openai_query(kwargs) 

650 resp = _tidy(client.embeddings.create(**kwargs)) 

651 _log_api_call(kwargs, resp) 

652 return resp 

653 

654 def _submit_chat_completion( 

655 kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = "conversational" 

656 ) -> List[Text]: 

657 """ 

658 Submit a request to the chat completion endpoint of the OpenAI API. 

659 

660 This method consists of an inner method: 

661 - _tidy: Parse and tidy up the response from the chat completion endpoint of the OpenAI API. 

662 

663 Args: 

664 kwargs (Dict): OpenAI API arguments, including the model to use. 

665 prompts (List[Text]): List of prompts. 

666 api_args (Dict): Other OpenAI API arguments. 

667 df (pd.DataFrame): Input data to run chat completion on. 

668 mode (Text): Mode of operation. 

669 

670 Returns: 

671 List[Text]: List of chat completions as text. 

672 """ 

673 

674 def _tidy(comp: openai.types.chat.chat_completion.ChatCompletion) -> List[Text]: 

675 """ 

676 Parse and tidy up the response from the chat completion endpoint of the OpenAI API. 

677 

678 Args: 

679 comp (openai.types.chat.chat_completion.ChatCompletion): Chat completion object. 

680 

681 Returns: 

682 List[Text]: List of chat completions as text. 

683 """ 

684 tidy_comps = [] 

685 for c in comp.choices: 

686 if hasattr(c, "message"): 686 ↛ 685line 686 didn't jump to line 685 because the condition on line 686 was always true

687 tidy_comps.append(c.message.content.strip("\n").strip("")) 

688 return tidy_comps 

689 

690 mode = Mode(mode) 

691 completions = [] 

692 if mode is not Mode.conversational or "prompt" not in args: 

693 initial_prompt = { 

694 "role": "system", 

695 "content": "You are a helpful assistant. Your task is to continue the chat.", 

696 } # noqa 

697 else: 

698 # get prompt from model 

699 initial_prompt = {"role": "system", "content": args["prompt"]} # noqa 

700 

701 kwargs["messages"] = [initial_prompt] 

702 last_completion_content = None 

703 

704 for pidx in range(len(prompts)): 

705 if mode is not Mode.conversational: 

706 kwargs["messages"].append({"role": "user", "content": prompts[pidx]}) 

707 else: 

708 question = prompts[pidx] 

709 if question: 709 ↛ 712line 709 didn't jump to line 712 because the condition on line 709 was always true

710 kwargs["messages"].append({"role": "user", "content": question}) 

711 

712 assistant_column = args.get("assistant_column") 

713 if assistant_column in df.columns: 713 ↛ 714line 713 didn't jump to line 714 because the condition on line 713 was never true

714 answer = df.iloc[pidx][assistant_column] 

715 else: 

716 answer = None 

717 if answer: 717 ↛ 718line 717 didn't jump to line 718 because the condition on line 717 was never true

718 kwargs["messages"].append({"role": "assistant", "content": answer}) 

719 

720 if mode is Mode.conversational_full or (mode is Mode.conversational and pidx == len(prompts) - 1): 

721 kwargs["messages"] = truncate_msgs_for_token_limit( 

722 kwargs["messages"], kwargs["model"], api_args["max_tokens"] 

723 ) 

724 pkwargs = {**kwargs, **api_args} 

725 

726 before_openai_query(kwargs) 

727 resp = _tidy(client.chat.completions.create(**pkwargs)) 

728 _log_api_call(pkwargs, resp) 

729 

730 completions.extend(resp) 

731 elif mode is Mode.default: 

732 kwargs["messages"] = [initial_prompt] + [kwargs["messages"][-1]] 

733 pkwargs = {**kwargs, **api_args} 

734 

735 before_openai_query(kwargs) 

736 resp = _tidy(client.chat.completions.create(**pkwargs)) 

737 _log_api_call(pkwargs, resp) 

738 

739 completions.extend(resp) 

740 else: 

741 # in "normal" conversational mode, we request completions only for the last row 

742 last_completion_content = None 

743 completions.extend([""]) 

744 

745 if last_completion_content: 745 ↛ 747line 745 didn't jump to line 747 because the condition on line 745 was never true

746 # interleave assistant responses with user input 

747 kwargs["messages"].append({"role": "assistant", "content": last_completion_content[0]}) 

748 

749 return completions 

750 

751 def _submit_image_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]: 

752 """ 

753 Submit a request to the image generation endpoint of the OpenAI API. 

754 

755 This method consists of an inner method: 

756 - _tidy: Parse and tidy up the response from the image generation endpoint of the OpenAI API. 

757 

758 Args: 

759 kwargs (Dict): OpenAI API arguments, including the model to use. 

760 prompts (List[Text]): List of prompts. 

761 api_args (Dict): Other OpenAI API arguments. 

762 

763 Raises: 

764 Exception: If the maximum batch size is reached. 

765 

766 Returns: 

767 List[Text]: List of image completions as URLs or base64 encoded images. 

768 """ 

769 

770 def _tidy(comp: List[openai.types.image.Image]) -> List[Text]: 

771 """ 

772 Parse and tidy up the response from the image generation endpoint of the OpenAI API. 

773 

774 Args: 

775 comp (List[openai.types.image.Image]): Image completion objects. 

776 

777 Returns: 

778 List[Text]: List of image completions as URLs or base64 encoded images. 

779 """ 

780 return [c.url if hasattr(c, "url") else c.b64_json for c in comp] 

781 

782 completions = [client.images.generate(**{"prompt": p, **kwargs, **api_args}).data[0] for p in prompts] 

783 return _tidy(completions) 

784 

785 client = self._get_client( 

786 api_key=api_key, 

787 base_url=args.get("api_base"), 

788 org=args.pop("api_organization") if "api_organization" in args else None, 

789 args=args, 

790 ) 

791 

792 try: 

793 # check if simple completion works 

794 completion = _submit_completion(model_name, prompts, api_args, args, df) 

795 return completion 

796 except Exception as e: 

797 # else, we get the max batch size 

798 if "you can currently request up to at most a total of" in str(e): 

799 pattern = "a total of" 

800 max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(").")[0]) 

801 else: 

802 max_batch_size = self.max_batch_size # guards against changes in the API message 

803 

804 if not parallel: 

805 completion = None 

806 for i in range(math.ceil(len(prompts) / max_batch_size)): 

807 partial = _submit_completion( 

808 model_name, 

809 prompts[i * max_batch_size : (i + 1) * max_batch_size], 

810 api_args, 

811 args, 

812 df, 

813 ) 

814 if not completion: 

815 completion = partial 

816 else: 

817 completion.extend(partial) 

818 else: 

819 promises = [] 

820 with concurrent.futures.ThreadPoolExecutor() as executor: 

821 for i in range(math.ceil(len(prompts) / max_batch_size)): 

822 logger.debug(f"{i * max_batch_size}:{(i + 1) * max_batch_size}/{len(prompts)}") 

823 future = executor.submit( 

824 _submit_completion, 

825 model_name, 

826 prompts[i * max_batch_size : (i + 1) * max_batch_size], 

827 api_args, 

828 args, 

829 df, 

830 ) 

831 promises.append({"choices": future}) 

832 completion = None 

833 for p in promises: 

834 if not completion: 

835 completion = p["choices"].result() 

836 else: 

837 completion.extend(p["choices"].result()) 

838 

839 return completion 

840 

841 def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame: 

842 """ 

843 Get the metadata or arguments of a model. 

844 

845 Args: 

846 attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'. 

847 

848 Returns: 

849 pd.DataFrame: Model metadata or model arguments. 

850 """ 

851 # TODO: Update to use update() artifacts 

852 

853 args = self.model_storage.json_get("args") 

854 api_key = get_api_key(self.api_key_name, args, self.engine_storage) 

855 if attribute == "args": 

856 return pd.DataFrame(args.items(), columns=["key", "value"]) 

857 elif attribute == "metadata": 

858 model_name = args.get("model_name", self.default_model) 

859 try: 

860 client = self._get_client( 

861 api_key=api_key, 

862 base_url=args.get("api_base"), 

863 org=args.get("api_organization"), 

864 args=args, 

865 ) 

866 meta = client.models.retrieve(model_name) 

867 except Exception as e: 

868 meta = {"error": str(e)} 

869 return pd.DataFrame(dict(meta).items(), columns=["key", "value"]) 

870 else: 

871 tables = ["args", "metadata"] 

872 return pd.DataFrame(tables, columns=["tables"]) 

873 

874 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

875 """ 

876 Fine-tune OpenAI GPT models via a MindsDB model connected to the OpenAI API. 

877 Steps are roughly: 

878 - Analyze input data and modify it according to suggestions made by the OpenAI utility tool 

879 - Get a training and validation file 

880 - Determine base model to use 

881 - Submit a fine-tuning job via the OpenAI API 

882 - Monitor progress with exponential backoff (which has been modified for greater control given a time budget in hours), 

883 - Gather stats once fine-tuning finishes 

884 - Modify model metadata so that the new version triggers the fine-tuned version of the model (stored in the user's OpenAI account) 

885 

886 Caveats: 

887 - As base fine-tuning models, OpenAI only supports the original GPT ones: `ada`, `babbage`, `curie`, `davinci`. This means if you fine-tune successively more than once, any fine-tuning other than the most recent one is lost. 

888 - A bunch of helper methods exist to be overridden in other handlers that follow the OpenAI API, e.g. Anyscale 

889 

890 Args: 

891 df (Optional[pd.DataFrame]): Input data to fine-tune on. 

892 args (Optional[Dict]): Parameters for the fine-tuning process. 

893 

894 Raises: 

895 Exception: If the model does not support fine-tuning. 

896 

897 Returns: 

898 None 

899 """ # noqa 

900 args = args if args else {} 

901 

902 api_key = get_api_key(self.api_key_name, args, self.engine_storage) 

903 

904 using_args = args.pop("using") if "using" in args else {} 

905 

906 api_base = using_args.get("api_base", os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)) 

907 org = using_args.get("api_organization") 

908 client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args) 

909 

910 args = {**using_args, **args} 

911 prev_model_name = self.base_model_storage.json_get("args").get("model_name", "") 

912 

913 if prev_model_name not in self.supported_ft_models: 913 ↛ 923line 913 didn't jump to line 923 because the condition on line 913 was always true

914 # base model may be already FTed, check prefixes 

915 for model in self.supported_ft_models: 

916 if model in prev_model_name: 916 ↛ 917line 916 didn't jump to line 917 because the condition on line 916 was never true

917 break 

918 else: 

919 raise Exception( 

920 f"This model cannot be finetuned. Supported base models are {self.supported_ft_models}." 

921 ) 

922 

923 finetune_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 

924 

925 temp_storage_path = tempfile.mkdtemp() 

926 temp_file_name = f"ft_{finetune_time}" 

927 temp_model_storage_path = f"{temp_storage_path}/{temp_file_name}.jsonl" 

928 

929 file_names = self._prepare_ft_jsonl(df, temp_storage_path, temp_file_name, temp_model_storage_path) 

930 

931 jsons = {k: None for k in file_names.keys()} 

932 for split, file_name in file_names.items(): 

933 if os.path.isfile(os.path.join(temp_storage_path, file_name)): 

934 jsons[split] = client.files.create( 

935 file=open(f"{temp_storage_path}/{file_name}", "rb"), purpose="fine-tune" 

936 ) 

937 

938 if type(jsons["train"]) is openai.types.FileObject: 

939 train_file_id = jsons["train"].id 

940 else: 

941 train_file_id = jsons["base"].id 

942 

943 if type(jsons["val"]) is openai.types.FileObject: 

944 val_file_id = jsons["val"].id 

945 else: 

946 val_file_id = None 

947 

948 # `None` values are internally imputed by OpenAI to `null` or default values 

949 ft_params = { 

950 "training_file": train_file_id, 

951 "validation_file": val_file_id, 

952 "model": self._get_ft_model_type(prev_model_name), 

953 } 

954 ft_params = self._add_extra_ft_params(ft_params, using_args) 

955 

956 start_time = datetime.datetime.now() 

957 

958 ft_stats, result_file_id = self._ft_call(ft_params, client, args.get("hour_budget", 8)) 

959 ft_model_name = ft_stats.fine_tuned_model 

960 

961 end_time = datetime.datetime.now() 

962 runtime = end_time - start_time 

963 name_extension = client.files.retrieve(file_id=result_file_id).filename 

964 result_path = f"{temp_storage_path}/ft_{finetune_time}_result_{name_extension}" 

965 

966 try: 

967 client.files.content(file_id=result_file_id).stream_to_file(result_path) 

968 if ".csv" in name_extension: 

969 # legacy endpoint 

970 train_stats = pd.read_csv(result_path) 

971 if "validation_token_accuracy" in train_stats.columns: 

972 train_stats = train_stats[train_stats["validation_token_accuracy"].notnull()] 

973 args["ft_api_info"] = ft_stats.dict() 

974 args["ft_result_stats"] = train_stats.to_dict() 

975 

976 elif ".json" in name_extension: 

977 train_stats = pd.read_json(path_or_buf=result_path, lines=True) # new endpoint 

978 args["ft_api_info"] = args["ft_result_stats"] = train_stats.to_dict() 

979 

980 except Exception: 

981 logger.info( 

982 f"Error retrieving fine-tuning results. Please check manually for information on job {ft_stats.id} (result file {result_file_id})." 

983 ) 

984 

985 args["model_name"] = ft_model_name 

986 args["runtime"] = runtime.total_seconds() 

987 args["mode"] = self.base_model_storage.json_get("args").get("mode", self.default_mode) 

988 

989 self.model_storage.json_set("args", args) 

990 shutil.rmtree(temp_storage_path) 

991 

992 @staticmethod 

993 def _prepare_ft_jsonl(df: pd.DataFrame, _, temp_filename: Text, temp_model_path: Text) -> Dict: 

994 """ 

995 Prepare the input data for fine-tuning. 

996 

997 Args: 

998 df (pd.DataFrame): Input data to fine-tune on. 

999 temp_filename (Text): Temporary filename. 

1000 temp_model_path (Text): Temporary model path. 

1001 

1002 Returns: 

1003 Dict: File names for the fine-tuning process. 

1004 """ 

1005 df.to_json(temp_model_path, orient="records", lines=True) 

1006 

1007 # TODO avoid subprocess usage once OpenAI enables non-CLI access, or refactor to use our own LLM utils instead 

1008 subprocess.run( 

1009 [ 

1010 "openai", 

1011 "tools", 

1012 "fine_tunes.prepare_data", 

1013 "-f", 

1014 temp_model_path, # from file 

1015 "-q", # quiet mode (accepts all suggestions) 

1016 ], 

1017 stdout=subprocess.PIPE, 

1018 stderr=subprocess.PIPE, 

1019 encoding="utf-8", 

1020 ) 

1021 

1022 file_names = { 

1023 "original": f"{temp_filename}.jsonl", 

1024 "base": f"{temp_filename}_prepared.jsonl", 

1025 "train": f"{temp_filename}_prepared_train.jsonl", 

1026 "val": f"{temp_filename}_prepared_valid.jsonl", 

1027 } 

1028 return file_names 

1029 

1030 def _get_ft_model_type(self, model_name: Text) -> Text: 

1031 """ 

1032 Get the model to use for fine-tuning. If the model is not supported, the default model (babbage-002) is used. 

1033 

1034 Args: 

1035 model_name (Text): Model name. 

1036 

1037 Returns: 

1038 Text: Model to use for fine-tuning. 

1039 """ 

1040 for model_type in self.supported_ft_models: 

1041 if model_type in model_name.lower(): 

1042 return model_type 

1043 return "babbage-002" 

1044 

1045 @staticmethod 

1046 def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict: 

1047 """ 

1048 Add extra parameters to the fine-tuning process. 

1049 

1050 Args: 

1051 ft_params (Dict): Parameters for the fine-tuning process required by the OpenAI API. 

1052 using_args (Dict): Parameters passed when calling the fine-tuning process via a model. 

1053 

1054 Returns: 

1055 Dict: Fine-tuning parameters with extra parameters. 

1056 """ 

1057 extra_params = { 

1058 "n_epochs": using_args.get("n_epochs", None), 

1059 "batch_size": using_args.get("batch_size", None), 

1060 "learning_rate_multiplier": using_args.get("learning_rate_multiplier", None), 

1061 "prompt_loss_weight": using_args.get("prompt_loss_weight", None), 

1062 "compute_classification_metrics": using_args.get("compute_classification_metrics", None), 

1063 "classification_n_classes": using_args.get("classification_n_classes", None), 

1064 "classification_positive_class": using_args.get("classification_positive_class", None), 

1065 "classification_betas": using_args.get("classification_betas", None), 

1066 } 

1067 return {**ft_params, **extra_params} 

1068 

1069 def _ft_call(self, ft_params: Dict, client: OpenAI, hour_budget: int) -> Tuple[FineTuningJob, Text]: 

1070 """ 

1071 Submit a fine-tuning job via the OpenAI API. 

1072 This method handles requests to both the legacy and new endpoints. 

1073 Currently, `OpenAIHandler` uses the legacy endpoint. Others, like `AnyscaleEndpointsHandler`, use the new endpoint. 

1074 

1075 This method consists of an inner method: 

1076 - _check_ft_status: Check the status of a fine-tuning job via the OpenAI API. 

1077 

1078 Args: 

1079 ft_params (Dict): Fine-tuning parameters. 

1080 client (openai.OpenAI): OpenAI client. 

1081 hour_budget (int): Hour budget for the fine-tuning process. 

1082 

1083 Raises: 

1084 PendingFT: If the fine-tuning process is still pending. 

1085 

1086 Returns: 

1087 Tuple[FineTuningJob, Text]: Fine-tuning stats and result file ID. 

1088 """ 

1089 ft_result = client.fine_tuning.jobs.create(**{k: v for k, v in ft_params.items() if v is not None}) 

1090 

1091 @retry_with_exponential_backoff( 

1092 hour_budget=hour_budget, 

1093 ) 

1094 def _check_ft_status(job_id: Text) -> FineTuningJob: 

1095 """ 

1096 Check the status of a fine-tuning job via the OpenAI API. 

1097 

1098 Args: 

1099 job_id (Text): Fine-tuning job ID. 

1100 

1101 Raises: 

1102 PendingFT: If the fine-tuning process is still pending. 

1103 

1104 Returns: 

1105 FineTuningJob: Fine-tuning stats. 

1106 """ 

1107 ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id) 

1108 if ft_retrieved.status in ("succeeded", "failed", "cancelled"): 

1109 return ft_retrieved 

1110 else: 

1111 raise PendingFT("Fine-tuning still pending!") 

1112 

1113 ft_stats = _check_ft_status(ft_result.id) 

1114 

1115 if ft_stats.status != "succeeded": 

1116 err_message = ft_stats.events[-1].message if hasattr(ft_stats, "events") else "could not retrieve!" 

1117 ft_status = ft_stats.status if hasattr(ft_stats, "status") else "N/A" 

1118 raise Exception( 

1119 f"Fine-tuning did not complete successfully (status: {ft_status}). Error message: {err_message}" 

1120 ) # noqa 

1121 

1122 result_file_id = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=ft_result.id).result_files[0] 

1123 if hasattr(result_file_id, "id"): 

1124 result_file_id = result_file_id.id # legacy endpoint 

1125 

1126 return ft_stats, result_file_id 

1127 

1128 @staticmethod 

1129 def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None, args: dict = None) -> OpenAI: 

1130 """ 

1131 Get an OpenAI client with the given API key, base URL, and organization. 

1132 

1133 Args: 

1134 api_key (Text): OpenAI API key. 

1135 base_url (Text): OpenAI base URL. 

1136 org (Optional[Text]): OpenAI organization. 

1137 

1138 Returns: 

1139 openai.OpenAI: OpenAI client. 

1140 """ 

1141 if args is not None and args.get("provider") == "azure": 1141 ↛ 1142line 1141 didn't jump to line 1142 because the condition on line 1141 was never true

1142 return AzureOpenAI( 

1143 api_key=api_key, azure_endpoint=base_url, api_version=args.get("api_version"), organization=org 

1144 ) 

1145 return OpenAI(api_key=api_key, base_url=base_url, organization=org)