Coverage for mindsdb / integrations / handlers / openai_handler / openai_handler.py: 61%
456 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import math
3import json
4import shutil
5import tempfile
6import datetime
7import textwrap
8import subprocess
9from enum import Enum
10import concurrent.futures
11from typing import Text, Tuple, Dict, List, Optional, Any
12import openai
13from openai.types.fine_tuning import FineTuningJob
14from openai import OpenAI, AzureOpenAI, NotFoundError, AuthenticationError
15import numpy as np
16import pandas as pd
18from mindsdb.utilities.hooks import before_openai_query, after_openai_query
19from mindsdb.utilities import log
20from mindsdb.integrations.libs.base import BaseMLEngine
21from mindsdb.integrations.handlers.openai_handler.helpers import (
22 retry_with_exponential_backoff,
23 truncate_msgs_for_token_limit,
24 get_available_models,
25 PendingFT,
26)
27from mindsdb.integrations.handlers.openai_handler.constants import (
28 CHAT_MODELS_PREFIXES,
29 IMAGE_MODELS,
30 FINETUNING_MODELS,
31 OPENAI_API_BASE,
32 DEFAULT_CHAT_MODEL,
33 DEFAULT_EMBEDDING_MODEL,
34 DEFAULT_IMAGE_MODEL,
35)
36from mindsdb.integrations.libs.llm.utils import get_completed_prompts
37from mindsdb.integrations.utilities.handler_utils import get_api_key
39logger = log.getLogger(__name__)
42class Mode(Enum):
43 default = "default"
44 conversational = "conversational"
45 conversational_full = "conversational-full"
46 image = "image"
47 embedding = "embedding"
48 legacy = "legacy"
50 @classmethod
51 def _missing_(cls, value):
52 raise ValueError(f"Invalid operation mode '{value}'. Please use one of: {[val.name for val in cls]}")
55class OpenAIHandler(BaseMLEngine):
56 """
57 This handler handles connection and inference with the OpenAI API.
58 """
60 name = "openai"
62 def __init__(self, *args, **kwargs):
63 super().__init__(*args, **kwargs)
64 self.generative = True
65 self.default_model = DEFAULT_CHAT_MODEL
66 self.default_embedding_model = DEFAULT_EMBEDDING_MODEL
67 self.default_image_model = DEFAULT_IMAGE_MODEL
68 self.default_mode = Mode.default # can also be 'conversational' or 'conversational-full'
69 self.rate_limit = 60 # requests per minute
70 self.max_batch_size = 20
71 self.default_max_tokens = 100
72 self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning
73 # For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods
74 self.api_key_name = getattr(self, "api_key_name", self.name)
75 self.api_base = getattr(self, "api_base", OPENAI_API_BASE)
77 def create_engine(self, connection_args: Dict) -> None:
78 """
79 Validate the OpenAI API credentials on engine creation.
81 Args:
82 connection_args (Dict): Parameters for the engine.
84 Raises:
85 Exception: If the handler is not configured with valid API credentials.
87 Returns:
88 None
89 """
90 connection_args = {k.lower(): v for k, v in connection_args.items()}
91 api_key = connection_args.get("openai_api_key")
92 if api_key is not None:
93 org = connection_args.get("api_organization")
94 api_base = connection_args.get("api_base") or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
95 client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=connection_args)
96 OpenAIHandler._check_client_connection(client)
98 @staticmethod
99 def is_chat_model(model_name):
100 for prefix in CHAT_MODELS_PREFIXES:
101 if model_name.startswith(prefix):
102 return True
103 return False
105 @staticmethod
106 def _check_client_connection(client: OpenAI) -> None:
107 """
108 Check the OpenAI engine client connection by retrieving a model.
110 Args:
111 client (openai.OpenAI): OpenAI client configured with the API credentials.
113 Raises:
114 Exception: If the client connection (API key) is invalid or something else goes wrong.
116 Returns:
117 None
118 """
119 try:
120 client.models.retrieve("test")
121 except NotFoundError:
122 pass
123 except AuthenticationError as e:
124 if isinstance(e.body, dict) and e.body.get("code") == "invalid_api_key": 124 ↛ 126line 124 didn't jump to line 126 because the condition on line 124 was always true
125 raise Exception("Invalid api key")
126 raise Exception(f"Something went wrong: {e}")
128 @staticmethod
129 def create_validation(target: Text, args: Dict = None, **kwargs: Any) -> None:
130 """
131 Validate the OpenAI API credentials on model creation.
133 Args:
134 target (Text): Target column name.
135 args (Dict): Parameters for the model.
136 kwargs (Any): Other keyword arguments.
138 Raises:
139 Exception: If the handler is not configured with valid API credentials.
141 Returns:
142 None
143 """
144 if "using" not in args:
145 raise Exception("OpenAI engine requires a USING clause! Refer to its documentation for more details.")
146 else:
147 args = args["using"]
149 if len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) == 0:
150 raise Exception("One of `question_column`, `prompt_template` or `prompt` is required for this engine.")
152 keys_collection = [
153 ["prompt_template"],
154 ["question_column", "context_column"],
155 ["prompt", "user_column", "assistant_column"],
156 ]
157 for keys in keys_collection:
158 if keys[0] in args and any(x[0] in args for x in keys_collection if x != keys):
159 raise Exception(
160 textwrap.dedent(
161 """\
162 Please provide one of
163 1) a `prompt_template`
164 2) a `question_column` and an optional `context_column`
165 3) a `prompt', 'user_column' and 'assistant_column`
166 """
167 )
168 )
170 # for all args that are not expected, raise an error
171 known_args = set()
172 # flatten of keys_collection
173 for keys in keys_collection:
174 known_args = known_args.union(set(keys))
176 # TODO: need a systematic way to maintain a list of known args
177 known_args = known_args.union(
178 {
179 "target",
180 "model_name",
181 "mode",
182 "predict_params",
183 "json_struct",
184 "ft_api_info",
185 "ft_result_stats",
186 "runtime",
187 "max_tokens",
188 "temperature",
189 "openai_api_key",
190 "api_organization",
191 "api_base",
192 "api_version",
193 "provider",
194 }
195 )
197 unknown_args = set(args.keys()) - known_args
198 if unknown_args:
199 # return a list of unknown args as a string
200 raise Exception(
201 f"Unknown arguments: {', '.join(unknown_args)}.\n Known arguments are: {', '.join(known_args)}"
202 )
204 engine_storage = kwargs["handler_storage"]
205 connection_args = engine_storage.get_connection_args()
206 api_key = get_api_key("openai", args, engine_storage=engine_storage)
207 api_base = (
208 args.get("api_base")
209 or connection_args.get("api_base")
210 or os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE)
211 )
212 org = args.get("api_organization")
213 client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
214 OpenAIHandler._check_client_connection(client)
216 def create(self, target, args: Dict = None, **kwargs: Any) -> None:
217 """
218 Create a model by connecting to the OpenAI API.
220 Args:
221 target (Text): Target column name.
222 args (Dict): Parameters for the model.
223 kwargs (Any): Other keyword arguments.
225 Raises:
226 Exception: If the model is not configured with valid parameters.
228 Returns:
229 None
230 """
231 args = args["using"]
232 args["target"] = target
233 try:
234 api_key = get_api_key(self.api_key_name, args, self.engine_storage)
235 connection_args = self.engine_storage.get_connection_args()
236 api_base = (
237 args.get("api_base")
238 or connection_args.get("api_base")
239 or os.environ.get("OPENAI_API_BASE")
240 or self.api_base
241 )
242 client = self._get_client(api_key=api_key, base_url=api_base, org=args.get("api_organization"), args=args)
243 available_models = get_available_models(client)
245 mode = args.get("mode")
246 if mode is not None:
247 mode = Mode(mode)
248 else:
249 mode = self.default_mode
251 if not args.get("model_name"):
252 if mode is Mode.embedding: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 args["model_name"] = self.default_embedding_model
254 elif mode is Mode.image: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true
255 args["model_name"] = self.default_image_model
256 else:
257 args["model_name"] = self.default_model
258 elif (args["model_name"] not in available_models) and (mode is not Mode.embedding): 258 ↛ 261line 258 didn't jump to line 261 because the condition on line 258 was always true
259 raise Exception(f"Invalid model name. Please use one of {available_models}")
260 finally:
261 self.model_storage.json_set("args", args)
263 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
264 """
265 Make predictions using a model connected to the OpenAI API.
267 Args:
268 df (pd.DataFrame): Input data to make predictions on.
269 args (Dict): Parameters passed when making predictions.
271 Raises:
272 Exception: If the model is not configured with valid parameters or if the input data is not in the expected format.
274 Returns:
275 pd.DataFrame: Input data with the predicted values in a new column.
276 """ # noqa
277 # TODO: support for edits, embeddings and moderation
279 pred_args = args["predict_params"] if args else {}
280 args = self.model_storage.json_get("args")
281 connection_args = self.engine_storage.get_connection_args()
283 args["api_base"] = (
284 pred_args.get("api_base")
285 or args.get("api_base")
286 or connection_args.get("api_base")
287 or os.environ.get("OPENAI_API_BASE")
288 or self.api_base
289 )
291 if pred_args.get("api_organization"): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 args["api_organization"] = pred_args["api_organization"]
293 df = df.reset_index(drop=True)
295 if pred_args.get("mode"):
296 mode = Mode(pred_args["mode"])
297 args["mode"] = mode.value
298 elif args.get("mode"): 298 ↛ 301line 298 didn't jump to line 301 because the condition on line 298 was always true
299 mode = Mode(args["mode"])
300 else:
301 mode = Mode(self.default_mode)
303 strict_prompt_template = True
304 if pred_args.get("prompt_template", False): 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 base_template = pred_args["prompt_template"] # override with predict-time template if available
306 strict_prompt_template = False
307 elif args.get("prompt_template", False):
308 base_template = args["prompt_template"]
309 else:
310 base_template = None
312 # Embedding mode
313 if mode is Mode.embedding:
314 api_args = {
315 "question_column": pred_args.get("question_column", None),
316 "model": pred_args.get("model_name") or args.get("model_name"),
317 }
318 model_name = "embedding"
319 if args.get("question_column"):
320 prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
321 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
322 else:
323 raise Exception("Embedding mode needs a question_column")
325 # Image mode
326 elif mode is Mode.image:
327 api_args = {
328 "n": pred_args.get("n", None),
329 "size": pred_args.get("size", None),
330 "response_format": pred_args.get("response_format", None),
331 }
332 api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
333 model_name = pred_args.get("model_name") or args.get("model_name")
335 if args.get("question_column"):
336 prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
337 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
338 elif args.get("prompt_template"):
339 prompts, empty_prompt_ids = get_completed_prompts(base_template, df)
340 else:
341 raise Exception("Image mode needs either `prompt_template` or `question_column`.")
343 # Chat or normal completion mode
344 else:
345 if args.get("question_column", False) and args["question_column"] not in df.columns:
346 raise Exception(f"This model expects a question to answer in the '{args['question_column']}' column.")
348 if args.get("context_column", False) and args["context_column"] not in df.columns:
349 raise Exception(f"This model expects context in the '{args['context_column']}' column.")
351 # API argument validation
352 model_name = args.get("model_name", self.default_model)
353 api_args = {
354 "max_tokens": pred_args.get("max_tokens", args.get("max_tokens", self.default_max_tokens)),
355 "temperature": min(
356 1.0,
357 max(0.0, pred_args.get("temperature", args.get("temperature", 0.0))),
358 ),
359 "top_p": pred_args.get("top_p", None),
360 "n": pred_args.get("n", None),
361 "stop": pred_args.get("stop", None),
362 "presence_penalty": pred_args.get("presence_penalty", None),
363 "frequency_penalty": pred_args.get("frequency_penalty", None),
364 "best_of": pred_args.get("best_of", None),
365 "logit_bias": pred_args.get("logit_bias", None),
366 "user": pred_args.get("user", None),
367 }
369 if args.get("prompt_template", False):
370 prompts, empty_prompt_ids = get_completed_prompts(base_template, df, strict=strict_prompt_template)
372 elif args.get("context_column", False): 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true
373 empty_prompt_ids = np.where(
374 df[[args["context_column"], args["question_column"]]].isna().all(axis=1).values
375 )[0]
376 contexts = list(df[args["context_column"]].apply(lambda x: str(x)))
377 questions = list(df[args["question_column"]].apply(lambda x: str(x)))
378 prompts = [f"Context: {c}\nQuestion: {q}\nAnswer: " for c, q in zip(contexts, questions)]
380 elif "prompt" in args:
381 empty_prompt_ids = []
382 prompts = list(df[args["user_column"]])
383 else:
384 empty_prompt_ids = np.where(df[[args["question_column"]]].isna().all(axis=1).values)[0]
385 prompts = list(df[args["question_column"]].apply(lambda x: str(x)))
387 # add json struct if available
388 if args.get("json_struct", False): 388 ↛ 389line 388 didn't jump to line 389 because the condition on line 388 was never true
389 for i, prompt in enumerate(prompts):
390 json_struct = ""
391 if "json_struct" in df.columns and i not in empty_prompt_ids:
392 # if row has a specific json, we try to use it instead of the base prompt template
393 try:
394 if isinstance(df["json_struct"][i], str):
395 df["json_struct"][i] = json.loads(df["json_struct"][i])
396 for ind, val in enumerate(df["json_struct"][i].values()):
397 json_struct = json_struct + f"{ind}. {val}\n"
398 except Exception:
399 pass # if the row's json is invalid, we use the prompt template instead
401 if json_struct == "":
402 for ind, val in enumerate(args["json_struct"].values()):
403 json_struct = json_struct + f"{ind + 1}. {val}\n"
405 p = textwrap.dedent(
406 f"""\
407 Based on the text following 'The reference text is:', assign values to the following {len(args["json_struct"])} JSON attributes:
408 {{{{json_struct}}}}
410 Values should follow the same order as the attributes above.
411 Each line in the answer should start with a dotted number, and should not repeat the name of the attribute, just the value.
412 Each answer must end with new line.
413 If there is no valid value to a given attribute in the text, answer with a - character.
414 Values should be as short as possible, ideally 1-2 words (unless otherwise specified).
416 Here is an example input of 3 attributes:
417 1. rental price
418 2. location
419 3. number of bathrooms
421 Here is an example output for the input:
422 1. 3000
423 2. Manhattan
424 3. 2
426 Now for the real task. The reference text is:
427 {prompt}
428 """
429 )
431 p = p.replace("{{json_struct}}", json_struct)
432 prompts[i] = p
434 # remove prompts without signal from completion queue
435 prompts = [j for i, j in enumerate(prompts) if i not in empty_prompt_ids]
437 api_key = get_api_key(self.api_key_name, args, self.engine_storage)
438 api_args = {k: v for k, v in api_args.items() if v is not None} # filter out non-specified api args
439 completion = self._completion(model_name, prompts, api_key, api_args, args, df)
441 # add null completion for empty prompts
442 for i in sorted(empty_prompt_ids): 442 ↛ 443line 442 didn't jump to line 443 because the loop on line 442 never started
443 completion.insert(i, None)
445 pred_df = pd.DataFrame(completion, columns=[args["target"]])
447 # restore json struct
448 if args.get("json_struct", False): 448 ↛ 449line 448 didn't jump to line 449 because the condition on line 448 was never true
449 for i in pred_df.index:
450 try:
451 if "json_struct" in df.columns:
452 json_keys = df["json_struct"][i].keys()
453 else:
454 json_keys = args["json_struct"].keys()
455 responses = pred_df[args["target"]][i].split("\n")
456 responses = [x[3:] for x in responses] # del question index
458 pred_df[args["target"]][i] = {key: val for key, val in zip(json_keys, responses)}
459 except Exception:
460 pred_df[args["target"]][i] = None
462 return pred_df
464 def _completion(
465 self,
466 model_name: Text,
467 prompts: List[Text],
468 api_key: Text,
469 api_args: Dict,
470 args: Dict,
471 df: pd.DataFrame,
472 parallel: bool = True,
473 ) -> List[Any]:
474 """
475 Handles completion for an arbitrary amount of rows using a model connected to the OpenAI API.
477 This method consists of several inner methods:
478 - _submit_completion: Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task.
479 - _submit_normal_completion: Submit a request to the completion endpoint of the OpenAI API.
480 - _submit_embedding_completion: Submit a request to the embeddings endpoint of the OpenAI API.
481 - _submit_chat_completion: Submit a request to the chat completion endpoint of the OpenAI API.
482 - _submit_image_completion: Submit a request to the image completion endpoint of the OpenAI API.
483 - _log_api_call: Log the API call made to the OpenAI API.
485 There are a couple checks that should be done when calling OpenAI's API:
486 - account max batch size, to maximize batch size first
487 - account rate limit, to maximize parallel calls second
489 Additionally, single completion calls are done with exponential backoff to guarantee all prompts are processed,
490 because even with previous checks the tokens-per-minute limit may apply.
492 Args:
493 model_name (Text): OpenAI Model name.
494 prompts (List[Text]): List of prompts.
495 api_key (Text): OpenAI API key.
496 api_args (Dict): OpenAI API arguments.
497 args (Dict): Parameters for the model.
498 df (pd.DataFrame): Input data to run completion on.
499 parallel (bool): Whether to use parallel processing.
501 Returns:
502 List[Any]: List of completions. The type of completion depends on the task type.
503 """
505 @retry_with_exponential_backoff()
506 def _submit_completion(
507 model_name: Text, prompts: List[Text], api_args: Dict, args: Dict, df: pd.DataFrame
508 ) -> List[Text]:
509 """
510 Submit a request to the relevant completion endpoint of the OpenAI API based on the type of task.
512 Args:
513 model_name (Text): OpenAI Model name.
514 prompts (List[Text]): List of prompts.
515 api_args (Dict): OpenAI API arguments.
516 args (Dict): Parameters for the model.
517 df (pd.DataFrame): Input data to run completion on.
519 Returns:
520 List[Text]: List of completions.
521 """
522 kwargs = {
523 "model": model_name,
524 }
525 try:
526 mode = Mode(args.get("mode"))
527 except ValueError:
528 if model_name in IMAGE_MODELS:
529 mode = Mode.image
530 elif model_name == "embedding":
531 mode = Mode.embedding
532 elif self.is_chat_model(model_name) and model_name != "gpt-3.5-turbo-instruct":
533 mode = Mode.conversational
534 elif model_name == "gpt-3.5-turbo-instruct":
535 mode = Mode.legacy
536 else:
537 mode = Mode.default
539 match mode:
540 case Mode.image:
541 return _submit_image_completion(kwargs, prompts, api_args)
542 case Mode.embedding:
543 return _submit_embedding_completion(kwargs, prompts, api_args)
544 case Mode.conversational | Mode.conversational_full | Mode.default:
545 return _submit_chat_completion(
546 kwargs,
547 prompts,
548 api_args,
549 df,
550 mode=args.get("mode", "conversational"),
551 )
552 case Mode.legacy: 552 ↛ exitline 552 didn't return from function '_submit_completion' because the pattern on line 552 always matched
553 return _submit_normal_completion(kwargs, prompts, api_args)
555 def _log_api_call(params: Dict, response: Any) -> None:
556 """
557 Log the API call made to the OpenAI API.
559 Args:
560 params (Dict): Parameters for the API call.
561 response (Any): Response from the API.
563 Returns:
564 None
565 """
566 after_openai_query(params, response)
568 params2 = params.copy()
569 params2.pop("api_key", None)
570 params2.pop("user", None)
571 logger.debug(f">>>openai call: {params2}:\n{response}")
573 def _submit_normal_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]:
574 """
575 Submit a request to the completion endpoint of the OpenAI API.
577 This method consists of an inner method:
578 - _tidy: Parse and tidy up the response from the completion endpoint of the OpenAI API.
580 Args:
581 kwargs (Dict): OpenAI API arguments, including the model to use.
582 prompts (List[Text]): List of prompts.
583 api_args (Dict): Other OpenAI API arguments.
585 Returns:
586 List[Text]: List of text completions.
587 """
589 def _tidy(comp: openai.types.completion.Completion) -> List[Text]:
590 """
591 Parse and tidy up the response from the completion endpoint of the OpenAI API.
593 Args:
594 comp (openai.types.completion.Completion): Completion object.
596 Returns:
597 List[Text]: List of completions as text.
598 """
599 tidy_comps = []
600 for c in comp.choices:
601 if hasattr(c, "text"): 601 ↛ 600line 601 didn't jump to line 600 because the condition on line 601 was always true
602 tidy_comps.append(c.text.strip("\n").strip(""))
603 return tidy_comps
605 kwargs = {**kwargs, **api_args}
607 before_openai_query(kwargs)
608 responses = []
609 for prompt in prompts:
610 responses.extend(_tidy(client.completions.create(prompt=prompt, **kwargs)))
611 _log_api_call(kwargs, responses)
612 return responses
614 def _submit_embedding_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[float]:
615 """
616 Submit a request to the embeddings endpoint of the OpenAI API.
618 This method consists of an inner method:
619 - _tidy: Parse and tidy up the response from the embeddings endpoint of the OpenAI API.
621 Args:
622 kwargs (Dict): OpenAI API arguments, including the model to use.
623 prompts (List[Text]): List of prompts.
624 api_args (Dict): Other OpenAI API arguments.
626 Returns:
627 List[float]: List of embeddings as numbers.
628 """
630 def _tidy(comp: openai.types.create_embedding_response.CreateEmbeddingResponse) -> List[float]:
631 """
632 Parse and tidy up the response from the embeddings endpoint of the OpenAI API.
634 Args:
635 comp (openai.types.create_embedding_response.CreateEmbeddingResponse): Embedding object.
637 Returns:
638 List[float]: List of embeddings as numbers.
639 """
640 tidy_comps = []
641 for c in comp.data:
642 if hasattr(c, "embedding"): 642 ↛ 641line 642 didn't jump to line 641 because the condition on line 642 was always true
643 tidy_comps.append([c.embedding])
644 return tidy_comps
646 kwargs["input"] = prompts
647 kwargs = {**kwargs, **api_args}
649 before_openai_query(kwargs)
650 resp = _tidy(client.embeddings.create(**kwargs))
651 _log_api_call(kwargs, resp)
652 return resp
654 def _submit_chat_completion(
655 kwargs: Dict, prompts: List[Text], api_args: Dict, df: pd.DataFrame, mode: Text = "conversational"
656 ) -> List[Text]:
657 """
658 Submit a request to the chat completion endpoint of the OpenAI API.
660 This method consists of an inner method:
661 - _tidy: Parse and tidy up the response from the chat completion endpoint of the OpenAI API.
663 Args:
664 kwargs (Dict): OpenAI API arguments, including the model to use.
665 prompts (List[Text]): List of prompts.
666 api_args (Dict): Other OpenAI API arguments.
667 df (pd.DataFrame): Input data to run chat completion on.
668 mode (Text): Mode of operation.
670 Returns:
671 List[Text]: List of chat completions as text.
672 """
674 def _tidy(comp: openai.types.chat.chat_completion.ChatCompletion) -> List[Text]:
675 """
676 Parse and tidy up the response from the chat completion endpoint of the OpenAI API.
678 Args:
679 comp (openai.types.chat.chat_completion.ChatCompletion): Chat completion object.
681 Returns:
682 List[Text]: List of chat completions as text.
683 """
684 tidy_comps = []
685 for c in comp.choices:
686 if hasattr(c, "message"): 686 ↛ 685line 686 didn't jump to line 685 because the condition on line 686 was always true
687 tidy_comps.append(c.message.content.strip("\n").strip(""))
688 return tidy_comps
690 mode = Mode(mode)
691 completions = []
692 if mode is not Mode.conversational or "prompt" not in args:
693 initial_prompt = {
694 "role": "system",
695 "content": "You are a helpful assistant. Your task is to continue the chat.",
696 } # noqa
697 else:
698 # get prompt from model
699 initial_prompt = {"role": "system", "content": args["prompt"]} # noqa
701 kwargs["messages"] = [initial_prompt]
702 last_completion_content = None
704 for pidx in range(len(prompts)):
705 if mode is not Mode.conversational:
706 kwargs["messages"].append({"role": "user", "content": prompts[pidx]})
707 else:
708 question = prompts[pidx]
709 if question: 709 ↛ 712line 709 didn't jump to line 712 because the condition on line 709 was always true
710 kwargs["messages"].append({"role": "user", "content": question})
712 assistant_column = args.get("assistant_column")
713 if assistant_column in df.columns: 713 ↛ 714line 713 didn't jump to line 714 because the condition on line 713 was never true
714 answer = df.iloc[pidx][assistant_column]
715 else:
716 answer = None
717 if answer: 717 ↛ 718line 717 didn't jump to line 718 because the condition on line 717 was never true
718 kwargs["messages"].append({"role": "assistant", "content": answer})
720 if mode is Mode.conversational_full or (mode is Mode.conversational and pidx == len(prompts) - 1):
721 kwargs["messages"] = truncate_msgs_for_token_limit(
722 kwargs["messages"], kwargs["model"], api_args["max_tokens"]
723 )
724 pkwargs = {**kwargs, **api_args}
726 before_openai_query(kwargs)
727 resp = _tidy(client.chat.completions.create(**pkwargs))
728 _log_api_call(pkwargs, resp)
730 completions.extend(resp)
731 elif mode is Mode.default:
732 kwargs["messages"] = [initial_prompt] + [kwargs["messages"][-1]]
733 pkwargs = {**kwargs, **api_args}
735 before_openai_query(kwargs)
736 resp = _tidy(client.chat.completions.create(**pkwargs))
737 _log_api_call(pkwargs, resp)
739 completions.extend(resp)
740 else:
741 # in "normal" conversational mode, we request completions only for the last row
742 last_completion_content = None
743 completions.extend([""])
745 if last_completion_content: 745 ↛ 747line 745 didn't jump to line 747 because the condition on line 745 was never true
746 # interleave assistant responses with user input
747 kwargs["messages"].append({"role": "assistant", "content": last_completion_content[0]})
749 return completions
751 def _submit_image_completion(kwargs: Dict, prompts: List[Text], api_args: Dict) -> List[Text]:
752 """
753 Submit a request to the image generation endpoint of the OpenAI API.
755 This method consists of an inner method:
756 - _tidy: Parse and tidy up the response from the image generation endpoint of the OpenAI API.
758 Args:
759 kwargs (Dict): OpenAI API arguments, including the model to use.
760 prompts (List[Text]): List of prompts.
761 api_args (Dict): Other OpenAI API arguments.
763 Raises:
764 Exception: If the maximum batch size is reached.
766 Returns:
767 List[Text]: List of image completions as URLs or base64 encoded images.
768 """
770 def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
771 """
772 Parse and tidy up the response from the image generation endpoint of the OpenAI API.
774 Args:
775 comp (List[openai.types.image.Image]): Image completion objects.
777 Returns:
778 List[Text]: List of image completions as URLs or base64 encoded images.
779 """
780 return [c.url if hasattr(c, "url") else c.b64_json for c in comp]
782 completions = [client.images.generate(**{"prompt": p, **kwargs, **api_args}).data[0] for p in prompts]
783 return _tidy(completions)
785 client = self._get_client(
786 api_key=api_key,
787 base_url=args.get("api_base"),
788 org=args.pop("api_organization") if "api_organization" in args else None,
789 args=args,
790 )
792 try:
793 # check if simple completion works
794 completion = _submit_completion(model_name, prompts, api_args, args, df)
795 return completion
796 except Exception as e:
797 # else, we get the max batch size
798 if "you can currently request up to at most a total of" in str(e):
799 pattern = "a total of"
800 max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(").")[0])
801 else:
802 max_batch_size = self.max_batch_size # guards against changes in the API message
804 if not parallel:
805 completion = None
806 for i in range(math.ceil(len(prompts) / max_batch_size)):
807 partial = _submit_completion(
808 model_name,
809 prompts[i * max_batch_size : (i + 1) * max_batch_size],
810 api_args,
811 args,
812 df,
813 )
814 if not completion:
815 completion = partial
816 else:
817 completion.extend(partial)
818 else:
819 promises = []
820 with concurrent.futures.ThreadPoolExecutor() as executor:
821 for i in range(math.ceil(len(prompts) / max_batch_size)):
822 logger.debug(f"{i * max_batch_size}:{(i + 1) * max_batch_size}/{len(prompts)}")
823 future = executor.submit(
824 _submit_completion,
825 model_name,
826 prompts[i * max_batch_size : (i + 1) * max_batch_size],
827 api_args,
828 args,
829 df,
830 )
831 promises.append({"choices": future})
832 completion = None
833 for p in promises:
834 if not completion:
835 completion = p["choices"].result()
836 else:
837 completion.extend(p["choices"].result())
839 return completion
841 def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
842 """
843 Get the metadata or arguments of a model.
845 Args:
846 attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'.
848 Returns:
849 pd.DataFrame: Model metadata or model arguments.
850 """
851 # TODO: Update to use update() artifacts
853 args = self.model_storage.json_get("args")
854 api_key = get_api_key(self.api_key_name, args, self.engine_storage)
855 if attribute == "args":
856 return pd.DataFrame(args.items(), columns=["key", "value"])
857 elif attribute == "metadata":
858 model_name = args.get("model_name", self.default_model)
859 try:
860 client = self._get_client(
861 api_key=api_key,
862 base_url=args.get("api_base"),
863 org=args.get("api_organization"),
864 args=args,
865 )
866 meta = client.models.retrieve(model_name)
867 except Exception as e:
868 meta = {"error": str(e)}
869 return pd.DataFrame(dict(meta).items(), columns=["key", "value"])
870 else:
871 tables = ["args", "metadata"]
872 return pd.DataFrame(tables, columns=["tables"])
874 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
875 """
876 Fine-tune OpenAI GPT models via a MindsDB model connected to the OpenAI API.
877 Steps are roughly:
878 - Analyze input data and modify it according to suggestions made by the OpenAI utility tool
879 - Get a training and validation file
880 - Determine base model to use
881 - Submit a fine-tuning job via the OpenAI API
882 - Monitor progress with exponential backoff (which has been modified for greater control given a time budget in hours),
883 - Gather stats once fine-tuning finishes
884 - Modify model metadata so that the new version triggers the fine-tuned version of the model (stored in the user's OpenAI account)
886 Caveats:
887 - As base fine-tuning models, OpenAI only supports the original GPT ones: `ada`, `babbage`, `curie`, `davinci`. This means if you fine-tune successively more than once, any fine-tuning other than the most recent one is lost.
888 - A bunch of helper methods exist to be overridden in other handlers that follow the OpenAI API, e.g. Anyscale
890 Args:
891 df (Optional[pd.DataFrame]): Input data to fine-tune on.
892 args (Optional[Dict]): Parameters for the fine-tuning process.
894 Raises:
895 Exception: If the model does not support fine-tuning.
897 Returns:
898 None
899 """ # noqa
900 args = args if args else {}
902 api_key = get_api_key(self.api_key_name, args, self.engine_storage)
904 using_args = args.pop("using") if "using" in args else {}
906 api_base = using_args.get("api_base", os.environ.get("OPENAI_API_BASE", OPENAI_API_BASE))
907 org = using_args.get("api_organization")
908 client = self._get_client(api_key=api_key, base_url=api_base, org=org, args=args)
910 args = {**using_args, **args}
911 prev_model_name = self.base_model_storage.json_get("args").get("model_name", "")
913 if prev_model_name not in self.supported_ft_models: 913 ↛ 923line 913 didn't jump to line 923 because the condition on line 913 was always true
914 # base model may be already FTed, check prefixes
915 for model in self.supported_ft_models:
916 if model in prev_model_name: 916 ↛ 917line 916 didn't jump to line 917 because the condition on line 916 was never true
917 break
918 else:
919 raise Exception(
920 f"This model cannot be finetuned. Supported base models are {self.supported_ft_models}."
921 )
923 finetune_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
925 temp_storage_path = tempfile.mkdtemp()
926 temp_file_name = f"ft_{finetune_time}"
927 temp_model_storage_path = f"{temp_storage_path}/{temp_file_name}.jsonl"
929 file_names = self._prepare_ft_jsonl(df, temp_storage_path, temp_file_name, temp_model_storage_path)
931 jsons = {k: None for k in file_names.keys()}
932 for split, file_name in file_names.items():
933 if os.path.isfile(os.path.join(temp_storage_path, file_name)):
934 jsons[split] = client.files.create(
935 file=open(f"{temp_storage_path}/{file_name}", "rb"), purpose="fine-tune"
936 )
938 if type(jsons["train"]) is openai.types.FileObject:
939 train_file_id = jsons["train"].id
940 else:
941 train_file_id = jsons["base"].id
943 if type(jsons["val"]) is openai.types.FileObject:
944 val_file_id = jsons["val"].id
945 else:
946 val_file_id = None
948 # `None` values are internally imputed by OpenAI to `null` or default values
949 ft_params = {
950 "training_file": train_file_id,
951 "validation_file": val_file_id,
952 "model": self._get_ft_model_type(prev_model_name),
953 }
954 ft_params = self._add_extra_ft_params(ft_params, using_args)
956 start_time = datetime.datetime.now()
958 ft_stats, result_file_id = self._ft_call(ft_params, client, args.get("hour_budget", 8))
959 ft_model_name = ft_stats.fine_tuned_model
961 end_time = datetime.datetime.now()
962 runtime = end_time - start_time
963 name_extension = client.files.retrieve(file_id=result_file_id).filename
964 result_path = f"{temp_storage_path}/ft_{finetune_time}_result_{name_extension}"
966 try:
967 client.files.content(file_id=result_file_id).stream_to_file(result_path)
968 if ".csv" in name_extension:
969 # legacy endpoint
970 train_stats = pd.read_csv(result_path)
971 if "validation_token_accuracy" in train_stats.columns:
972 train_stats = train_stats[train_stats["validation_token_accuracy"].notnull()]
973 args["ft_api_info"] = ft_stats.dict()
974 args["ft_result_stats"] = train_stats.to_dict()
976 elif ".json" in name_extension:
977 train_stats = pd.read_json(path_or_buf=result_path, lines=True) # new endpoint
978 args["ft_api_info"] = args["ft_result_stats"] = train_stats.to_dict()
980 except Exception:
981 logger.info(
982 f"Error retrieving fine-tuning results. Please check manually for information on job {ft_stats.id} (result file {result_file_id})."
983 )
985 args["model_name"] = ft_model_name
986 args["runtime"] = runtime.total_seconds()
987 args["mode"] = self.base_model_storage.json_get("args").get("mode", self.default_mode)
989 self.model_storage.json_set("args", args)
990 shutil.rmtree(temp_storage_path)
992 @staticmethod
993 def _prepare_ft_jsonl(df: pd.DataFrame, _, temp_filename: Text, temp_model_path: Text) -> Dict:
994 """
995 Prepare the input data for fine-tuning.
997 Args:
998 df (pd.DataFrame): Input data to fine-tune on.
999 temp_filename (Text): Temporary filename.
1000 temp_model_path (Text): Temporary model path.
1002 Returns:
1003 Dict: File names for the fine-tuning process.
1004 """
1005 df.to_json(temp_model_path, orient="records", lines=True)
1007 # TODO avoid subprocess usage once OpenAI enables non-CLI access, or refactor to use our own LLM utils instead
1008 subprocess.run(
1009 [
1010 "openai",
1011 "tools",
1012 "fine_tunes.prepare_data",
1013 "-f",
1014 temp_model_path, # from file
1015 "-q", # quiet mode (accepts all suggestions)
1016 ],
1017 stdout=subprocess.PIPE,
1018 stderr=subprocess.PIPE,
1019 encoding="utf-8",
1020 )
1022 file_names = {
1023 "original": f"{temp_filename}.jsonl",
1024 "base": f"{temp_filename}_prepared.jsonl",
1025 "train": f"{temp_filename}_prepared_train.jsonl",
1026 "val": f"{temp_filename}_prepared_valid.jsonl",
1027 }
1028 return file_names
1030 def _get_ft_model_type(self, model_name: Text) -> Text:
1031 """
1032 Get the model to use for fine-tuning. If the model is not supported, the default model (babbage-002) is used.
1034 Args:
1035 model_name (Text): Model name.
1037 Returns:
1038 Text: Model to use for fine-tuning.
1039 """
1040 for model_type in self.supported_ft_models:
1041 if model_type in model_name.lower():
1042 return model_type
1043 return "babbage-002"
1045 @staticmethod
1046 def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict:
1047 """
1048 Add extra parameters to the fine-tuning process.
1050 Args:
1051 ft_params (Dict): Parameters for the fine-tuning process required by the OpenAI API.
1052 using_args (Dict): Parameters passed when calling the fine-tuning process via a model.
1054 Returns:
1055 Dict: Fine-tuning parameters with extra parameters.
1056 """
1057 extra_params = {
1058 "n_epochs": using_args.get("n_epochs", None),
1059 "batch_size": using_args.get("batch_size", None),
1060 "learning_rate_multiplier": using_args.get("learning_rate_multiplier", None),
1061 "prompt_loss_weight": using_args.get("prompt_loss_weight", None),
1062 "compute_classification_metrics": using_args.get("compute_classification_metrics", None),
1063 "classification_n_classes": using_args.get("classification_n_classes", None),
1064 "classification_positive_class": using_args.get("classification_positive_class", None),
1065 "classification_betas": using_args.get("classification_betas", None),
1066 }
1067 return {**ft_params, **extra_params}
1069 def _ft_call(self, ft_params: Dict, client: OpenAI, hour_budget: int) -> Tuple[FineTuningJob, Text]:
1070 """
1071 Submit a fine-tuning job via the OpenAI API.
1072 This method handles requests to both the legacy and new endpoints.
1073 Currently, `OpenAIHandler` uses the legacy endpoint. Others, like `AnyscaleEndpointsHandler`, use the new endpoint.
1075 This method consists of an inner method:
1076 - _check_ft_status: Check the status of a fine-tuning job via the OpenAI API.
1078 Args:
1079 ft_params (Dict): Fine-tuning parameters.
1080 client (openai.OpenAI): OpenAI client.
1081 hour_budget (int): Hour budget for the fine-tuning process.
1083 Raises:
1084 PendingFT: If the fine-tuning process is still pending.
1086 Returns:
1087 Tuple[FineTuningJob, Text]: Fine-tuning stats and result file ID.
1088 """
1089 ft_result = client.fine_tuning.jobs.create(**{k: v for k, v in ft_params.items() if v is not None})
1091 @retry_with_exponential_backoff(
1092 hour_budget=hour_budget,
1093 )
1094 def _check_ft_status(job_id: Text) -> FineTuningJob:
1095 """
1096 Check the status of a fine-tuning job via the OpenAI API.
1098 Args:
1099 job_id (Text): Fine-tuning job ID.
1101 Raises:
1102 PendingFT: If the fine-tuning process is still pending.
1104 Returns:
1105 FineTuningJob: Fine-tuning stats.
1106 """
1107 ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)
1108 if ft_retrieved.status in ("succeeded", "failed", "cancelled"):
1109 return ft_retrieved
1110 else:
1111 raise PendingFT("Fine-tuning still pending!")
1113 ft_stats = _check_ft_status(ft_result.id)
1115 if ft_stats.status != "succeeded":
1116 err_message = ft_stats.events[-1].message if hasattr(ft_stats, "events") else "could not retrieve!"
1117 ft_status = ft_stats.status if hasattr(ft_stats, "status") else "N/A"
1118 raise Exception(
1119 f"Fine-tuning did not complete successfully (status: {ft_status}). Error message: {err_message}"
1120 ) # noqa
1122 result_file_id = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=ft_result.id).result_files[0]
1123 if hasattr(result_file_id, "id"):
1124 result_file_id = result_file_id.id # legacy endpoint
1126 return ft_stats, result_file_id
1128 @staticmethod
1129 def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None, args: dict = None) -> OpenAI:
1130 """
1131 Get an OpenAI client with the given API key, base URL, and organization.
1133 Args:
1134 api_key (Text): OpenAI API key.
1135 base_url (Text): OpenAI base URL.
1136 org (Optional[Text]): OpenAI organization.
1138 Returns:
1139 openai.OpenAI: OpenAI client.
1140 """
1141 if args is not None and args.get("provider") == "azure": 1141 ↛ 1142line 1141 didn't jump to line 1142 because the condition on line 1141 was never true
1142 return AzureOpenAI(
1143 api_key=api_key, azure_endpoint=base_url, api_version=args.get("api_version"), organization=org
1144 )
1145 return OpenAI(api_key=api_key, base_url=base_url, organization=org)