Coverage for mindsdb / interfaces / knowledge_base / llm_client.py: 44%
98 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import time
3from typing import List
5from openai import OpenAI, AzureOpenAI
7from mindsdb.integrations.utilities.handler_utils import get_api_key
10def retry_with_exponential_backoff(func):
11 def decorator(*args, **kwargs):
12 max_retries = 3
13 num_retries = 0
14 delay = 1
15 exponential_base = 2
17 while True:
18 try:
19 return func(*args, **kwargs)
20 except Exception as e:
21 message = str(e).lower()
22 if "connection error" not in message and "timeout" not in message.lower():
23 raise e
25 num_retries += 1
26 if num_retries > max_retries:
27 raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") from e
28 # Increment the delay and wait
29 delay *= exponential_base
30 time.sleep(delay)
32 return decorator
35def run_in_batches(batch_size):
36 """
37 decorator to run function into batches if input is greater than batch_size
38 """
40 def decorator(func):
41 def wrapper(self, messages, *args, **kwargs):
42 if len(messages) <= batch_size: 42 ↛ 45line 42 didn't jump to line 45 because the condition on line 42 was always true
43 return func(self, messages, *args, **kwargs)
45 chunk_num = 0
46 results = []
47 while chunk_num * batch_size < len(messages):
48 chunk = messages[chunk_num * batch_size : (chunk_num + 1) * batch_size]
49 results.extend(func(self, chunk, *args, **kwargs))
50 chunk_num += 1
52 return results
54 return wrapper
56 return decorator
59class LLMClient:
60 """
61 Class for accession to LLM.
62 It chooses openai client or litellm handler depending on the config
63 """
65 def __init__(self, params: dict = None, session=None):
66 self._session = session
67 self.params = params
69 self.provider = params.get("provider", "openai")
71 if "api_key" not in params: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 api_key = get_api_key(self.provider, params, strict=False)
73 if api_key is not None:
74 params["api_key"] = api_key
76 self.engine = "openai"
78 if self.provider == "azure_openai": 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true
79 azure_api_key = params.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY")
80 azure_api_endpoint = params.get("base_url") or os.environ.get("AZURE_OPENAI_ENDPOINT")
81 azure_api_version = params.get("api_version") or os.environ.get("AZURE_OPENAI_API_VERSION")
82 self.client = AzureOpenAI(
83 api_key=azure_api_key, azure_endpoint=azure_api_endpoint, api_version=azure_api_version, max_retries=2
84 )
85 elif self.provider == "openai": 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 openai_api_key = params.get("api_key") or os.getenv("OPENAI_API_KEY")
87 kwargs = {"api_key": openai_api_key, "max_retries": 2}
88 base_url = params.get("base_url")
89 if base_url:
90 kwargs["base_url"] = base_url
91 self.client = OpenAI(**kwargs)
92 elif self.provider == "ollama": 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 kwargs = params.copy()
94 kwargs.pop("model_name")
95 kwargs.pop("provider", None)
96 if kwargs.get("api_key") is None:
97 kwargs["api_key"] = "n/a"
98 self.client = OpenAI(**kwargs)
99 else:
100 # try to use litellm
101 if self._session is None: 101 ↛ 102line 101 didn't jump to line 102 because the condition on line 101 was never true
102 from mindsdb.api.executor.controllers.session_controller import SessionController
104 self._session = SessionController()
105 module = self._session.integration_controller.get_handler_module("litellm")
107 if module is None or module.Handler is None: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true
108 raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed')
110 self.client = module.Handler
111 self.engine = "litellm"
113 @run_in_batches(1000)
114 @retry_with_exponential_backoff
115 def embeddings(self, messages: List[str]):
116 params = self.params
117 if self.engine == "openai": 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 response = self.client.embeddings.create(
119 model=params["model_name"],
120 input=messages,
121 )
122 return [item.embedding for item in response.data]
123 else:
124 kwargs = params.copy()
125 model = kwargs.pop("model_name")
126 kwargs.pop("provider", None)
128 return self.client.embeddings(self.provider, model=model, messages=messages, args=kwargs)
130 @run_in_batches(100)
131 def completion(self, messages: List[dict], json_output: bool = False) -> List[str]:
132 """
133 Call LLM completion and get response
134 """
135 params = self.params
136 params["json_output"] = json_output
137 if self.engine == "openai":
138 response = self.client.chat.completions.create(
139 model=params["model_name"],
140 messages=messages,
141 )
142 return [item.message.content for item in response.choices]
143 else:
144 kwargs = params.copy()
145 model = kwargs.pop("model_name")
146 kwargs.pop("provider", None)
147 response = self.client.completion(self.provider, model=model, messages=messages, args=kwargs)
148 return [item.message.content for item in response.choices]