Coverage for mindsdb / integrations / handlers / litellm_handler / litellm_handler.py: 33%
87 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import ast
2from typing import Dict, Optional, List
4import litellm
5from litellm import completion, batch_completion, embedding, acompletion, supports_response_schema
7import pandas as pd
9from mindsdb.integrations.libs.base import BaseMLEngine
10from mindsdb.utilities import log
12from mindsdb.integrations.handlers.litellm_handler.settings import CompletionParameters
15logger = log.getLogger(__name__)
17litellm.drop_params = True
20class LiteLLMHandler(BaseMLEngine):
21 """
22 LiteLLMHandler is a MindsDB handler for litellm - https://docs.litellm.ai/docs/
23 """
25 name = "litellm"
27 def __init__(self, *args, **kwargs):
28 super().__init__(*args, **kwargs)
29 self.generative = True
31 @staticmethod
32 def create_validation(target, args=None, **kwargs):
33 if "using" not in args:
34 raise Exception("Litellm engine requires a USING clause. See settings.py for more info on supported args.")
36 @classmethod
37 def prepare_arguments(cls, provider, model_name, args):
38 if provider == "google": 38 ↛ 39line 38 didn't jump to line 39 because the condition on line 38 was never true
39 provider = "gemini"
40 if "base_url" in args: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 args["api_base"] = args.pop("base_url")
43 model_name = f"{provider}/{model_name}"
44 return model_name, args
46 @classmethod
47 def embeddings(cls, provider: str, model: str, messages: List[str], args: dict) -> List[list]:
48 model, args = cls.prepare_arguments(provider, model, args)
49 response = embedding(model=model, input=messages, **args)
50 return [rec["embedding"] for rec in response.data]
52 @classmethod
53 async def acompletion(cls, provider: str, model: str, messages: List[dict], args: dict):
54 model, args = cls.prepare_arguments(provider, model, args)
55 return await acompletion(model=model, messages=messages, stream=False, **args)
57 @classmethod
58 def completion(cls, provider: str, model: str, messages: List[dict], args: dict):
59 model, args = cls.prepare_arguments(provider, model, args)
60 json_output = args.pop("json_output", False)
62 supports_json_output = supports_response_schema(model=model, custom_llm_provider=provider)
64 if json_output and supports_json_output:
65 args["response_format"] = {"type": "json_object"}
66 else:
67 args["response_format"] = None
69 return completion(model=model, messages=messages, stream=False, **args)
71 def create(
72 self,
73 target: str,
74 df: pd.DataFrame = None,
75 args: Optional[Dict] = None,
76 ):
77 """
78 Dispatch is validating args and storing args in model_storage
79 """
80 # get api key from user input on create ML_ENGINE or create MODEL
81 input_args = args["using"]
83 # get api key from engine_storage
84 ml_engine_args = self.engine_storage.get_connection_args()
86 # check engine_storage for api_key
87 input_args.update({k: v for k, v in ml_engine_args.items()})
88 input_args["target"] = target
90 # validate args
91 export_args = CompletionParameters(**input_args).model_dump()
93 # store args
94 self.model_storage.json_set("args", export_args)
96 def predict(self, df: pd.DataFrame = None, args: dict = None):
97 """
98 Dispatch is getting args from model_storage, validating args and running completion
99 """
101 input_args = self.model_storage.json_get("args")
103 # validate args
104 args = CompletionParameters(**input_args).model_dump()
106 target = args.pop("target")
108 # build messages
109 self._build_messages(args, df)
111 # remove prompt_template from args
112 args.pop("prompt_template", None)
114 if len(args["messages"]) > 1:
115 # if more than one message, use batch completion
116 responses = batch_completion(**args)
117 return pd.DataFrame({target: [response.choices[0].message.content for response in responses]})
119 # run completion
120 response = completion(**args)
122 return pd.DataFrame({target: [response.choices[0].message.content]})
124 @staticmethod
125 def _prompt_to_messages(prompt: str, **kwargs) -> List[Dict]:
126 """
127 Convert a prompt to a list of messages
128 """
130 if kwargs:
131 # if kwargs are passed in, format the prompt with kwargs
132 prompt = prompt.format(**kwargs)
134 return [{"content": prompt, "role": "user"}]
136 def _build_messages(self, args: dict, df: pd.DataFrame):
137 """
138 Build messages for completion
139 """
141 prompt_kwargs = df.iloc[0].to_dict()
143 if "prompt_template" in prompt_kwargs:
144 # if prompt_template is passed in predict query, use it
145 logger.info(
146 "Using 'prompt_template' passed in SELECT Predict query. "
147 "Note this will overwrite a 'prompt_template' passed in create MODEL query."
148 )
150 args["prompt_template"] = prompt_kwargs.pop("prompt_template")
152 if "mock_response" in prompt_kwargs:
153 # used for testing to save on real completion api calls
154 args["mock_response"]: str = prompt_kwargs.pop("mock_response")
156 if "messages" in prompt_kwargs and len(prompt_kwargs) > 1:
157 # if user passes in messages, no other args can be passed in
158 raise Exception("If 'messages' is passed in SELECT Predict query, no other args can be passed in.")
160 # if user passes in messages, use those instead
161 if "messages" in prompt_kwargs:
162 logger.info("Using messages passed in SELECT Predict query. 'prompt_template' will be ignored.")
164 args["messages"]: List = ast.literal_eval(df["messages"].iloc[0])
166 else:
167 # if user passes in prompt_template, use that to create messages
168 if len(prompt_kwargs) == 1:
169 args["messages"] = (
170 self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
171 if args["prompt_template"]
172 else self._prompt_to_messages(df.iloc[0][0])
173 )
175 elif len(prompt_kwargs) > 1:
176 try:
177 args["messages"] = self._prompt_to_messages(args["prompt_template"], **prompt_kwargs)
178 except KeyError as e:
179 raise Exception(
180 f"{e}: Please pass in either a prompt_template on create MODEL or "
181 f"a single where clause in predict query."
182 f""
183 )