Coverage for mindsdb / integrations / handlers / bedrock_handler / settings.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import textwrap 

2from pydantic_settings import BaseSettings 

3from botocore.exceptions import ClientError 

4from typing import Text, List, Dict, Optional, Any, ClassVar 

5from pydantic import BaseModel, Field, model_validator, field_validator 

6 

7from mindsdb.integrations.handlers.bedrock_handler.utilities import create_amazon_bedrock_client 

8from mindsdb.integrations.utilities.handlers.validation_utilities import ParameterValidationUtilities 

9 

10 

11class AmazonBedrockHandlerSettings(BaseSettings): 

12 """ 

13 Settings for Amazon Bedrock handler. 

14 

15 Attributes 

16 ---------- 

17 DEFAULT_MODE : Text 

18 The default mode for the handler. 

19 

20 SUPPORTED_MODES : List 

21 List of supported modes for the handler. 

22 

23 DEFAULT_TEXT_MODEL_ID : Text 

24 The default model ID to use for text generation. This will be the default model ID for the default and conversational modes. 

25 """ 

26 # Modes. 

27 # TODO: Add other modes. 

28 DEFAULT_MODE: ClassVar[Text] = 'default' 

29 SUPPORTED_MODES: ClassVar[List] = ['default', 'conversational'] 

30 

31 # TODO: Set the default model ID for other modes. 

32 # Model IDs. 

33 DEFAULT_TEXT_MODEL_ID: ClassVar[Text] = 'amazon.titan-text-premier-v1:0' 

34 

35 

36class AmazonBedrockHandlerEngineConfig(BaseModel): 

37 """ 

38 Configuration model for engines created via the Amazon Bedrock handler. 

39 

40 Attributes 

41 ---------- 

42 aws_access_key_id : Text 

43 The AWS access key ID. 

44 

45 aws_secret_access_key : Text 

46 The AWS secret access key. 

47 

48 region_name : Text 

49 The AWS region name. 

50 

51 aws_session_token : Text, Optional 

52 The AWS session token. Optional, but required for temporary security credentials. 

53 """ 

54 aws_access_key_id: Text 

55 aws_secret_access_key: Text 

56 region_name: Text 

57 aws_session_token: Optional[Text] = None 

58 

59 class Config: 

60 extra = "forbid" 

61 

62 @model_validator(mode="before") 

63 @classmethod 

64 def check_if_params_contain_typos(cls, values: Any) -> Any: 

65 """ 

66 Checks if there are any typos in the parameters. 

67 

68 Args: 

69 values (Any): The parameters provided when creating an engine via the Amazon Bedrock handler. 

70 

71 Raises: 

72 ValueError: If there are any typos in the parameters. 

73 """ 

74 ParameterValidationUtilities.validate_parameter_spelling(cls, values) 

75 

76 return values 

77 

78 @model_validator(mode="after") 

79 @classmethod 

80 def check_access_to_amazon_bedrock(cls, model: BaseModel) -> BaseModel: 

81 """ 

82 Checks if the AWS credentials provided are valid and Amazon Bedrock is accessible. 

83 

84 Args: 

85 model (BaseModel): The parameters provided when creating an engine via the Amazon Bedrock handler. 

86 

87 Raises: 

88 ValueError: If the AWS credentials are invalid or Amazon Bedrock is not accessible. 

89 """ 

90 bedrock_client = create_amazon_bedrock_client( 

91 "bedrock", 

92 model.aws_access_key_id, 

93 model.aws_secret_access_key, 

94 model.region_name, 

95 model.aws_session_token 

96 ) 

97 

98 try: 

99 bedrock_client.list_foundation_models() 

100 except ClientError as e: 

101 raise ValueError(f"Invalid Amazon Bedrock credentials: {e}!") 

102 

103 return model 

104 

105 

106class AmazonBedrockHandlerModelConfig(BaseModel): 

107 """ 

108 Configuration model for models created via the Amazon Bedrock handler. 

109 

110 Attributes 

111 ---------- 

112 id : Text 

113 The ID of the model in Amazon Bedrock. 

114 

115 mode : Optional[Text] 

116 The mode to run the handler model in. The default mode and the supported modes are defined in the AmazonBedrockHandlerSettings class. 

117 

118 prompt_template : Optional[Text] 

119 The base template for prompts with placeholders. 

120 

121 question_column : Optional[Text] 

122 The column name for questions to be asked. 

123 

124 context_column : Optional[Text] 

125 The column name for context to be provided with the questions. 

126 

127 temperature : Optional[float] 

128 The setting for the randomness in the responses generated by the model. 

129 

130 top_p : Optional[float] 

131 The setting for the probability of the tokens in the responses generated by the model. 

132 

133 max_tokens : Optional[int] 

134 The maximum number of tokens to generate in the responses. 

135 

136 stop : Optional[List[Text]] 

137 The list of sequences to stop the generation of tokens in the responses. 

138 

139 connection_args : Dict 

140 The connection arguments passed required to connect to Amazon Bedrock. These are AWS credentials provided when creating the engine. 

141 """ 

142 # User-provided Handler Model Prameters: These are parameters specific to the MindsDB handler for Amazon Bedrock provided by the user. 

143 id: Text = Field(None) 

144 mode: Optional[Text] = Field(AmazonBedrockHandlerSettings.DEFAULT_MODE) 

145 prompt_template: Optional[Text] = Field(None) 

146 question_column: Optional[Text] = Field(None) 

147 context_column: Optional[Text] = Field(None) 

148 

149 # Amazon Bedrock Model Parameters: These are parameters specific to the models in Amazon Bedrock. They are provided by the user. 

150 temperature: Optional[float] = Field(None, bedrock_model_param=True, bedrock_model_param_name='temperature') 

151 top_p: Optional[float] = Field(None, bedrock_model_param=True, bedrock_model_param_name='topP') 

152 max_tokens: Optional[int] = Field(None, bedrock_model_param=True, bedrock_model_param_name='maxTokens') 

153 stop: Optional[List[Text]] = Field(None, bedrock_model_param=True, bedrock_model_param_name='stopSequences') 

154 

155 # System-provided Handler Model Parameters: These are parameters specific to the MindsDB handler for Amazon Bedrock provided by the system. 

156 connection_args: Dict = Field(None, exclude=True) 

157 

158 class Config: 

159 extra = "forbid" 

160 

161 @model_validator(mode="before") 

162 @classmethod 

163 def check_if_params_contain_typos(cls, values: Any) -> Any: 

164 """ 

165 Checks if there are any typos in the parameters. 

166 

167 Args: 

168 values (Any): The parameters provided when creating a model via the Amazon Bedrock handler. 

169 

170 Raises: 

171 ValueError: If there are any typos in the parameters. 

172 """ 

173 ParameterValidationUtilities.validate_parameter_spelling(cls, values) 

174 

175 return values 

176 

177 @field_validator("mode") 

178 @classmethod 

179 def check_if_mode_is_supported(cls, mode: Text) -> Text: 

180 """ 

181 Checks if the mode provided is supported. 

182 

183 Args: 

184 mode (Text): The mode to run the handler model in. 

185 

186 Raises: 

187 ValueError: If the mode provided is not supported. 

188 """ 

189 if mode not in AmazonBedrockHandlerSettings.SUPPORTED_MODES: 

190 raise ValueError(f"Mode {mode} is not supported. The supported modes are {''.join(AmazonBedrockHandlerSettings.SUPPORTED_MODES)}!") 

191 

192 return mode 

193 

194 @model_validator(mode="after") 

195 @classmethod 

196 def check_if_model_id_is_valid_and_correct_for_mode(cls, model: BaseModel) -> BaseModel: 

197 """ 

198 Checks if the model ID and the parameters provided for the model are valid. 

199 If a model ID is not provided, the default model ID for that mode will be used. 

200 

201 Args: 

202 values (Any): The parameters provided when creating a model via the Amazon Bedrock handler. 

203 

204 Raises: 

205 ValueError: If the model ID provided is invalid or the parameters provided are invalid for the chosen model. 

206 """ 

207 # TODO: Set the default model ID for other modes. 

208 if model.id is None: 

209 if model.mode in ['default', 'conversational']: 

210 model.id = AmazonBedrockHandlerSettings.DEFAULT_TEXT_MODEL_ID 

211 

212 bedrock_client = create_amazon_bedrock_client( 

213 "bedrock", 

214 **model.connection_args 

215 ) 

216 

217 try: 

218 # Check if the model ID is valid and accessible. 

219 response = bedrock_client.get_foundation_model(modelIdentifier=model.id) 

220 except ClientError as e: 

221 raise ValueError(f"Invalid Amazon Bedrock model ID: {e}!") 

222 

223 # Check if the model is suitable for the mode provided. 

224 if model.mode in ['default', 'conversational']: 

225 if 'TEXT' not in response['modelDetails']['outputModalities']: 

226 raise ValueError(f"The models used for the {model.mode} should support text generation!") 

227 

228 return model 

229 

230 @model_validator(mode="after") 

231 @classmethod 

232 def check_if_params_are_valid_for_mode(cls, model: BaseModel) -> BaseModel: 

233 """ 

234 Checks if the parameters required for the chosen mode provided are valid. 

235 

236 Args: 

237 model (BaseModel): The parameters provided when creating a model via the Amazon Bedrock handler. 

238 

239 Raises: 

240 ValueError: If the parameters provided are invalid for the mode provided. 

241 """ 

242 # If the mode is default, one of the following need to be provided: 

243 # 1. prompt_template. 

244 # 2. question_column with an optional context_column. 

245 # TODO: Find the other possible parameters/combinations for the default mode. 

246 if model.mode in ['default', 'conversational']: 

247 error_message = textwrap.dedent( 

248 f"""\ 

249 For the {model.mode} mode, one of the following need to be provided: 

250 1) A `prompt_template` 

251 2) A `question_column` and an optional `context_column` 

252 """ 

253 ) 

254 if model.prompt_template is None and model.question_column is None: 

255 raise ValueError(error_message) 

256 

257 if model.prompt_template is not None and model.question_column is not None: 

258 raise ValueError(error_message) 

259 

260 if model.context_column is not None and model.question_column is None: 

261 raise ValueError(error_message) 

262 

263 # TODO: Add validations for other modes. 

264 

265 return model 

266 

267 def model_dump(self) -> Dict: 

268 """ 

269 Dumps the model configuration to a dictionary. 

270 

271 Returns: 

272 Dict: The configuration of the model. 

273 """ 

274 bedrock_model_param_names = [val.get("bedrock_model_param_name") for key, val in self.model_json_schema(mode='serialization')['properties'].items() if val.get("bedrock_model_param")] 

275 bedrock_model_params = [key for key, val in self.model_json_schema(mode='serialization')['properties'].items() if val.get("bedrock_model_param")] 

276 

277 handler_model_params = [key for key, val in self.model_json_schema(mode='serialization')['properties'].items() if not val.get("bedrock_model_param")] 

278 

279 inference_config = {} 

280 for index, key in enumerate(bedrock_model_params): 

281 if getattr(self, key) is not None: 

282 inference_config[bedrock_model_param_names[index]] = getattr(self, key) 

283 

284 return { 

285 "inference_config": inference_config, 

286 **{key: getattr(self, key) for key in handler_model_params} 

287 }