Coverage for mindsdb / integrations / handlers / litellm_handler / litellm_handler.py: 33%

87 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import ast 

2from typing import Dict, Optional, List 

3 

4import litellm 

5from litellm import completion, batch_completion, embedding, acompletion, supports_response_schema 

6 

7import pandas as pd 

8 

9from mindsdb.integrations.libs.base import BaseMLEngine 

10from mindsdb.utilities import log 

11 

12from mindsdb.integrations.handlers.litellm_handler.settings import CompletionParameters 

13 

14 

15logger = log.getLogger(__name__) 

16 

17litellm.drop_params = True 

18 

19 

20class LiteLLMHandler(BaseMLEngine): 

21 """ 

22 LiteLLMHandler is a MindsDB handler for litellm - https://docs.litellm.ai/docs/ 

23 """ 

24 

25 name = "litellm" 

26 

27 def __init__(self, *args, **kwargs): 

28 super().__init__(*args, **kwargs) 

29 self.generative = True 

30 

31 @staticmethod 

32 def create_validation(target, args=None, **kwargs): 

33 if "using" not in args: 

34 raise Exception("Litellm engine requires a USING clause. See settings.py for more info on supported args.") 

35 

36 @classmethod 

37 def prepare_arguments(cls, provider, model_name, args): 

38 if provider == "google": 38 ↛ 39line 38 didn't jump to line 39 because the condition on line 38 was never true

39 provider = "gemini" 

40 if "base_url" in args: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true

41 args["api_base"] = args.pop("base_url") 

42 

43 model_name = f"{provider}/{model_name}" 

44 return model_name, args 

45 

46 @classmethod 

47 def embeddings(cls, provider: str, model: str, messages: List[str], args: dict) -> List[list]: 

48 model, args = cls.prepare_arguments(provider, model, args) 

49 response = embedding(model=model, input=messages, **args) 

50 return [rec["embedding"] for rec in response.data] 

51 

52 @classmethod 

53 async def acompletion(cls, provider: str, model: str, messages: List[dict], args: dict): 

54 model, args = cls.prepare_arguments(provider, model, args) 

55 return await acompletion(model=model, messages=messages, stream=False, **args) 

56 

57 @classmethod 

58 def completion(cls, provider: str, model: str, messages: List[dict], args: dict): 

59 model, args = cls.prepare_arguments(provider, model, args) 

60 json_output = args.pop("json_output", False) 

61 

62 supports_json_output = supports_response_schema(model=model, custom_llm_provider=provider) 

63 

64 if json_output and supports_json_output: 

65 args["response_format"] = {"type": "json_object"} 

66 else: 

67 args["response_format"] = None 

68 

69 return completion(model=model, messages=messages, stream=False, **args) 

70 

71 def create( 

72 self, 

73 target: str, 

74 df: pd.DataFrame = None, 

75 args: Optional[Dict] = None, 

76 ): 

77 """ 

78 Dispatch is validating args and storing args in model_storage 

79 """ 

80 # get api key from user input on create ML_ENGINE or create MODEL 

81 input_args = args["using"] 

82 

83 # get api key from engine_storage 

84 ml_engine_args = self.engine_storage.get_connection_args() 

85 

86 # check engine_storage for api_key 

87 input_args.update({k: v for k, v in ml_engine_args.items()}) 

88 input_args["target"] = target 

89 

90 # validate args 

91 export_args = CompletionParameters(**input_args).model_dump() 

92 

93 # store args 

94 self.model_storage.json_set("args", export_args) 

95 

96 def predict(self, df: pd.DataFrame = None, args: dict = None): 

97 """ 

98 Dispatch is getting args from model_storage, validating args and running completion 

99 """ 

100 

101 input_args = self.model_storage.json_get("args") 

102 

103 # validate args 

104 args = CompletionParameters(**input_args).model_dump() 

105 

106 target = args.pop("target") 

107 

108 # build messages 

109 self._build_messages(args, df) 

110 

111 # remove prompt_template from args 

112 args.pop("prompt_template", None) 

113 

114 if len(args["messages"]) > 1: 

115 # if more than one message, use batch completion 

116 responses = batch_completion(**args) 

117 return pd.DataFrame({target: [response.choices[0].message.content for response in responses]}) 

118 

119 # run completion 

120 response = completion(**args) 

121 

122 return pd.DataFrame({target: [response.choices[0].message.content]}) 

123 

124 @staticmethod 

125 def _prompt_to_messages(prompt: str, **kwargs) -> List[Dict]: 

126 """ 

127 Convert a prompt to a list of messages 

128 """ 

129 

130 if kwargs: 

131 # if kwargs are passed in, format the prompt with kwargs 

132 prompt = prompt.format(**kwargs) 

133 

134 return [{"content": prompt, "role": "user"}] 

135 

136 def _build_messages(self, args: dict, df: pd.DataFrame): 

137 """ 

138 Build messages for completion 

139 """ 

140 

141 prompt_kwargs = df.iloc[0].to_dict() 

142 

143 if "prompt_template" in prompt_kwargs: 

144 # if prompt_template is passed in predict query, use it 

145 logger.info( 

146 "Using 'prompt_template' passed in SELECT Predict query. " 

147 "Note this will overwrite a 'prompt_template' passed in create MODEL query." 

148 ) 

149 

150 args["prompt_template"] = prompt_kwargs.pop("prompt_template") 

151 

152 if "mock_response" in prompt_kwargs: 

153 # used for testing to save on real completion api calls 

154 args["mock_response"]: str = prompt_kwargs.pop("mock_response") 

155 

156 if "messages" in prompt_kwargs and len(prompt_kwargs) > 1: 

157 # if user passes in messages, no other args can be passed in 

158 raise Exception("If 'messages' is passed in SELECT Predict query, no other args can be passed in.") 

159 

160 # if user passes in messages, use those instead 

161 if "messages" in prompt_kwargs: 

162 logger.info("Using messages passed in SELECT Predict query. 'prompt_template' will be ignored.") 

163 

164 args["messages"]: List = ast.literal_eval(df["messages"].iloc[0]) 

165 

166 else: 

167 # if user passes in prompt_template, use that to create messages 

168 if len(prompt_kwargs) == 1: 

169 args["messages"] = ( 

170 self._prompt_to_messages(args["prompt_template"], **prompt_kwargs) 

171 if args["prompt_template"] 

172 else self._prompt_to_messages(df.iloc[0][0]) 

173 ) 

174 

175 elif len(prompt_kwargs) > 1: 

176 try: 

177 args["messages"] = self._prompt_to_messages(args["prompt_template"], **prompt_kwargs) 

178 except KeyError as e: 

179 raise Exception( 

180 f"{e}: Please pass in either a prompt_template on create MODEL or " 

181 f"a single where clause in predict query." 

182 f"" 

183 )