Coverage for mindsdb / interfaces / knowledge_base / llm_client.py: 44%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import time 

3from typing import List 

4 

5from openai import OpenAI, AzureOpenAI 

6 

7from mindsdb.integrations.utilities.handler_utils import get_api_key 

8 

9 

10def retry_with_exponential_backoff(func): 

11 def decorator(*args, **kwargs): 

12 max_retries = 3 

13 num_retries = 0 

14 delay = 1 

15 exponential_base = 2 

16 

17 while True: 

18 try: 

19 return func(*args, **kwargs) 

20 except Exception as e: 

21 message = str(e).lower() 

22 if "connection error" not in message and "timeout" not in message.lower(): 

23 raise e 

24 

25 num_retries += 1 

26 if num_retries > max_retries: 

27 raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") from e 

28 # Increment the delay and wait 

29 delay *= exponential_base 

30 time.sleep(delay) 

31 

32 return decorator 

33 

34 

35def run_in_batches(batch_size): 

36 """ 

37 decorator to run function into batches if input is greater than batch_size 

38 """ 

39 

40 def decorator(func): 

41 def wrapper(self, messages, *args, **kwargs): 

42 if len(messages) <= batch_size: 42 ↛ 45line 42 didn't jump to line 45 because the condition on line 42 was always true

43 return func(self, messages, *args, **kwargs) 

44 

45 chunk_num = 0 

46 results = [] 

47 while chunk_num * batch_size < len(messages): 

48 chunk = messages[chunk_num * batch_size : (chunk_num + 1) * batch_size] 

49 results.extend(func(self, chunk, *args, **kwargs)) 

50 chunk_num += 1 

51 

52 return results 

53 

54 return wrapper 

55 

56 return decorator 

57 

58 

59class LLMClient: 

60 """ 

61 Class for accession to LLM. 

62 It chooses openai client or litellm handler depending on the config 

63 """ 

64 

65 def __init__(self, params: dict = None, session=None): 

66 self._session = session 

67 self.params = params 

68 

69 self.provider = params.get("provider", "openai") 

70 

71 if "api_key" not in params: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 api_key = get_api_key(self.provider, params, strict=False) 

73 if api_key is not None: 

74 params["api_key"] = api_key 

75 

76 self.engine = "openai" 

77 

78 if self.provider == "azure_openai": 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 azure_api_key = params.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY") 

80 azure_api_endpoint = params.get("base_url") or os.environ.get("AZURE_OPENAI_ENDPOINT") 

81 azure_api_version = params.get("api_version") or os.environ.get("AZURE_OPENAI_API_VERSION") 

82 self.client = AzureOpenAI( 

83 api_key=azure_api_key, azure_endpoint=azure_api_endpoint, api_version=azure_api_version, max_retries=2 

84 ) 

85 elif self.provider == "openai": 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 openai_api_key = params.get("api_key") or os.getenv("OPENAI_API_KEY") 

87 kwargs = {"api_key": openai_api_key, "max_retries": 2} 

88 base_url = params.get("base_url") 

89 if base_url: 

90 kwargs["base_url"] = base_url 

91 self.client = OpenAI(**kwargs) 

92 elif self.provider == "ollama": 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 kwargs = params.copy() 

94 kwargs.pop("model_name") 

95 kwargs.pop("provider", None) 

96 if kwargs.get("api_key") is None: 

97 kwargs["api_key"] = "n/a" 

98 self.client = OpenAI(**kwargs) 

99 else: 

100 # try to use litellm 

101 if self._session is None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true

102 from mindsdb.api.executor.controllers.session_controller import SessionController 

103 

104 self._session = SessionController() 

105 module = self._session.integration_controller.get_handler_module("litellm") 

106 

107 if module is None or module.Handler is None: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true

108 raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed') 

109 

110 self.client = module.Handler 

111 self.engine = "litellm" 

112 

113 @run_in_batches(1000) 

114 @retry_with_exponential_backoff 

115 def embeddings(self, messages: List[str]): 

116 params = self.params 

117 if self.engine == "openai": 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 response = self.client.embeddings.create( 

119 model=params["model_name"], 

120 input=messages, 

121 ) 

122 return [item.embedding for item in response.data] 

123 else: 

124 kwargs = params.copy() 

125 model = kwargs.pop("model_name") 

126 kwargs.pop("provider", None) 

127 

128 return self.client.embeddings(self.provider, model=model, messages=messages, args=kwargs) 

129 

130 @run_in_batches(100) 

131 def completion(self, messages: List[dict], json_output: bool = False) -> List[str]: 

132 """ 

133 Call LLM completion and get response 

134 """ 

135 params = self.params 

136 params["json_output"] = json_output 

137 if self.engine == "openai": 

138 response = self.client.chat.completions.create( 

139 model=params["model_name"], 

140 messages=messages, 

141 ) 

142 return [item.message.content for item in response.choices] 

143 else: 

144 kwargs = params.copy() 

145 model = kwargs.pop("model_name") 

146 kwargs.pop("provider", None) 

147 response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs) 

148 return [item.message.content for item in response.choices]