Coverage for mindsdb / integrations / handlers / replicate_handler / replicate_handler.py: 0%
92 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import replicate
2import pandas as pd
3from mindsdb.integrations.libs.base import BaseMLEngine
4from typing import Dict, Optional
5import os
6import types
7from mindsdb.utilities.config import Config
10class ReplicateHandler(BaseMLEngine):
11 name = "replicate"
13 @staticmethod
14 def create_validation(target, args=None, **kwargs):
15 if 'using' not in args:
16 raise Exception("Replicate engine requires a USING clause! Refer to its documentation for more details.")
17 else:
18 args = args['using']
20 if 'model_name' not in args or 'version' not in args:
21 raise Exception('Both model_name and version must be provided.')
23 # Checking if passed model_name and version are correct or not
24 try:
25 replicate.default_client.api_token = args['api_key']
26 replicate.models.get(args['model_name']).versions.get(args['version'])
28 except Exception as e:
29 if e.args[0].startswith('Not found'):
30 raise Exception(f"Could not retrieve version {args['version']} of model {args['model_name']}. Verify values are correct and try again.", e)
32 elif e.args[0].startswith('Incorrect authentication token'):
33 raise Exception("Provided api_key is Incorrect. Get your api_key here: https://replicate.com/account/api-tokens")
35 else:
36 raise Exception("Error occured.", e)
38 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
39 """Saves model details in storage to access it later
40 """
41 args = args['using']
42 args['target'] = target
43 self.model_storage.json_set('args', args)
45 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
46 """Using replicate makes the prediction according to your parameters
48 Args:
49 df (pd.DataFrame): The input DataFrame containing data to predict.
50 args (Optional[Dict]): Additional arguments for prediction parameters.
52 Returns:
53 pd.DataFrame: The DataFrame containing the predicted results.
54 """
56 # Extracting prediction parameters from input arguments
57 pred_args = args['predict_params'] if args else {}
58 args = self.model_storage.json_get('args')
59 model_name, version, target_col = args['model_name'], args['version'], args['target']
61 # Check if '__mindsdb_row_id' column exists and drop it if present
62 if '__mindsdb_row_id' in df.columns:
63 df.drop(columns=['__mindsdb_row_id'], inplace=True)
65 def get_data(conditions):
66 # Run prediction using MindsDB's replicate library
67 output = replicate.run(
68 f"{args['model_name']}:{args['version']}",
69 input={**conditions.to_dict(), **pred_args} # Unpacking parameters inputted
70 )
71 # Process output based on the model type
72 if isinstance(output, types.GeneratorType) and args.get('model_type') == 'LLM':
73 output = ''.join(list(output)) # If model_type is LLM, make the stream a string
74 elif isinstance(output, types.GeneratorType):
75 output = list(output)[-1] # Getting the final URL if output is a generator of frames URL
76 elif isinstance(output, list) and len(output) > 0:
77 output = output[-1] # Returns generated image for controlNet models as it outputs filter and generated image
78 return output
80 # Check if any wrong parameters are given and raise an exception if necessary
81 params_names = set(df.columns) | set(pred_args)
82 available_params = self._get_schema(only_keys=True)
83 wrong_params = []
84 for i in params_names:
85 if i not in available_params:
86 wrong_params.append(i)
88 if wrong_params:
89 raise Exception(f"""'{wrong_params}' is/are not supported parameter for this model.
90Use DESCRIBE PREDICTOR mindsdb.<model_name>.features; to know about available parameters. OR
91Visit https://replicate.com/{model_name}/versions/{version} to check parameters.""")
93 # Set the Replicate API token for communication with the server
94 replicate.default_client.api_token = self._get_replicate_api_key(args)
96 # Run prediction on the DataFrame rows and format the results into a DataFrame
97 data = df.apply(get_data, axis=1)
98 data = pd.DataFrame(data)
99 data.columns = [target_col]
101 return data
103 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
105 if attribute == "features":
106 return self._get_schema()
108 else:
109 return pd.DataFrame(['features'], columns=['tables'])
111 def _get_replicate_api_key(self, args, strict=True):
112 """
113 API_KEY preference order:
114 1. provided at model creation
115 2. provided at engine creation
116 3. REPLICATE_API_KEY env variable
117 4. replicate.api_key setting in config.json
118 """ # noqa
119 # 1
120 if 'api_key' in args:
121 return args['api_key']
122 # 2
123 connection_args = self.engine_storage.get_connection_args()
124 if 'api_key' in connection_args:
125 return connection_args['api_key']
126 # 3
127 api_key = os.getenv('REPLICATE_API_TOKEN')
128 if api_key is not None:
129 return api_key
130 # 4
131 config = Config()
132 replicate_cfg = config.get('replicate', {})
133 if 'api_key' in replicate_cfg:
134 return replicate_cfg['api_key']
136 if strict:
137 raise Exception('Missing API key "api_key". Either re-create this ML_ENGINE specifying the `api_key` parameter,\
138 or re-create this model and pass the API key with `USING` syntax.')
140 def _get_schema(self, only_keys=False):
141 '''Return parameters list with its description, default value and type,
142 which helps user to customize their prediction '''
144 args = self.model_storage.json_get('args')
145 os.environ['REPLICATE_API_TOKEN'] = self._get_replicate_api_key(args)
146 replicate.default_client.api_token = self._get_replicate_api_key(args)
147 model = replicate.models.get(args['model_name'])
148 version = model.versions.get(args['version'])
149 schema = version.openapi_schema['components']['schemas']['Input']['properties']
151 # returns only list of parameter
152 if only_keys:
153 return schema.keys()
155 for i in list(schema.keys()):
156 for j in list(schema[i].keys()):
157 if j not in ['default', 'description', 'type']:
158 schema[i].pop(j)
160 df = pd.DataFrame(schema).T
161 df = df.reset_index().rename(columns={'index': 'inputs'})
162 return df.fillna('-')