Coverage for mindsdb / integrations / handlers / groq_handler / groq_handler.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import pandas as pd 

3import openai 

4from openai import OpenAI, NotFoundError, AuthenticationError 

5from typing import Dict, Optional 

6from mindsdb.integrations.handlers.openai_handler import Handler as OpenAIHandler 

7from mindsdb.integrations.utilities.handler_utils import get_api_key 

8from mindsdb.integrations.handlers.groq_handler.settings import groq_handler_config 

9from mindsdb.utilities import log 

10 

11logger = log.getLogger(__name__) 

12 

13 

14class GroqHandler(OpenAIHandler): 

15 """ 

16 This handler handles connection to the Groq. 

17 """ 

18 

19 name = "groq" 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 

23 self.api_base = groq_handler_config.BASE_URL 

24 self.default_model = groq_handler_config.DEFAULT_MODEL 

25 self.default_mode = groq_handler_config.DEFAULT_MODE 

26 self.supported_modes = groq_handler_config.SUPPORTED_MODES 

27 

28 @staticmethod 

29 def _check_client_connection(client: OpenAI): 

30 """ 

31 Check the Groq engine client connection by listing models. 

32 

33 Args: 

34 client (OpenAI): OpenAI client configured with the Groq API credentials. 

35 

36 Raises: 

37 Exception: If the client connection (API key) is invalid. 

38 

39 Returns: 

40 None 

41 """ 

42 try: 

43 client.models.list() 

44 except NotFoundError: 

45 pass 

46 except AuthenticationError as e: 

47 if isinstance(e.body, dict) and e.body.get("code") == "invalid_api_key": 

48 raise Exception("Invalid api key") 

49 raise Exception(f"Something went wrong: {e}") 

50 

51 def create_engine(self, connection_args): 

52 """ 

53 Validate the Groq API credentials on engine creation. 

54 

55 Args: 

56 connection_args (dict): Connection arguments. 

57 

58 Raises: 

59 Exception: If the handler is not configured with valid API credentials. 

60 

61 Returns: 

62 None 

63 """ 

64 connection_args = {k.lower(): v for k, v in connection_args.items()} 

65 api_key = connection_args.get("groq_api_key") 

66 if api_key is not None: 

67 org = connection_args.get("api_organization") 

68 api_base = connection_args.get("api_base") or os.environ.get("GROQ_BASE", groq_handler_config.BASE_URL) 

69 client = self._get_client(api_key=api_key, base_url=api_base, org=org) 

70 GroqHandler._check_client_connection(client) 

71 

72 @staticmethod 

73 def create_validation(target, args=None, **kwargs): 

74 """ 

75 Validate the Groq API credentials on model creation. 

76 

77 Args: 

78 target (str): Target column, not required for LLMs. 

79 args (dict): Handler arguments. 

80 kwargs (dict): Handler keyword arguments. 

81 

82 Raises: 

83 Exception: If the handler is not configured with valid API credentials. 

84 

85 Returns: 

86 None 

87 """ 

88 if "using" not in args: 

89 raise Exception("Groq engine requires a USING clause! Refer to its documentation for more details.") 

90 else: 

91 args = args["using"] 

92 

93 engine_storage = kwargs["handler_storage"] 

94 connection_args = engine_storage.get_connection_args() 

95 api_key = get_api_key("groq", args, engine_storage=engine_storage) 

96 api_base = ( 

97 connection_args.get("api_base") 

98 or args.get("api_base") 

99 or os.environ.get("GROQ_BASE", groq_handler_config.BASE_URL) 

100 ) 

101 org = args.get("api_organization") 

102 client = OpenAIHandler._get_client(api_key=api_key, base_url=api_base, org=org) 

103 GroqHandler._check_client_connection(client) 

104 

105 @staticmethod 

106 def is_chat_model(model_name): 

107 """ 

108 All Groq models use the chat completions endpoint, hence every model is a chat model 

109 """ 

110 return True 

111 

112 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

113 """ 

114 Call the Groq engine to predict the next token. 

115 

116 Args: 

117 df (pd.DataFrame): Input data. 

118 args (dict): Handler arguments. 

119 

120 Returns: 

121 pd.DataFrame: Predicted data 

122 """ 

123 api_key = get_api_key("groq", args, self.engine_storage) 

124 supported_models = self._get_supported_models(api_key, self.api_base) 

125 self.chat_completion_models = [model.id for model in supported_models] 

126 return super().predict(df, args) 

127 

128 @staticmethod 

129 def _get_supported_models(api_key, base_url, org=None): 

130 """ 

131 Get the list of supported models for the Groq engine. 

132 

133 Args: 

134 api_key (str): API key. 

135 base_url (str): Base URL. 

136 org (str): Organization name. 

137 

138 Returns: 

139 List: List of supported models. 

140 """ 

141 client = openai.OpenAI(api_key=api_key, base_url=base_url, organization=org) 

142 return client.models.list() 

143 

144 def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None): 

145 raise NotImplementedError("Fine-tuning is not supported for Groq AI engine.")