Coverage for mindsdb / integrations / handlers / bedrock_handler / settings.py: 0%
92 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import textwrap
2from pydantic_settings import BaseSettings
3from botocore.exceptions import ClientError
4from typing import Text, List, Dict, Optional, Any, ClassVar
5from pydantic import BaseModel, Field, model_validator, field_validator
7from mindsdb.integrations.handlers.bedrock_handler.utilities import create_amazon_bedrock_client
8from mindsdb.integrations.utilities.handlers.validation_utilities import ParameterValidationUtilities
11class AmazonBedrockHandlerSettings(BaseSettings):
12 """
13 Settings for Amazon Bedrock handler.
15 Attributes
16 ----------
17 DEFAULT_MODE : Text
18 The default mode for the handler.
20 SUPPORTED_MODES : List
21 List of supported modes for the handler.
23 DEFAULT_TEXT_MODEL_ID : Text
24 The default model ID to use for text generation. This will be the default model ID for the default and conversational modes.
25 """
26 # Modes.
27 # TODO: Add other modes.
28 DEFAULT_MODE: ClassVar[Text] = 'default'
29 SUPPORTED_MODES: ClassVar[List] = ['default', 'conversational']
31 # TODO: Set the default model ID for other modes.
32 # Model IDs.
33 DEFAULT_TEXT_MODEL_ID: ClassVar[Text] = 'amazon.titan-text-premier-v1:0'
36class AmazonBedrockHandlerEngineConfig(BaseModel):
37 """
38 Configuration model for engines created via the Amazon Bedrock handler.
40 Attributes
41 ----------
42 aws_access_key_id : Text
43 The AWS access key ID.
45 aws_secret_access_key : Text
46 The AWS secret access key.
48 region_name : Text
49 The AWS region name.
51 aws_session_token : Text, Optional
52 The AWS session token. Optional, but required for temporary security credentials.
53 """
54 aws_access_key_id: Text
55 aws_secret_access_key: Text
56 region_name: Text
57 aws_session_token: Optional[Text] = None
59 class Config:
60 extra = "forbid"
62 @model_validator(mode="before")
63 @classmethod
64 def check_if_params_contain_typos(cls, values: Any) -> Any:
65 """
66 Checks if there are any typos in the parameters.
68 Args:
69 values (Any): The parameters provided when creating an engine via the Amazon Bedrock handler.
71 Raises:
72 ValueError: If there are any typos in the parameters.
73 """
74 ParameterValidationUtilities.validate_parameter_spelling(cls, values)
76 return values
78 @model_validator(mode="after")
79 @classmethod
80 def check_access_to_amazon_bedrock(cls, model: BaseModel) -> BaseModel:
81 """
82 Checks if the AWS credentials provided are valid and Amazon Bedrock is accessible.
84 Args:
85 model (BaseModel): The parameters provided when creating an engine via the Amazon Bedrock handler.
87 Raises:
88 ValueError: If the AWS credentials are invalid or Amazon Bedrock is not accessible.
89 """
90 bedrock_client = create_amazon_bedrock_client(
91 "bedrock",
92 model.aws_access_key_id,
93 model.aws_secret_access_key,
94 model.region_name,
95 model.aws_session_token
96 )
98 try:
99 bedrock_client.list_foundation_models()
100 except ClientError as e:
101 raise ValueError(f"Invalid Amazon Bedrock credentials: {e}!")
103 return model
106class AmazonBedrockHandlerModelConfig(BaseModel):
107 """
108 Configuration model for models created via the Amazon Bedrock handler.
110 Attributes
111 ----------
112 id : Text
113 The ID of the model in Amazon Bedrock.
115 mode : Optional[Text]
116 The mode to run the handler model in. The default mode and the supported modes are defined in the AmazonBedrockHandlerSettings class.
118 prompt_template : Optional[Text]
119 The base template for prompts with placeholders.
121 question_column : Optional[Text]
122 The column name for questions to be asked.
124 context_column : Optional[Text]
125 The column name for context to be provided with the questions.
127 temperature : Optional[float]
128 The setting for the randomness in the responses generated by the model.
130 top_p : Optional[float]
131 The setting for the probability of the tokens in the responses generated by the model.
133 max_tokens : Optional[int]
134 The maximum number of tokens to generate in the responses.
136 stop : Optional[List[Text]]
137 The list of sequences to stop the generation of tokens in the responses.
139 connection_args : Dict
140 The connection arguments passed required to connect to Amazon Bedrock. These are AWS credentials provided when creating the engine.
141 """
142 # User-provided Handler Model Prameters: These are parameters specific to the MindsDB handler for Amazon Bedrock provided by the user.
143 id: Text = Field(None)
144 mode: Optional[Text] = Field(AmazonBedrockHandlerSettings.DEFAULT_MODE)
145 prompt_template: Optional[Text] = Field(None)
146 question_column: Optional[Text] = Field(None)
147 context_column: Optional[Text] = Field(None)
149 # Amazon Bedrock Model Parameters: These are parameters specific to the models in Amazon Bedrock. They are provided by the user.
150 temperature: Optional[float] = Field(None, bedrock_model_param=True, bedrock_model_param_name='temperature')
151 top_p: Optional[float] = Field(None, bedrock_model_param=True, bedrock_model_param_name='topP')
152 max_tokens: Optional[int] = Field(None, bedrock_model_param=True, bedrock_model_param_name='maxTokens')
153 stop: Optional[List[Text]] = Field(None, bedrock_model_param=True, bedrock_model_param_name='stopSequences')
155 # System-provided Handler Model Parameters: These are parameters specific to the MindsDB handler for Amazon Bedrock provided by the system.
156 connection_args: Dict = Field(None, exclude=True)
158 class Config:
159 extra = "forbid"
161 @model_validator(mode="before")
162 @classmethod
163 def check_if_params_contain_typos(cls, values: Any) -> Any:
164 """
165 Checks if there are any typos in the parameters.
167 Args:
168 values (Any): The parameters provided when creating a model via the Amazon Bedrock handler.
170 Raises:
171 ValueError: If there are any typos in the parameters.
172 """
173 ParameterValidationUtilities.validate_parameter_spelling(cls, values)
175 return values
177 @field_validator("mode")
178 @classmethod
179 def check_if_mode_is_supported(cls, mode: Text) -> Text:
180 """
181 Checks if the mode provided is supported.
183 Args:
184 mode (Text): The mode to run the handler model in.
186 Raises:
187 ValueError: If the mode provided is not supported.
188 """
189 if mode not in AmazonBedrockHandlerSettings.SUPPORTED_MODES:
190 raise ValueError(f"Mode {mode} is not supported. The supported modes are {''.join(AmazonBedrockHandlerSettings.SUPPORTED_MODES)}!")
192 return mode
194 @model_validator(mode="after")
195 @classmethod
196 def check_if_model_id_is_valid_and_correct_for_mode(cls, model: BaseModel) -> BaseModel:
197 """
198 Checks if the model ID and the parameters provided for the model are valid.
199 If a model ID is not provided, the default model ID for that mode will be used.
201 Args:
202 values (Any): The parameters provided when creating a model via the Amazon Bedrock handler.
204 Raises:
205 ValueError: If the model ID provided is invalid or the parameters provided are invalid for the chosen model.
206 """
207 # TODO: Set the default model ID for other modes.
208 if model.id is None:
209 if model.mode in ['default', 'conversational']:
210 model.id = AmazonBedrockHandlerSettings.DEFAULT_TEXT_MODEL_ID
212 bedrock_client = create_amazon_bedrock_client(
213 "bedrock",
214 **model.connection_args
215 )
217 try:
218 # Check if the model ID is valid and accessible.
219 response = bedrock_client.get_foundation_model(modelIdentifier=model.id)
220 except ClientError as e:
221 raise ValueError(f"Invalid Amazon Bedrock model ID: {e}!")
223 # Check if the model is suitable for the mode provided.
224 if model.mode in ['default', 'conversational']:
225 if 'TEXT' not in response['modelDetails']['outputModalities']:
226 raise ValueError(f"The models used for the {model.mode} should support text generation!")
228 return model
230 @model_validator(mode="after")
231 @classmethod
232 def check_if_params_are_valid_for_mode(cls, model: BaseModel) -> BaseModel:
233 """
234 Checks if the parameters required for the chosen mode provided are valid.
236 Args:
237 model (BaseModel): The parameters provided when creating a model via the Amazon Bedrock handler.
239 Raises:
240 ValueError: If the parameters provided are invalid for the mode provided.
241 """
242 # If the mode is default, one of the following need to be provided:
243 # 1. prompt_template.
244 # 2. question_column with an optional context_column.
245 # TODO: Find the other possible parameters/combinations for the default mode.
246 if model.mode in ['default', 'conversational']:
247 error_message = textwrap.dedent(
248 f"""\
249 For the {model.mode} mode, one of the following need to be provided:
250 1) A `prompt_template`
251 2) A `question_column` and an optional `context_column`
252 """
253 )
254 if model.prompt_template is None and model.question_column is None:
255 raise ValueError(error_message)
257 if model.prompt_template is not None and model.question_column is not None:
258 raise ValueError(error_message)
260 if model.context_column is not None and model.question_column is None:
261 raise ValueError(error_message)
263 # TODO: Add validations for other modes.
265 return model
267 def model_dump(self) -> Dict:
268 """
269 Dumps the model configuration to a dictionary.
271 Returns:
272 Dict: The configuration of the model.
273 """
274 bedrock_model_param_names = [val.get("bedrock_model_param_name") for key, val in self.model_json_schema(mode='serialization')['properties'].items() if val.get("bedrock_model_param")]
275 bedrock_model_params = [key for key, val in self.model_json_schema(mode='serialization')['properties'].items() if val.get("bedrock_model_param")]
277 handler_model_params = [key for key, val in self.model_json_schema(mode='serialization')['properties'].items() if not val.get("bedrock_model_param")]
279 inference_config = {}
280 for index, key in enumerate(bedrock_model_params):
281 if getattr(self, key) is not None:
282 inference_config[bedrock_model_param_names[index]] = getattr(self, key)
284 return {
285 "inference_config": inference_config,
286 **{key: getattr(self, key) for key in handler_model_params}
287 }