Coverage for mindsdb / integrations / handlers / replicate_handler / replicate_handler.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import replicate 

2import pandas as pd 

3from mindsdb.integrations.libs.base import BaseMLEngine 

4from typing import Dict, Optional 

5import os 

6import types 

7from mindsdb.utilities.config import Config 

8 

9 

10class ReplicateHandler(BaseMLEngine): 

11 name = "replicate" 

12 

13 @staticmethod 

14 def create_validation(target, args=None, **kwargs): 

15 if 'using' not in args: 

16 raise Exception("Replicate engine requires a USING clause! Refer to its documentation for more details.") 

17 else: 

18 args = args['using'] 

19 

20 if 'model_name' not in args or 'version' not in args: 

21 raise Exception('Both model_name and version must be provided.') 

22 

23 # Checking if passed model_name and version are correct or not 

24 try: 

25 replicate.default_client.api_token = args['api_key'] 

26 replicate.models.get(args['model_name']).versions.get(args['version']) 

27 

28 except Exception as e: 

29 if e.args[0].startswith('Not found'): 

30 raise Exception(f"Could not retrieve version {args['version']} of model {args['model_name']}. Verify values are correct and try again.", e) 

31 

32 elif e.args[0].startswith('Incorrect authentication token'): 

33 raise Exception("Provided api_key is Incorrect. Get your api_key here: https://replicate.com/account/api-tokens") 

34 

35 else: 

36 raise Exception("Error occured.", e) 

37 

38 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

39 """Saves model details in storage to access it later 

40 """ 

41 args = args['using'] 

42 args['target'] = target 

43 self.model_storage.json_set('args', args) 

44 

45 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

46 """Using replicate makes the prediction according to your parameters 

47 

48 Args: 

49 df (pd.DataFrame): The input DataFrame containing data to predict. 

50 args (Optional[Dict]): Additional arguments for prediction parameters. 

51 

52 Returns: 

53 pd.DataFrame: The DataFrame containing the predicted results. 

54 """ 

55 

56 # Extracting prediction parameters from input arguments 

57 pred_args = args['predict_params'] if args else {} 

58 args = self.model_storage.json_get('args') 

59 model_name, version, target_col = args['model_name'], args['version'], args['target'] 

60 

61 # Check if '__mindsdb_row_id' column exists and drop it if present 

62 if '__mindsdb_row_id' in df.columns: 

63 df.drop(columns=['__mindsdb_row_id'], inplace=True) 

64 

65 def get_data(conditions): 

66 # Run prediction using MindsDB's replicate library 

67 output = replicate.run( 

68 f"{args['model_name']}:{args['version']}", 

69 input={**conditions.to_dict(), **pred_args} # Unpacking parameters inputted 

70 ) 

71 # Process output based on the model type 

72 if isinstance(output, types.GeneratorType) and args.get('model_type') == 'LLM': 

73 output = ''.join(list(output)) # If model_type is LLM, make the stream a string 

74 elif isinstance(output, types.GeneratorType): 

75 output = list(output)[-1] # Getting the final URL if output is a generator of frames URL 

76 elif isinstance(output, list) and len(output) > 0: 

77 output = output[-1] # Returns generated image for controlNet models as it outputs filter and generated image 

78 return output 

79 

80 # Check if any wrong parameters are given and raise an exception if necessary 

81 params_names = set(df.columns) | set(pred_args) 

82 available_params = self._get_schema(only_keys=True) 

83 wrong_params = [] 

84 for i in params_names: 

85 if i not in available_params: 

86 wrong_params.append(i) 

87 

88 if wrong_params: 

89 raise Exception(f"""'{wrong_params}' is/are not supported parameter for this model. 

90Use DESCRIBE PREDICTOR mindsdb.<model_name>.features; to know about available parameters. OR 

91Visit https://replicate.com/{model_name}/versions/{version} to check parameters.""") 

92 

93 # Set the Replicate API token for communication with the server 

94 replicate.default_client.api_token = self._get_replicate_api_key(args) 

95 

96 # Run prediction on the DataFrame rows and format the results into a DataFrame 

97 data = df.apply(get_data, axis=1) 

98 data = pd.DataFrame(data) 

99 data.columns = [target_col] 

100 

101 return data 

102 

103 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

104 

105 if attribute == "features": 

106 return self._get_schema() 

107 

108 else: 

109 return pd.DataFrame(['features'], columns=['tables']) 

110 

111 def _get_replicate_api_key(self, args, strict=True): 

112 """  

113 API_KEY preference order: 

114 1. provided at model creation 

115 2. provided at engine creation 

116 3. REPLICATE_API_KEY env variable 

117 4. replicate.api_key setting in config.json 

118 """ # noqa 

119 # 1 

120 if 'api_key' in args: 

121 return args['api_key'] 

122 # 2 

123 connection_args = self.engine_storage.get_connection_args() 

124 if 'api_key' in connection_args: 

125 return connection_args['api_key'] 

126 # 3 

127 api_key = os.getenv('REPLICATE_API_TOKEN') 

128 if api_key is not None: 

129 return api_key 

130 # 4 

131 config = Config() 

132 replicate_cfg = config.get('replicate', {}) 

133 if 'api_key' in replicate_cfg: 

134 return replicate_cfg['api_key'] 

135 

136 if strict: 

137 raise Exception('Missing API key "api_key". Either re-create this ML_ENGINE specifying the `api_key` parameter,\ 

138 or re-create this model and pass the API key with `USING` syntax.') 

139 

140 def _get_schema(self, only_keys=False): 

141 '''Return parameters list with its description, default value and type, 

142 which helps user to customize their prediction ''' 

143 

144 args = self.model_storage.json_get('args') 

145 os.environ['REPLICATE_API_TOKEN'] = self._get_replicate_api_key(args) 

146 replicate.default_client.api_token = self._get_replicate_api_key(args) 

147 model = replicate.models.get(args['model_name']) 

148 version = model.versions.get(args['version']) 

149 schema = version.openapi_schema['components']['schemas']['Input']['properties'] 

150 

151 # returns only list of parameter 

152 if only_keys: 

153 return schema.keys() 

154 

155 for i in list(schema.keys()): 

156 for j in list(schema[i].keys()): 

157 if j not in ['default', 'description', 'type']: 

158 schema[i].pop(j) 

159 

160 df = pd.DataFrame(schema).T 

161 df = df.reset_index().rename(columns={'index': 'inputs'}) 

162 return df.fillna('-')