Coverage for mindsdb / integrations / handlers / togetherai_handler / togetherai_handler.py: 0%
89 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import textwrap
3from typing import Optional, Dict, Any
4import requests
5import pandas as pd
6from openai import OpenAI, AuthenticationError
7from mindsdb.integrations.handlers.openai_handler import Handler as OpenAIHandler
8from mindsdb.integrations.utilities.handler_utils import get_api_key
9from mindsdb.integrations.handlers.togetherai_handler.settings import (
10 togetherai_handler_config,
11)
13from mindsdb.utilities import log
15logger = log.getLogger(__name__)
18class TogetherAIHandler(OpenAIHandler):
19 """
20 This handler handles connection to the TogetherAI.
21 """
23 name = "togetherai"
25 def __init__(self, *args, **kwargs):
26 super().__init__(*args, **kwargs)
27 self.generative = True
28 self.api_base = togetherai_handler_config.BASE_URL
29 self.default_model = togetherai_handler_config.DEFAULT_MODEL
30 self.default_embedding_model = togetherai_handler_config.DEFAULT_EMBEDDING_MODEL
31 self.default_mode = togetherai_handler_config.DEFAULT_MODE
32 self.supported_modes = togetherai_handler_config.SUPPORTED_MODES
34 @staticmethod
35 def _check_client_connection(client: OpenAI):
36 """
37 Check the TogetherAI engine client connection by listing models.
39 Args:
40 client (OpenAI): OpenAI client configured with the TogetherAI API credentials.
42 Raises:
43 Exception: If the client connection (API key) is invalid.
45 Returns:
46 None
47 """
49 try:
50 TogetherAIHandler._get_supported_models(client.api_key, client.base_url)
52 except Exception as e:
53 raise Exception(f"Something went wrong: {e}")
55 def create_engine(self, connection_args):
56 """
57 Validate the TogetherAI API credentials on engine creation.
59 Args:
60 connection_args (dict): Connection arguments.
62 Raises:
63 Exception: If the handler is not configured with valid API credentials.
65 Returns:
66 None
67 """
69 connection_args = {k.lower(): v for k, v in connection_args.items()}
70 api_key = connection_args.get("togetherai_api_key")
71 if api_key is not None:
72 api_base = connection_args.get("api_base") or os.environ.get(
73 "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL
74 )
75 client = self._get_client(api_key=api_key, base_url=api_base)
76 TogetherAIHandler._check_client_connection(client)
78 @staticmethod
79 def create_validation(target, args=None, **kwargs):
80 """
81 Validate the TogetherAI API credentials on model creation.
83 Args:
84 target (str): Target column, not required for LLMs.
85 args (dict): Handler arguments.
86 kwargs (dict): Handler keyword arguments.
88 Raises:
89 Exception: If the handler is not configured with valid API credentials.
91 Returns:
92 None
93 """
94 if "using" not in args:
95 raise Exception(
96 "TogetherAI engine require a USING clause! Refer to its documentation for more details"
97 )
98 else:
99 args = args["using"]
101 if (
102 len(set(args.keys()) & {"question_column", "prompt_template", "prompt"})
103 == 0
104 ):
105 raise Exception(
106 "One of `question_column`, `prompt_template` or `prompt` is required for this engine."
107 )
109 keys_collection = [
110 ["prompt_template"],
111 ["question_column", "context_column"],
112 ["prompt", "user_column", "assistant_column"],
113 ]
114 for keys in keys_collection:
115 if keys[0] in args and any(
116 x[0] in args for x in keys_collection if x != keys
117 ):
118 raise Exception(
119 textwrap.dedent(
120 """\
121 Please provide one of
122 1) a `prompt_template`
123 2) a `question_column` and an optional `context_column`
124 3) a `prompt`, `user_column` and `assistant_column`
125 """
126 )
127 )
129 engine_storage = kwargs["handler_storage"]
130 connection_args = engine_storage.get_connection_args()
131 api_key = get_api_key("togetherai", args, engine_storage=engine_storage)
132 api_base = connection_args.get("api_base") or os.environ.get(
133 "TOGETHERAI_API_BASE", togetherai_handler_config.BASE_URL
134 )
135 client = TogetherAIHandler._get_client(api_key=api_key, base_url=api_base)
136 TogetherAIHandler._check_client_connection(client)
138 def create(self, target, args: Dict = None, **kwargs: Any) -> None:
139 """
140 Create a model for TogetherAI engine.
142 Args:
143 target (str): Target column, not required for LLMs.
144 args (dict): Handler arguments.
145 kwargs (dict): Handler keyword arguments.
147 Raises:
148 Exception: If the handler is not configured with valid API credentials.
150 Returns:
151 None
152 """
153 args = args["using"]
154 args["target"] = target
155 try:
156 api_key = get_api_key(self.api_key_name, args, self.engine_storage)
157 connection_args = self.engine_storage.get_connection_args()
158 api_base = (
159 args.get("api_base")
160 or connection_args.get("api_base")
161 or os.environ.get("TOGETHERAI_API_BASE")
162 or self.api_base
163 )
164 available_models = self._get_supported_models(api_key, api_base)
166 if args.get("mode") is None:
167 args["mode"] = self.default_mode
168 elif args["mode"] not in self.supported_modes:
169 raise Exception(
170 f"Invalid operation mode. Please use one of {self.supported_modes}"
171 )
173 if not args.get("model_name"):
174 if args["mode"] == "embedding":
175 args["model_name"] = self.default_embedding_model
176 else:
177 args["model_name"] = self.default_model
178 elif args["model_name"] not in available_models:
179 raise Exception(
180 f"Invalid model name. Please use one of {available_models}"
181 )
182 finally:
183 self.model_storage.json_set("args", args)
185 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
186 """
187 Call the TogetherAI engine to predict the next token.
189 Args:
190 df (pd.DataFrame): Input data.
191 args (dict): Handler arguments.
193 Returns:
194 pd.DataFrame: Predicted data.
195 """
197 api_key = get_api_key("togetherai", args, engine_storage=self.engine_storage)
198 supported_models = self._get_supported_models(api_key, self.api_base)
199 self.chat_completion_models = supported_models
200 return super().predict(df, args)
202 @staticmethod
203 def _get_supported_models(api_key, base_url):
204 """
205 Get the list of supported models from the TogetherAI engine.
207 Args:
208 api_key (str): TogetherAI API key.
209 base_url (str): TogetherAI API base URL.
211 Returns:
212 list: List of supported models.
213 """
215 list_model_endpoint = f"{base_url}/models"
216 headers = {
217 "accept": "application/json",
218 "authorization": f"Bearer {api_key}",
219 }
220 response = requests.get(url=list_model_endpoint, headers=headers)
222 if response.status_code == 200:
223 model_list = response.json()
224 chat_completion_models = list(map(lambda model: model["id"], model_list))
225 return chat_completion_models
226 elif response.status_code == 401:
227 raise AuthenticationError(message="Invalid API key")
228 else:
229 raise Exception(f"Failed to get supported models: {response.text}")
231 def finetune(
232 self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None
233 ) -> None:
234 raise NotImplementedError("Fine-tuning is not supported for TogetherAI engine")