Coverage for mindsdb / integrations / handlers / openai_handler / helpers.py: 53%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Text, List, Dict 

2import random 

3import time 

4import math 

5 

6import openai 

7 

8import tiktoken 

9 

10import mindsdb.utilities.profiler as profiler 

11 

12 

13class PendingFT(openai.OpenAIError): 

14 """ 

15 Custom exception to handle pending fine-tuning status. 

16 """ 

17 

18 message: str 

19 

20 def __init__(self, message) -> None: 

21 super().__init__() 

22 self.message = message 

23 

24 

25def retry_with_exponential_backoff( 

26 initial_delay: float = 1, 

27 hour_budget: float = 0.3, 

28 jitter: bool = False, 

29 exponential_base: int = 2, 

30 wait_errors: tuple = (openai.APITimeoutError, openai.APIConnectionError, PendingFT), 

31 status_errors: tuple = (openai.APIStatusError, openai.APIResponseValidationError), 

32): 

33 """ 

34 Wrapper to enable optional arguments. It means this decorator always needs to be called with parenthesis: 

35 

36 > @retry_with_exponential_backoff() # optional argument override here 

37 > def f(): [...] 

38 

39 """ # noqa 

40 

41 @profiler.profile() 

42 def _retry_with_exponential_backoff(func): 

43 """ 

44 Exponential backoff to retry requests on a rate-limited API call, as recommended by OpenAI. 

45 Loops the call until a successful response or max_retries is hit or an exception is raised. 

46 

47 Slight changes in the implementation, but originally from: 

48 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb 

49 

50 Args: 

51 func: Function to be wrapped 

52 initial_delay: Initial delay in seconds 

53 hour_budget: Hourly budget in seconds 

54 jitter: Adds randomness to the delay 

55 exponential_base: Base for the exponential backoff 

56 wait_errors: Tuple of errors to retry on 

57 status_errors: Tuple of status errors to raise 

58 

59 Returns: 

60 Wrapper function with exponential backoff 

61 """ # noqa 

62 

63 def wrapper(*args, **kwargs): 

64 num_retries = 0 

65 delay = initial_delay 

66 

67 if isinstance(hour_budget, float) or isinstance(hour_budget, int): 67 ↛ 73line 67 didn't jump to line 73 because the condition on line 67 was always true

68 try: 

69 max_retries = round((math.log((hour_budget * 3600) / initial_delay)) / math.log(exponential_base)) 

70 except ValueError: 

71 max_retries = 10 

72 else: 

73 max_retries = 10 

74 max_retries = max(1, max_retries) 

75 

76 while True: 

77 try: 

78 return func(*args, **kwargs) 

79 

80 except status_errors as e: 

81 error_message = e.body 

82 if isinstance(error_message, dict): 

83 error_message = error_message.get( 

84 "message", 

85 "Please refer to `https://platform.openai.com/docs/guides/error-codes` for more information.", 

86 ) 

87 raise Exception(f"Error status {e.status_code} raised by OpenAI API: {error_message}") 

88 

89 except wait_errors: 

90 num_retries += 1 

91 if num_retries > max_retries: 

92 raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") 

93 # Increment the delay and wait 

94 delay *= exponential_base * (1 + jitter * random.random()) 

95 time.sleep(delay) 

96 

97 except openai.OpenAIError as e: 

98 raise Exception( 

99 f"General {str(e)} error raised by OpenAI. Please refer to `https://platform.openai.com/docs/guides/error-codes` for more information." # noqa 

100 ) 

101 

102 except Exception as e: 

103 raise e 

104 

105 return wrapper 

106 

107 return _retry_with_exponential_backoff 

108 

109 

110def truncate_msgs_for_token_limit(messages: List[Dict], model_name: Text, max_tokens: int, truncate: Text = "first"): 

111 """ 

112 Truncates message list to fit within the token limit. 

113 The first message for chat completion models are general directives with the system role, which will ideally be kept at all times. 

114 

115 Slight changes in the implementation, but originally from: 

116 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 

117 

118 Args: 

119 messages (List[Dict]): List of messages 

120 model_name (Text): Model name 

121 max_tokens (int): Maximum token limit 

122 truncate (Text): Truncate strategy, either 'first' or 'last' 

123 

124 Returns: 

125 List[Dict]: Truncated message list 

126 """ # noqa 

127 try: 

128 encoder = tiktoken.encoding_for_model(model_name) 

129 except KeyError: 

130 # If the encoding is not found, defualt to cl100k_base. 

131 # This is applicable for handlers that extend the OpenAI handler such as Anyscale. 

132 model_name = "gpt-3.5-turbo-0301" 

133 encoder = tiktoken.get_encoding("cl100k_base") 

134 

135 sys_priming = messages[0:1] 

136 n_tokens = count_tokens(messages, encoder, model_name) 

137 while n_tokens > max_tokens: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 if len(messages) == 2: 

139 return messages[:-1] # edge case: if limit is surpassed by just one input, we remove initial instruction 

140 elif len(messages) == 1: 

141 return messages 

142 

143 if truncate == "first": 

144 messages = sys_priming + messages[2:] 

145 else: 

146 messages = sys_priming + messages[1:-1] 

147 

148 n_tokens = count_tokens(messages, encoder, model_name) 

149 return messages 

150 

151 

152def count_tokens(messages: List[Dict], encoder: tiktoken.core.Encoding, model_name: Text = "gpt-3.5-turbo-0301"): 

153 """ 

154 Counts the number of tokens in a list of messages. 

155 

156 Args: 

157 messages: List of messages 

158 encoder: Tokenizer 

159 model_name: Model name 

160 """ 

161 if "gpt-3.5-turbo" in model_name: # note: future models may deviate from this (only 0301 really complies) 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 

163 tokens_per_name = -1 

164 else: 

165 tokens_per_message = 3 

166 tokens_per_name = 1 

167 

168 num_tokens = 0 

169 for message in messages: 

170 num_tokens += tokens_per_message 

171 

172 for key, value in message.items(): 

173 num_tokens += len(encoder.encode(value)) 

174 if key == "name": # if there's a name, the role is omitted 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 num_tokens += tokens_per_name 

176 num_tokens += 2 # every reply is primed with <im_start>assistant 

177 return num_tokens 

178 

179 

180def get_available_models(client) -> List[Text]: 

181 """ 

182 Returns a list of available openai models for the given API key. 

183 NOTE: writer's 'get models list' response differs from openai's 

184 https://dev.writer.com/api-reference/completion-api/list-models 

185 https://platform.openai.com/docs/api-reference/models/list 

186 

187 Args: 

188 client: openai sdk client 

189 

190 Returns: 

191 List[Text]: List of available models 

192 """ 

193 res = client.models.list() 

194 

195 if str(client.base_url.netloc).lower() == "api.writer.com": 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true

196 return [models["id"] for models in res.models] 

197 

198 return [models.id for models in res.data]