Coverage for mindsdb / integrations / handlers / bedrock_handler / bedrock_handler.py: 0%

117 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import numpy as np 

2import pandas as pd 

3from typing import Text, Tuple, Dict, List, Optional, Any 

4 

5from mindsdb.utilities import log 

6 

7from mindsdb.integrations.libs.base import BaseMLEngine 

8from mindsdb.integrations.libs.llm.utils import get_completed_prompts 

9from mindsdb.integrations.libs.api_handler_exceptions import MissingConnectionParams 

10from mindsdb.integrations.handlers.bedrock_handler.utilities import create_amazon_bedrock_client 

11from mindsdb.integrations.handlers.bedrock_handler.settings import AmazonBedrockHandlerEngineConfig, AmazonBedrockHandlerModelConfig 

12 

13 

14logger = log.getLogger(__name__) 

15 

16 

17class AmazonBedrockHandler(BaseMLEngine): 

18 """ 

19 This handler handles connection and inference with the Amazon Bedrock API. 

20 """ 

21 

22 name = 'bedrock' 

23 

24 def __init__(self, *args, **kwargs): 

25 super().__init__(*args, **kwargs) 

26 self.generative = True 

27 

28 def create_engine(self, connection_args: Dict) -> None: 

29 """ 

30 Validates the AWS credentials provided on creation of an engine. 

31 

32 Args: 

33 connection_args (Dict): The parameters of the engine. 

34 

35 Raises: 

36 Exception: If the handler is not configured with valid API credentials. 

37 """ 

38 connection_args = {k.lower(): v for k, v in connection_args.items()} 

39 AmazonBedrockHandlerEngineConfig(**connection_args) 

40 

41 def create(self, target, args: Dict = None, **kwargs: Any) -> None: 

42 """ 

43 Creates a model by validating the model configuration and saving it to the storage. 

44 

45 Args: 

46 target (Text): The target column name. 

47 args (Dict): The parameters of the model. 

48 kwargs (Any): Other keyword arguments. 

49 

50 Raises: 

51 Exception: If the model is not configured with valid parameters. 

52 

53 Returns: 

54 None 

55 """ 

56 if 'using' not in args: 

57 raise MissingConnectionParams("Amazon Bedrock engine requires a USING clause! Refer to its documentation for more details.") 

58 else: 

59 model_args = args['using'] 

60 # Replace 'model_id' with 'id' to match the Amazon Bedrock handler model configuration. 

61 # This is done to avoid the Pydantic warning regarding conflicts with the protected 'model_' namespace. 

62 if 'model_id' in model_args: 

63 model_args['id'] = model_args['model_id'] 

64 del model_args['model_id'] 

65 

66 handler_model_config = AmazonBedrockHandlerModelConfig(**model_args, connection_args=self.engine_storage.get_connection_args()) 

67 

68 # Save the model configuration to the storage. 

69 handler_model_params = handler_model_config.model_dump() 

70 logger.info(f"Saving model configuration to storage: {handler_model_params}") 

71 

72 args['target'] = target 

73 args['handler_model_params'] = handler_model_params 

74 self.model_storage.json_set('args', args) 

75 

76 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> pd.DataFrame: 

77 """ 

78 Makes predictions using a model by invoking the Amazon Bedrock API. 

79 

80 Args: 

81 df (pd.DataFrame): The input data to invoke the model with. 

82 args (Dict): The parameters passed when making predictions. 

83 

84 Raises: 

85 ValueError: If the input data does not match the configuration of the model. 

86 

87 Returns: 

88 pd.DataFrame: The input data with the predicted values in a new column. 

89 """ 

90 args = self.model_storage.json_get('args') 

91 handler_model_params = args['handler_model_params'] 

92 mode = handler_model_params['mode'] 

93 model_id = handler_model_params['id'] 

94 inference_config = handler_model_params.get('inference_config') 

95 target = args['target'] 

96 

97 if mode == 'default': 

98 prompts, empty_prompt_ids = self._prepare_data_for_default_mode(df, args) 

99 predictions = self._predict_for_default_mode(model_id, prompts, inference_config) 

100 

101 # Fill the empty predictions with None. 

102 for i in sorted(empty_prompt_ids): 

103 predictions.insert(i, None) 

104 

105 elif mode == 'conversational': 

106 prompt, total_questions = self._prepare_data_for_conversational_mode(df, args) 

107 prediction = self._predict_for_conversational_mode(model_id, prompt, inference_config) 

108 

109 # Create a list of None values for the total number of questions and replace the last one with the prediction. 

110 predictions = [None] * total_questions 

111 predictions[-1] = prediction 

112 

113 pred_df = pd.DataFrame(predictions, columns=[target]) 

114 return pred_df 

115 

116 def _prepare_data_for_default_mode(self, df: pd.DataFrame, args: Dict) -> List[Dict]: 

117 """ 

118 Prepares the input data for the default mode of the Amazon Bedrock handler. 

119 A separate prompt is prepared for each question. 

120 

121 Args: 

122 df (pd.DataFrame): The input data to invoke the model with. 

123 args (Dict): The parameters of the model. 

124 

125 Returns: 

126 List[Dict]: The prepared prompts for invoking the Amazon Bedrock API. The model will be invoked for each prompt. 

127 """ 

128 handler_model_params = args['handler_model_params'] 

129 question_column = handler_model_params.get('question_column') 

130 context_column = handler_model_params.get('context_column') 

131 prompt_template = handler_model_params.get('prompt_template') 

132 

133 if question_column is not None: 

134 questions, empty_prompt_ids = self._prepare_data_with_question_and_context_columns( 

135 df, 

136 question_column, 

137 context_column 

138 ) 

139 

140 elif prompt_template is not None: 

141 questions, empty_prompt_ids = self._prepare_data_with_prompt_template(df, prompt_template) 

142 

143 # Prepare the prompts. 

144 questions = [question for i, question in enumerate(questions) if i not in empty_prompt_ids] 

145 prompts = [{"role": "user", "content": [{"text": question}]} for question in questions] 

146 

147 return prompts, empty_prompt_ids 

148 

149 def _prepare_data_for_conversational_mode(self, df: pd.DataFrame, args: Dict) -> Tuple[List[Dict], int]: 

150 """ 

151 Prepares the input data for the conversational mode of the Amazon Bedrock handler. 

152 A single prompt is prepared for all the questions. 

153 

154 Args: 

155 df (pd.DataFrame): The input data to invoke the model with. 

156 args (Dict): The parameters of the model. 

157 

158 Returns: 

159 Tuple[List[Dict], int]: The prepared prompt for invoking the Amazon Bedrock API and the total number of questions. 

160 The model will be invoked once using this prompt which contains all the questions. 

161 The total number of questions is used to produce the final list of predictions. 

162 """ 

163 handler_model_params = args['handler_model_params'] 

164 question_column = handler_model_params.get('question_column') 

165 context_column = handler_model_params.get('context_column') 

166 prompt_template = handler_model_params.get('prompt_template') 

167 

168 if question_column is not None: 

169 questions, empty_prompt_ids = self._prepare_data_with_question_and_context_columns( 

170 df, 

171 question_column, 

172 context_column 

173 ) 

174 

175 if prompt_template is not None: 

176 questions, empty_prompt_ids = self._prepare_data_with_prompt_template(df, prompt_template) 

177 

178 # Prepare the prompts. 

179 questions = [question for i, question in enumerate(questions) if i not in empty_prompt_ids] 

180 prompt = [{"role": "user", "content": [{"text": question} for question in questions]}] 

181 

182 # Get the total number of questions; including the empty ones. 

183 total_questions = len(df) 

184 

185 return prompt, total_questions 

186 

187 def _prepare_data_with_question_and_context_columns(self, df: pd.DataFrame, question_column: Text, context_column: Text = None) -> Tuple[List[Text], List[int]]: 

188 """ 

189 Prepares the input data with question and context columns. 

190 

191 Args: 

192 df (pd.DataFrame): The input data to invoke the model with. 

193 question_column (Text): The column containing the questions. 

194 context_column (Text): The column containing the context. 

195 

196 Returns: 

197 Tuple[List[Text], List[int]]: The questions to build the prompts for invoking the Amazon Bedrock API and the empty prompt IDs. 

198 """ 

199 if question_column not in df.columns: 

200 raise ValueError(f"Column {question_column} not found in the dataframe!") 

201 

202 if context_column and context_column not in df.columns: 

203 raise ValueError(f"Column {context_column} not found in the dataframe!") 

204 

205 if context_column: 

206 empty_prompt_ids = np.where( 

207 df[[context_column, question_column]] 

208 .isna() 

209 .all(axis=1) 

210 .values 

211 )[0] 

212 contexts = list(df[context_column].apply(lambda x: str(x))) 

213 questions_without_context = list(df[question_column].apply(lambda x: str(x))) 

214 

215 questions = [ 

216 f'Context: {c}\nQuestion: {q}\nAnswer: ' 

217 for c, q in zip(contexts, questions_without_context) 

218 ] 

219 

220 else: 

221 questions = list(df[question_column].apply(lambda x: str(x))) 

222 empty_prompt_ids = np.where( 

223 df[[question_column]].isna().all(axis=1).values 

224 )[0] 

225 

226 return questions, empty_prompt_ids 

227 

228 def _prepare_data_with_prompt_template(self, df: pd.DataFrame, prompt_template: Text) -> Tuple[List[Text], List[int]]: 

229 """ 

230 Prepares the input data with a prompt template. 

231 

232 Args: 

233 df (pd.DataFrame): The input data to invoke the model with. 

234 prompt_template (Text): The base prompt template to use. 

235 

236 Returns: 

237 Tuple[List[Text], List[int]]: The questions to build the prompts for invoking the Amazon Bedrock API and the empty prompt IDs. 

238 """ 

239 questions, empty_prompt_ids = get_completed_prompts(prompt_template, df) 

240 

241 return questions, empty_prompt_ids 

242 

243 def _predict_for_default_mode(self, model_id: Text, prompts: List[Text], inference_config: Dict) -> List[Text]: 

244 """ 

245 Makes predictions for the default mode of the Amazon Bedrock handler using the prepared prompts. 

246 

247 Args: 

248 model_id (Text): The ID of the model in Amazon Bedrock. 

249 prompts (List[Text]): The prepared prompts for invoking the Amazon Bedrock API. 

250 inference_config (Dict): The inference configuration supported by the Amazon Bedrock API. 

251 

252 Returns: 

253 List[Text]: The predictions made by the Amazon Bedrock API. 

254 """ 

255 predictions = [] 

256 bedrock_runtime_client = create_amazon_bedrock_client( 

257 'bedrock-runtime', 

258 **self.engine_storage.get_connection_args() 

259 ) 

260 

261 for prompt in prompts: 

262 response = bedrock_runtime_client.converse( 

263 modelId=model_id, 

264 messages=[prompt], 

265 inferenceConfig=inference_config 

266 ) 

267 predictions.append( 

268 response["output"]["message"]["content"][0]["text"] 

269 ) 

270 

271 return predictions 

272 

273 def _predict_for_conversational_mode(self, model_id: Text, prompt: List[Text], inference_config: Dict) -> Text: 

274 """ 

275 Makes a prediction for the conversational mode of the Amazon Bedrock handler using the prepared prompt. 

276 

277 Args: 

278 model_id (Text): The ID of the model in Amazon Bedrock. 

279 prompts (List[Text]): Prepared prompts for invoking the Amazon Bedrock API. 

280 inference_config (Dict): Inference configuration supported by the Amazon Bedrock API. 

281 

282 Returns: 

283 Text: The prediction made by the Amazon Bedrock API. 

284 """ 

285 bedrock_runtime_client = create_amazon_bedrock_client( 

286 'bedrock-runtime', 

287 **self.engine_storage.get_connection_args() 

288 ) 

289 

290 response = bedrock_runtime_client.converse( 

291 modelId=model_id, 

292 messages=prompt, 

293 inferenceConfig=inference_config 

294 ) 

295 

296 return response["output"]["message"]["content"][0]["text"] 

297 

298 def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame: 

299 """ 

300 Get the metadata or arguments of a model. 

301 

302 Args: 

303 attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'. 

304 

305 Returns: 

306 pd.DataFrame: Model metadata or model arguments. 

307 """ 

308 args = self.model_storage.json_get('args') 

309 

310 if attribute == 'args': 

311 del args['handler_model_params'] 

312 return pd.DataFrame(args.items(), columns=['key', 'value']) 

313 

314 elif attribute == 'metadata': 

315 model_id = args['handler_model_params']['id'] 

316 try: 

317 bedrock_client = create_amazon_bedrock_client( 

318 'bedrock', 

319 **self.engine_storage.get_connection_args() 

320 ) 

321 meta = bedrock_client.get_foundation_model(modelIdentifier=model_id)['modelDetails'] 

322 except Exception as e: 

323 meta = {'error': str(e)} 

324 return pd.DataFrame(dict(meta).items(), columns=['key', 'value']) 

325 

326 else: 

327 tables = ['args', 'metadata'] 

328 return pd.DataFrame(tables, columns=['tables'])