Coverage for mindsdb / integrations / handlers / openai_handler / helpers.py: 53%
83 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Text, List, Dict
2import random
3import time
4import math
6import openai
8import tiktoken
10import mindsdb.utilities.profiler as profiler
13class PendingFT(openai.OpenAIError):
14 """
15 Custom exception to handle pending fine-tuning status.
16 """
18 message: str
20 def __init__(self, message) -> None:
21 super().__init__()
22 self.message = message
25def retry_with_exponential_backoff(
26 initial_delay: float = 1,
27 hour_budget: float = 0.3,
28 jitter: bool = False,
29 exponential_base: int = 2,
30 wait_errors: tuple = (openai.APITimeoutError, openai.APIConnectionError, PendingFT),
31 status_errors: tuple = (openai.APIStatusError, openai.APIResponseValidationError),
32):
33 """
34 Wrapper to enable optional arguments. It means this decorator always needs to be called with parenthesis:
36 > @retry_with_exponential_backoff() # optional argument override here
37 > def f(): [...]
39 """ # noqa
41 @profiler.profile()
42 def _retry_with_exponential_backoff(func):
43 """
44 Exponential backoff to retry requests on a rate-limited API call, as recommended by OpenAI.
45 Loops the call until a successful response or max_retries is hit or an exception is raised.
47 Slight changes in the implementation, but originally from:
48 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb
50 Args:
51 func: Function to be wrapped
52 initial_delay: Initial delay in seconds
53 hour_budget: Hourly budget in seconds
54 jitter: Adds randomness to the delay
55 exponential_base: Base for the exponential backoff
56 wait_errors: Tuple of errors to retry on
57 status_errors: Tuple of status errors to raise
59 Returns:
60 Wrapper function with exponential backoff
61 """ # noqa
63 def wrapper(*args, **kwargs):
64 num_retries = 0
65 delay = initial_delay
67 if isinstance(hour_budget, float) or isinstance(hour_budget, int): 67 ↛ 73line 67 didn't jump to line 73 because the condition on line 67 was always true
68 try:
69 max_retries = round((math.log((hour_budget * 3600) / initial_delay)) / math.log(exponential_base))
70 except ValueError:
71 max_retries = 10
72 else:
73 max_retries = 10
74 max_retries = max(1, max_retries)
76 while True:
77 try:
78 return func(*args, **kwargs)
80 except status_errors as e:
81 error_message = e.body
82 if isinstance(error_message, dict):
83 error_message = error_message.get(
84 "message",
85 "Please refer to `https://platform.openai.com/docs/guides/error-codes` for more information.",
86 )
87 raise Exception(f"Error status {e.status_code} raised by OpenAI API: {error_message}")
89 except wait_errors:
90 num_retries += 1
91 if num_retries > max_retries:
92 raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
93 # Increment the delay and wait
94 delay *= exponential_base * (1 + jitter * random.random())
95 time.sleep(delay)
97 except openai.OpenAIError as e:
98 raise Exception(
99 f"General {str(e)} error raised by OpenAI. Please refer to `https://platform.openai.com/docs/guides/error-codes` for more information." # noqa
100 )
102 except Exception as e:
103 raise e
105 return wrapper
107 return _retry_with_exponential_backoff
110def truncate_msgs_for_token_limit(messages: List[Dict], model_name: Text, max_tokens: int, truncate: Text = "first"):
111 """
112 Truncates message list to fit within the token limit.
113 The first message for chat completion models are general directives with the system role, which will ideally be kept at all times.
115 Slight changes in the implementation, but originally from:
116 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
118 Args:
119 messages (List[Dict]): List of messages
120 model_name (Text): Model name
121 max_tokens (int): Maximum token limit
122 truncate (Text): Truncate strategy, either 'first' or 'last'
124 Returns:
125 List[Dict]: Truncated message list
126 """ # noqa
127 try:
128 encoder = tiktoken.encoding_for_model(model_name)
129 except KeyError:
130 # If the encoding is not found, defualt to cl100k_base.
131 # This is applicable for handlers that extend the OpenAI handler such as Anyscale.
132 model_name = "gpt-3.5-turbo-0301"
133 encoder = tiktoken.get_encoding("cl100k_base")
135 sys_priming = messages[0:1]
136 n_tokens = count_tokens(messages, encoder, model_name)
137 while n_tokens > max_tokens: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 if len(messages) == 2:
139 return messages[:-1] # edge case: if limit is surpassed by just one input, we remove initial instruction
140 elif len(messages) == 1:
141 return messages
143 if truncate == "first":
144 messages = sys_priming + messages[2:]
145 else:
146 messages = sys_priming + messages[1:-1]
148 n_tokens = count_tokens(messages, encoder, model_name)
149 return messages
152def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_name: Text = "gpt-3.5-turbo-0301"):
153 """
154 Counts the number of tokens in a list of messages.
156 Args:
157 messages: List of messages
158 encoder: Tokenizer
159 model_name: Model name
160 """
161 if "gpt-3.5-turbo" in model_name: # note: future models may deviate from this (only 0301 really complies) 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
163 tokens_per_name = -1
164 else:
165 tokens_per_message = 3
166 tokens_per_name = 1
168 num_tokens = 0
169 for message in messages:
170 num_tokens += tokens_per_message
172 for key, value in message.items():
173 num_tokens += len(encoder.encode(value))
174 if key == "name": # if there's a name, the role is omitted 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 num_tokens += tokens_per_name
176 num_tokens += 2 # every reply is primed with <im_start>assistant
177 return num_tokens
180def get_available_models(client) -> List[Text]:
181 """
182 Returns a list of available openai models for the given API key.
183 NOTE: writer's 'get models list' response differs from openai's
184 https://dev.writer.com/api-reference/completion-api/list-models
185 https://platform.openai.com/docs/api-reference/models/list
187 Args:
188 client: openai sdk client
190 Returns:
191 List[Text]: List of available models
192 """
193 res = client.models.list()
195 if str(client.base_url.netloc).lower() == "api.writer.com": 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true
196 return [models["id"] for models in res.models]
198 return [models.id for models in res.data]