Coverage for mindsdb / integrations / handlers / groq_handler / groq_handler.py: 0%
62 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import pandas as pd
3import openai
4from openai import OpenAI, NotFoundError, AuthenticationError
5from typing import Dict, Optional
6from mindsdb.integrations.handlers.openai_handler import Handler as OpenAIHandler
7from mindsdb.integrations.utilities.handler_utils import get_api_key
8from mindsdb.integrations.handlers.groq_handler.settings import groq_handler_config
9from mindsdb.utilities import log
11logger = log.getLogger(__name__)
14class GroqHandler(OpenAIHandler):
15 """
16 This handler handles connection to the Groq.
17 """
19 name = "groq"
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs)
23 self.api_base = groq_handler_config.BASE_URL
24 self.default_model = groq_handler_config.DEFAULT_MODEL
25 self.default_mode = groq_handler_config.DEFAULT_MODE
26 self.supported_modes = groq_handler_config.SUPPORTED_MODES
28 @staticmethod
29 def _check_client_connection(client: OpenAI):
30 """
31 Check the Groq engine client connection by listing models.
33 Args:
34 client (OpenAI): OpenAI client configured with the Groq API credentials.
36 Raises:
37 Exception: If the client connection (API key) is invalid.
39 Returns:
40 None
41 """
42 try:
43 client.models.list()
44 except NotFoundError:
45 pass
46 except AuthenticationError as e:
47 if isinstance(e.body, dict) and e.body.get("code") == "invalid_api_key":
48 raise Exception("Invalid api key")
49 raise Exception(f"Something went wrong: {e}")
51 def create_engine(self, connection_args):
52 """
53 Validate the Groq API credentials on engine creation.
55 Args:
56 connection_args (dict): Connection arguments.
58 Raises:
59 Exception: If the handler is not configured with valid API credentials.
61 Returns:
62 None
63 """
64 connection_args = {k.lower(): v for k, v in connection_args.items()}
65 api_key = connection_args.get("groq_api_key")
66 if api_key is not None:
67 org = connection_args.get("api_organization")
68 api_base = connection_args.get("api_base") or os.environ.get("GROQ_BASE", groq_handler_config.BASE_URL)
69 client = self._get_client(api_key=api_key, base_url=api_base, org=org)
70 GroqHandler._check_client_connection(client)
72 @staticmethod
73 def create_validation(target, args=None, **kwargs):
74 """
75 Validate the Groq API credentials on model creation.
77 Args:
78 target (str): Target column, not required for LLMs.
79 args (dict): Handler arguments.
80 kwargs (dict): Handler keyword arguments.
82 Raises:
83 Exception: If the handler is not configured with valid API credentials.
85 Returns:
86 None
87 """
88 if "using" not in args:
89 raise Exception("Groq engine requires a USING clause! Refer to its documentation for more details.")
90 else:
91 args = args["using"]
93 engine_storage = kwargs["handler_storage"]
94 connection_args = engine_storage.get_connection_args()
95 api_key = get_api_key("groq", args, engine_storage=engine_storage)
96 api_base = (
97 connection_args.get("api_base")
98 or args.get("api_base")
99 or os.environ.get("GROQ_BASE", groq_handler_config.BASE_URL)
100 )
101 org = args.get("api_organization")
102 client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org)
103 GroqHandler._check_client_connection(client)
105 @staticmethod
106 def is_chat_model(model_name):
107 """
108 All Groq models use the chat completions endpoint, hence every model is a chat model
109 """
110 return True
112 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
113 """
114 Call the Groq engine to predict the next token.
116 Args:
117 df (pd.DataFrame): Input data.
118 args (dict): Handler arguments.
120 Returns:
121 pd.DataFrame: Predicted data
122 """
123 api_key = get_api_key("groq", args, self.engine_storage)
124 supported_models = self._get_supported_models(api_key, self.api_base)
125 self.chat_completion_models = [model.id for model in supported_models]
126 return super().predict(df, args)
128 @staticmethod
129 def _get_supported_models(api_key, base_url, org=None):
130 """
131 Get the list of supported models for the Groq engine.
133 Args:
134 api_key (str): API key.
135 base_url (str): Base URL.
136 org (str): Organization name.
138 Returns:
139 List: List of supported models.
140 """
141 client = openai.OpenAI(api_key=api_key, base_url=base_url, organization=org)
142 return client.models.list()
144 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None):
145 raise NotImplementedError("Fine-tuning is not supported for Groq AI engine.")