Coverage for mindsdb / integrations / handlers / clipdrop_handler / clipdrop_handler.py: 0%

124 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict 

2import pandas as pd 

3 

4from mindsdb.integrations.handlers.clipdrop_handler.clipdrop import ClipdropClient 

5 

6from mindsdb.integrations.libs.base import BaseMLEngine 

7 

8from mindsdb.utilities import log 

9 

10from mindsdb.integrations.utilities.handler_utils import get_api_key 

11 

12 

13logger = log.getLogger(__name__) 

14 

15 

16class ClipdropHandler(BaseMLEngine): 

17 name = "clipdrop" 

18 

19 @staticmethod 

20 def create_validation(target, args=None, **kwargs): 

21 args = args['using'] 

22 

23 available_tasks = ["remove_text", "remove_background", "sketch_to_image", "text_to_image", "replace_background", "reimagine"] 

24 

25 if 'task' not in args: 

26 raise Exception(f'task has to be specified. Available tasks are - {available_tasks}') 

27 

28 if args['task'] not in available_tasks: 

29 raise Exception(f'Unknown task specified. Available tasks are - {available_tasks}') 

30 

31 if 'local_directory_path' not in args: 

32 raise Exception('local_directory_path has to be specified') 

33 

34 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

35 if 'using' not in args: 

36 raise Exception("Clipdrop AI Inference engine requires a USING clause! Refer to its documentation for more details.") 

37 self.generative = True 

38 

39 args = args['using'] 

40 args['target'] = target 

41 self.model_storage.json_set('args', args) 

42 

43 def _get_clipdrop_client(self, args): 

44 api_key = get_api_key('clipdrop', args, self.engine_storage, strict=False) 

45 

46 local_directory_path = args["local_directory_path"] 

47 

48 return ClipdropClient(api_key=api_key, local_dir=local_directory_path) 

49 

50 def _process_remove_text(self, df, args): 

51 

52 def generate_remove_text(conds, client): 

53 conds = conds.to_dict() 

54 return client.remove_text(conds.get("image_url")) 

55 

56 supported_params = set(["image_url"]) 

57 

58 if "image_url" not in df.columns: 

59 raise Exception("`image_url` column has to be given in the query.") 

60 

61 for col in df.columns: 

62 if col not in supported_params: 

63 raise Exception(f"Unknown column {col}. Currently supported parameters for remove text - {supported_params}") 

64 

65 client = self._get_clipdrop_client(args) 

66 

67 return df[df.columns.intersection(supported_params)].apply(generate_remove_text, client=client, axis=1) 

68 

69 def _process_remove_background(self, df, args): 

70 

71 def generate_remove_background(conds, client): 

72 conds = conds.to_dict() 

73 return client.remove_background(conds.get("image_url")) 

74 

75 supported_params = set(["image_url"]) 

76 

77 if "image_url" not in df.columns: 

78 raise Exception("`image_url` column has to be given in the query.") 

79 

80 for col in df.columns: 

81 if col not in supported_params: 

82 raise Exception(f"Unknown column {col}. Currently supported parameters for remove background - {supported_params}") 

83 

84 client = self._get_clipdrop_client(args) 

85 

86 return df[df.columns.intersection(supported_params)].apply(generate_remove_background, client=client, axis=1) 

87 

88 def _process_sketch_to_image(self, df, args): 

89 

90 def generate_sketch_to_image(conds, client): 

91 conds = conds.to_dict() 

92 return client.sketch_to_image(conds.get("image_url"), conds.get("text")) 

93 

94 supported_params = set(["image_url", "text"]) 

95 

96 if "image_url" not in df.columns: 

97 raise Exception("`image_url` column has to be given in the query.") 

98 

99 if "text" not in df.columns: 

100 raise Exception("`text` column has to be given in the query.") 

101 

102 for col in df.columns: 

103 if col not in supported_params: 

104 raise Exception(f"Unknown column {col}. Currently supported parameters for remove background - {supported_params}") 

105 

106 client = self._get_clipdrop_client(args) 

107 

108 return df[df.columns.intersection(supported_params)].apply(generate_sketch_to_image, client=client, axis=1) 

109 

110 def _process_text_to_image(self, df, args): 

111 

112 def generate_text_to_image(conds, client): 

113 conds = conds.to_dict() 

114 return client.text_to_image(conds.get("text")) 

115 

116 supported_params = set(["text"]) 

117 

118 if "text" not in df.columns: 

119 raise Exception("`text` column has to be given in the query.") 

120 

121 for col in df.columns: 

122 if col not in supported_params: 

123 raise Exception(f"Unknown column {col}. Currently supported parameters for remove background - {supported_params}") 

124 

125 client = self._get_clipdrop_client(args) 

126 

127 return df[df.columns.intersection(supported_params)].apply(generate_text_to_image, client=client, axis=1) 

128 

129 def _process_replace_background(self, df, args): 

130 

131 def generate_replace_background(conds, client): 

132 conds = conds.to_dict() 

133 return client.replace_background(conds.get("image_url"), conds.get("text")) 

134 

135 supported_params = set(["image_url", "text"]) 

136 

137 if "image_url" not in df.columns: 

138 raise Exception("`image_url` column has to be given in the query.") 

139 

140 if "text" not in df.columns: 

141 raise Exception("`text` column has to be given in the query.") 

142 

143 for col in df.columns: 

144 if col not in supported_params: 

145 raise Exception(f"Unknown column {col}. Currently supported parameters for replace background - {supported_params}") 

146 

147 client = self._get_clipdrop_client(args) 

148 

149 return df[df.columns.intersection(supported_params)].apply(generate_replace_background, client=client, axis=1) 

150 

151 def _process_reimagine(self, df, args): 

152 

153 def generate_reimagine(conds, client): 

154 conds = conds.to_dict() 

155 return client.reimagine(conds.get("image_url")) 

156 

157 supported_params = set(["image_url"]) 

158 

159 if "image_url" not in df.columns: 

160 raise Exception("`image_url` column has to be given in the query.") 

161 

162 for col in df.columns: 

163 if col not in supported_params: 

164 raise Exception(f"Unknown column {col}. Currently supported parameters for reimagine - {supported_params}") 

165 

166 client = self._get_clipdrop_client(args) 

167 

168 return df[df.columns.intersection(supported_params)].apply(generate_reimagine, client=client, axis=1) 

169 

170 def predict(self, df, args=None): 

171 

172 args = self.model_storage.json_get('args') 

173 

174 if args["task"] == "remove_text": 

175 preds = self._process_remove_text(df, args) 

176 elif args["task"] == "remove_background": 

177 preds = self._process_remove_background(df, args) 

178 elif args["task"] == "sketch_to_image": 

179 preds = self._process_sketch_to_image(df, args) 

180 elif args["task"] == "text_to_image": 

181 preds = self._process_text_to_image(df, args) 

182 elif args["task"] == "replace_background": 

183 preds = self._process_replace_background(df, args) 

184 elif args["task"] == "reimagine": 

185 preds = self._process_reimagine(df, args) 

186 

187 result_df = pd.DataFrame() 

188 

189 result_df['predictions'] = preds 

190 

191 result_df = result_df.rename(columns={'predictions': args['target']}) 

192 

193 return result_df