Coverage for mindsdb / integrations / handlers / togetherai_handler / togetherai_handler.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import textwrap 

3from typing import Optional, Dict, Any 

4import requests 

5import pandas as pd 

6from openai import OpenAI, AuthenticationError 

7from mindsdb.integrations.handlers.openai_handler import Handler as OpenAIHandler 

8from mindsdb.integrations.utilities.handler_utils import get_api_key 

9from mindsdb.integrations.handlers.togetherai_handler.settings import ( 

10 togetherai_handler_config, 

11) 

12 

13from mindsdb.utilities import log 

14 

15logger = log.getLogger(__name__) 

16 

17 

18class TogetherAIHandler(OpenAIHandler): 

19 """ 

20 This handler handles connection to the TogetherAI. 

21 """ 

22 

23 name = "togetherai" 

24 

25 def __init__(self, *args, **kwargs): 

26 super().__init__(*args, **kwargs) 

27 self.generative = True 

28 self.api_base = togetherai_handler_config.BASE_URL 

29 self.default_model = togetherai_handler_config.DEFAULT_MODEL 

30 self.default_embedding_model = togetherai_handler_config.DEFAULT_EMBEDDING_MODEL 

31 self.default_mode = togetherai_handler_config.DEFAULT_MODE 

32 self.supported_modes = togetherai_handler_config.SUPPORTED_MODES 

33 

34 @staticmethod 

35 def _check_client_connection(client: OpenAI): 

36 """ 

37 Check the TogetherAI engine client connection by listing models. 

38 

39 Args: 

40 client (OpenAI): OpenAI client configured with the TogetherAI API credentials. 

41 

42 Raises: 

43 Exception: If the client connection (API key) is invalid. 

44 

45 Returns: 

46 None 

47 """ 

48 

49 try: 

50 TogetherAIHandler._get_supported_models(client.api_key, client.base_url) 

51 

52 except Exception as e: 

53 raise Exception(f"Something went wrong: {e}") 

54 

55 def create_engine(self, connection_args): 

56 """ 

57 Validate the TogetherAI API credentials on engine creation. 

58 

59 Args: 

60 connection_args (dict): Connection arguments. 

61 

62 Raises: 

63 Exception: If the handler is not configured with valid API credentials. 

64 

65 Returns: 

66 None 

67 """ 

68 

69 connection_args = {k.lower(): v for k, v in connection_args.items()} 

70 api_key = connection_args.get("togetherai_api_key") 

71 if api_key is not None: 

72 api_base = connection_args.get("api_base") or os.environ.get( 

73 "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL 

74 ) 

75 client = self._get_client(api_key=api_key, base_url=api_base) 

76 TogetherAIHandler._check_client_connection(client) 

77 

78 @staticmethod 

79 def create_validation(target, args=None, **kwargs): 

80 """ 

81 Validate the TogetherAI API credentials on model creation. 

82 

83 Args: 

84 target (str): Target column, not required for LLMs. 

85 args (dict): Handler arguments. 

86 kwargs (dict): Handler keyword arguments. 

87 

88 Raises: 

89 Exception: If the handler is not configured with valid API credentials. 

90 

91 Returns: 

92 None 

93 """ 

94 if "using" not in args: 

95 raise Exception( 

96 "TogetherAI engine require a USING clause! Refer to its documentation for more details" 

97 ) 

98 else: 

99 args = args["using"] 

100 

101 if ( 

102 len(set(args.keys()) & {"question_column", "prompt_template", "prompt"}) 

103 == 0 

104 ): 

105 raise Exception( 

106 "One of `question_column`, `prompt_template` or `prompt` is required for this engine." 

107 ) 

108 

109 keys_collection = [ 

110 ["prompt_template"], 

111 ["question_column", "context_column"], 

112 ["prompt", "user_column", "assistant_column"], 

113 ] 

114 for keys in keys_collection: 

115 if keys[0] in args and any( 

116 x[0] in args for x in keys_collection if x != keys 

117 ): 

118 raise Exception( 

119 textwrap.dedent( 

120 """\ 

121 Please provide one of 

122 1) a `prompt_template` 

123 2) a `question_column` and an optional `context_column` 

124 3) a `prompt`, `user_column` and `assistant_column` 

125 """ 

126 ) 

127 ) 

128 

129 engine_storage = kwargs["handler_storage"] 

130 connection_args = engine_storage.get_connection_args() 

131 api_key = get_api_key("togetherai", args, engine_storage=engine_storage) 

132 api_base = connection_args.get("api_base") or os.environ.get( 

133 "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL 

134 ) 

135 client = TogetherAIHandler._get_client(api_key=api_key, base_url=api_base) 

136 TogetherAIHandler._check_client_connection(client) 

137 

138 def create(self, target, args: Dict = None, **kwargs: Any) -> None: 

139 """ 

140 Create a model for TogetherAI engine. 

141 

142 Args: 

143 target (str): Target column, not required for LLMs. 

144 args (dict): Handler arguments. 

145 kwargs (dict): Handler keyword arguments. 

146 

147 Raises: 

148 Exception: If the handler is not configured with valid API credentials. 

149 

150 Returns: 

151 None 

152 """ 

153 args = args["using"] 

154 args["target"] = target 

155 try: 

156 api_key = get_api_key(self.api_key_name, args, self.engine_storage) 

157 connection_args = self.engine_storage.get_connection_args() 

158 api_base = ( 

159 args.get("api_base") 

160 or connection_args.get("api_base") 

161 or os.environ.get("TOGETHERAI_API_BASE") 

162 or self.api_base 

163 ) 

164 available_models = self._get_supported_models(api_key, api_base) 

165 

166 if args.get("mode") is None: 

167 args["mode"] = self.default_mode 

168 elif args["mode"] not in self.supported_modes: 

169 raise Exception( 

170 f"Invalid operation mode. Please use one of {self.supported_modes}" 

171 ) 

172 

173 if not args.get("model_name"): 

174 if args["mode"] == "embedding": 

175 args["model_name"] = self.default_embedding_model 

176 else: 

177 args["model_name"] = self.default_model 

178 elif args["model_name"] not in available_models: 

179 raise Exception( 

180 f"Invalid model name. Please use one of {available_models}" 

181 ) 

182 finally: 

183 self.model_storage.json_set("args", args) 

184 

185 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

186 """ 

187 Call the TogetherAI engine to predict the next token. 

188 

189 Args: 

190 df (pd.DataFrame): Input data. 

191 args (dict): Handler arguments. 

192 

193 Returns: 

194 pd.DataFrame: Predicted data. 

195 """ 

196 

197 api_key = get_api_key("togetherai", args, engine_storage=self.engine_storage) 

198 supported_models = self._get_supported_models(api_key, self.api_base) 

199 self.chat_completion_models = supported_models 

200 return super().predict(df, args) 

201 

202 @staticmethod 

203 def _get_supported_models(api_key, base_url): 

204 """ 

205 Get the list of supported models from the TogetherAI engine. 

206 

207 Args: 

208 api_key (str): TogetherAI API key. 

209 base_url (str): TogetherAI API base URL. 

210 

211 Returns: 

212 list: List of supported models. 

213 """ 

214 

215 list_model_endpoint = f"{base_url}/models" 

216 headers = { 

217 "accept": "application/json", 

218 "authorization": f"Bearer {api_key}", 

219 } 

220 response = requests.get(url=list_model_endpoint, headers=headers) 

221 

222 if response.status_code == 200: 

223 model_list = response.json() 

224 chat_completion_models = list(map(lambda model: model["id"], model_list)) 

225 return chat_completion_models 

226 elif response.status_code == 401: 

227 raise AuthenticationError(message="Invalid API key") 

228 else: 

229 raise Exception(f"Failed to get supported models: {response.text}") 

230 

231 def finetune( 

232 self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None 

233 ) -> None: 

234 raise NotImplementedError("Fine-tuning is not supported for TogetherAI engine")