Coverage for mindsdb / integrations / handlers / bedrock_handler / bedrock_handler.py: 0%
117 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import numpy as np
2import pandas as pd
3from typing import Text, Tuple, Dict, List, Optional, Any
5from mindsdb.utilities import log
7from mindsdb.integrations.libs.base import BaseMLEngine
8from mindsdb.integrations.libs.llm.utils import get_completed_prompts
9from mindsdb.integrations.libs.api_handler_exceptions import MissingConnectionParams
10from mindsdb.integrations.handlers.bedrock_handler.utilities import create_amazon_bedrock_client
11from mindsdb.integrations.handlers.bedrock_handler.settings import AmazonBedrockHandlerEngineConfig, AmazonBedrockHandlerModelConfig
14logger = log.getLogger(__name__)
17class AmazonBedrockHandler(BaseMLEngine):
18 """
19 This handler handles connection and inference with the Amazon Bedrock API.
20 """
22 name = 'bedrock'
24 def __init__(self, *args, **kwargs):
25 super().__init__(*args, **kwargs)
26 self.generative = True
28 def create_engine(self, connection_args: Dict) -> None:
29 """
30 Validates the AWS credentials provided on creation of an engine.
32 Args:
33 connection_args (Dict): The parameters of the engine.
35 Raises:
36 Exception: If the handler is not configured with valid API credentials.
37 """
38 connection_args = {k.lower(): v for k, v in connection_args.items()}
39 AmazonBedrockHandlerEngineConfig(**connection_args)
41 def create(self, target, args: Dict = None, **kwargs: Any) -> None:
42 """
43 Creates a model by validating the model configuration and saving it to the storage.
45 Args:
46 target (Text): The target column name.
47 args (Dict): The parameters of the model.
48 kwargs (Any): Other keyword arguments.
50 Raises:
51 Exception: If the model is not configured with valid parameters.
53 Returns:
54 None
55 """
56 if 'using' not in args:
57 raise MissingConnectionParams("Amazon Bedrock engine requires a USING clause! Refer to its documentation for more details.")
58 else:
59 model_args = args['using']
60 # Replace 'model_id' with 'id' to match the Amazon Bedrock handler model configuration.
61 # This is done to avoid the Pydantic warning regarding conflicts with the protected 'model_' namespace.
62 if 'model_id' in model_args:
63 model_args['id'] = model_args['model_id']
64 del model_args['model_id']
66 handler_model_config = AmazonBedrockHandlerModelConfig(**model_args, connection_args=self.engine_storage.get_connection_args())
68 # Save the model configuration to the storage.
69 handler_model_params = handler_model_config.model_dump()
70 logger.info(f"Saving model configuration to storage: {handler_model_params}")
72 args['target'] = target
73 args['handler_model_params'] = handler_model_params
74 self.model_storage.json_set('args', args)
76 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> pd.DataFrame:
77 """
78 Makes predictions using a model by invoking the Amazon Bedrock API.
80 Args:
81 df (pd.DataFrame): The input data to invoke the model with.
82 args (Dict): The parameters passed when making predictions.
84 Raises:
85 ValueError: If the input data does not match the configuration of the model.
87 Returns:
88 pd.DataFrame: The input data with the predicted values in a new column.
89 """
90 args = self.model_storage.json_get('args')
91 handler_model_params = args['handler_model_params']
92 mode = handler_model_params['mode']
93 model_id = handler_model_params['id']
94 inference_config = handler_model_params.get('inference_config')
95 target = args['target']
97 if mode == 'default':
98 prompts, empty_prompt_ids = self._prepare_data_for_default_mode(df, args)
99 predictions = self._predict_for_default_mode(model_id, prompts, inference_config)
101 # Fill the empty predictions with None.
102 for i in sorted(empty_prompt_ids):
103 predictions.insert(i, None)
105 elif mode == 'conversational':
106 prompt, total_questions = self._prepare_data_for_conversational_mode(df, args)
107 prediction = self._predict_for_conversational_mode(model_id, prompt, inference_config)
109 # Create a list of None values for the total number of questions and replace the last one with the prediction.
110 predictions = [None] * total_questions
111 predictions[-1] = prediction
113 pred_df = pd.DataFrame(predictions, columns=[target])
114 return pred_df
116 def _prepare_data_for_default_mode(self, df: pd.DataFrame, args: Dict) -> List[Dict]:
117 """
118 Prepares the input data for the default mode of the Amazon Bedrock handler.
119 A separate prompt is prepared for each question.
121 Args:
122 df (pd.DataFrame): The input data to invoke the model with.
123 args (Dict): The parameters of the model.
125 Returns:
126 List[Dict]: The prepared prompts for invoking the Amazon Bedrock API. The model will be invoked for each prompt.
127 """
128 handler_model_params = args['handler_model_params']
129 question_column = handler_model_params.get('question_column')
130 context_column = handler_model_params.get('context_column')
131 prompt_template = handler_model_params.get('prompt_template')
133 if question_column is not None:
134 questions, empty_prompt_ids = self._prepare_data_with_question_and_context_columns(
135 df,
136 question_column,
137 context_column
138 )
140 elif prompt_template is not None:
141 questions, empty_prompt_ids = self._prepare_data_with_prompt_template(df, prompt_template)
143 # Prepare the prompts.
144 questions = [question for i, question in enumerate(questions) if i not in empty_prompt_ids]
145 prompts = [{"role": "user", "content": [{"text": question}]} for question in questions]
147 return prompts, empty_prompt_ids
149 def _prepare_data_for_conversational_mode(self, df: pd.DataFrame, args: Dict) -> Tuple[List[Dict], int]:
150 """
151 Prepares the input data for the conversational mode of the Amazon Bedrock handler.
152 A single prompt is prepared for all the questions.
154 Args:
155 df (pd.DataFrame): The input data to invoke the model with.
156 args (Dict): The parameters of the model.
158 Returns:
159 Tuple[List[Dict], int]: The prepared prompt for invoking the Amazon Bedrock API and the total number of questions.
160 The model will be invoked once using this prompt which contains all the questions.
161 The total number of questions is used to produce the final list of predictions.
162 """
163 handler_model_params = args['handler_model_params']
164 question_column = handler_model_params.get('question_column')
165 context_column = handler_model_params.get('context_column')
166 prompt_template = handler_model_params.get('prompt_template')
168 if question_column is not None:
169 questions, empty_prompt_ids = self._prepare_data_with_question_and_context_columns(
170 df,
171 question_column,
172 context_column
173 )
175 if prompt_template is not None:
176 questions, empty_prompt_ids = self._prepare_data_with_prompt_template(df, prompt_template)
178 # Prepare the prompts.
179 questions = [question for i, question in enumerate(questions) if i not in empty_prompt_ids]
180 prompt = [{"role": "user", "content": [{"text": question} for question in questions]}]
182 # Get the total number of questions; including the empty ones.
183 total_questions = len(df)
185 return prompt, total_questions
187 def _prepare_data_with_question_and_context_columns(self, df: pd.DataFrame, question_column: Text, context_column: Text = None) -> Tuple[List[Text], List[int]]:
188 """
189 Prepares the input data with question and context columns.
191 Args:
192 df (pd.DataFrame): The input data to invoke the model with.
193 question_column (Text): The column containing the questions.
194 context_column (Text): The column containing the context.
196 Returns:
197 Tuple[List[Text], List[int]]: The questions to build the prompts for invoking the Amazon Bedrock API and the empty prompt IDs.
198 """
199 if question_column not in df.columns:
200 raise ValueError(f"Column {question_column} not found in the dataframe!")
202 if context_column and context_column not in df.columns:
203 raise ValueError(f"Column {context_column} not found in the dataframe!")
205 if context_column:
206 empty_prompt_ids = np.where(
207 df[[context_column, question_column]]
208 .isna()
209 .all(axis=1)
210 .values
211 )[0]
212 contexts = list(df[context_column].apply(lambda x: str(x)))
213 questions_without_context = list(df[question_column].apply(lambda x: str(x)))
215 questions = [
216 f'Context: {c}\nQuestion: {q}\nAnswer: '
217 for c, q in zip(contexts, questions_without_context)
218 ]
220 else:
221 questions = list(df[question_column].apply(lambda x: str(x)))
222 empty_prompt_ids = np.where(
223 df[[question_column]].isna().all(axis=1).values
224 )[0]
226 return questions, empty_prompt_ids
228 def _prepare_data_with_prompt_template(self, df: pd.DataFrame, prompt_template: Text) -> Tuple[List[Text], List[int]]:
229 """
230 Prepares the input data with a prompt template.
232 Args:
233 df (pd.DataFrame): The input data to invoke the model with.
234 prompt_template (Text): The base prompt template to use.
236 Returns:
237 Tuple[List[Text], List[int]]: The questions to build the prompts for invoking the Amazon Bedrock API and the empty prompt IDs.
238 """
239 questions, empty_prompt_ids = get_completed_prompts(prompt_template, df)
241 return questions, empty_prompt_ids
243 def _predict_for_default_mode(self, model_id: Text, prompts: List[Text], inference_config: Dict) -> List[Text]:
244 """
245 Makes predictions for the default mode of the Amazon Bedrock handler using the prepared prompts.
247 Args:
248 model_id (Text): The ID of the model in Amazon Bedrock.
249 prompts (List[Text]): The prepared prompts for invoking the Amazon Bedrock API.
250 inference_config (Dict): The inference configuration supported by the Amazon Bedrock API.
252 Returns:
253 List[Text]: The predictions made by the Amazon Bedrock API.
254 """
255 predictions = []
256 bedrock_runtime_client = create_amazon_bedrock_client(
257 'bedrock-runtime',
258 **self.engine_storage.get_connection_args()
259 )
261 for prompt in prompts:
262 response = bedrock_runtime_client.converse(
263 modelId=model_id,
264 messages=[prompt],
265 inferenceConfig=inference_config
266 )
267 predictions.append(
268 response["output"]["message"]["content"][0]["text"]
269 )
271 return predictions
273 def _predict_for_conversational_mode(self, model_id: Text, prompt: List[Text], inference_config: Dict) -> Text:
274 """
275 Makes a prediction for the conversational mode of the Amazon Bedrock handler using the prepared prompt.
277 Args:
278 model_id (Text): The ID of the model in Amazon Bedrock.
279 prompts (List[Text]): Prepared prompts for invoking the Amazon Bedrock API.
280 inference_config (Dict): Inference configuration supported by the Amazon Bedrock API.
282 Returns:
283 Text: The prediction made by the Amazon Bedrock API.
284 """
285 bedrock_runtime_client = create_amazon_bedrock_client(
286 'bedrock-runtime',
287 **self.engine_storage.get_connection_args()
288 )
290 response = bedrock_runtime_client.converse(
291 modelId=model_id,
292 messages=prompt,
293 inferenceConfig=inference_config
294 )
296 return response["output"]["message"]["content"][0]["text"]
298 def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
299 """
300 Get the metadata or arguments of a model.
302 Args:
303 attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'.
305 Returns:
306 pd.DataFrame: Model metadata or model arguments.
307 """
308 args = self.model_storage.json_get('args')
310 if attribute == 'args':
311 del args['handler_model_params']
312 return pd.DataFrame(args.items(), columns=['key', 'value'])
314 elif attribute == 'metadata':
315 model_id = args['handler_model_params']['id']
316 try:
317 bedrock_client = create_amazon_bedrock_client(
318 'bedrock',
319 **self.engine_storage.get_connection_args()
320 )
321 meta = bedrock_client.get_foundation_model(modelIdentifier=model_id)['modelDetails']
322 except Exception as e:
323 meta = {'error': str(e)}
324 return pd.DataFrame(dict(meta).items(), columns=['key', 'value'])
326 else:
327 tables = ['args', 'metadata']
328 return pd.DataFrame(tables, columns=['tables'])