Coverage for mindsdb / integrations / handlers / stabilityai_handler / stabilityai_handler.py: 0%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict 

2import pandas as pd 

3 

4from mindsdb.integrations.handlers.stabilityai_handler.stabilityai import StabilityAPIClient 

5 

6from mindsdb.integrations.libs.base import BaseMLEngine 

7 

8from mindsdb.utilities import log 

9 

10from mindsdb.integrations.utilities.handler_utils import get_api_key 

11 

12 

13logger = log.getLogger(__name__) 

14 

15 

16class StabilityAIHandler(BaseMLEngine): 

17 name = "stabilityai" 

18 

19 @staticmethod 

20 def create_validation(target, args=None, **kwargs): 

21 args = args['using'] 

22 

23 available_tasks = ["text-to-image", "image-to-image", "image-upscaling", "image-masking"] 

24 

25 if 'task' not in args: 

26 raise Exception('task has to be specified. Available tasks are - ' + available_tasks) 

27 

28 if args['task'] not in available_tasks: 

29 raise Exception('Unknown task specified. Available tasks are - ' + available_tasks) 

30 

31 if 'local_directory_path' not in args: 

32 raise Exception('local_directory_path has to be specified') 

33 

34 client = StabilityAPIClient(args["stabilityai_api_key"], args["local_directory_path"]) 

35 

36 if "engine_id" in args: 

37 if not client._is_valid_engine(args["engine_id"]): 

38 raise Exception("Unknown engine. The available engines are - " + list(client.available_engines.keys())) 

39 else: 

40 args["engine_id"] = "stable-diffusion-xl-1024-v1-0" 

41 

42 if "upscale_engine_id" in args: 

43 if not client._is_valid_engine(args["upscale_engine_id"]): 

44 raise Exception("Unknown engine. The available engines are - " + list(client.available_engines.keys())) 

45 else: 

46 args["upscale_engine_id"] = "esrgan-v1-x2plus" 

47 

48 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

49 if 'using' not in args: 

50 raise Exception("Stability AI Inference engine requires a USING clause! Refer to its documentation for more details.") 

51 self.generative = True 

52 

53 args = args['using'] 

54 args['target'] = target 

55 self.model_storage.json_set('args', args) 

56 

57 def _get_stability_client(self, args): 

58 api_key = get_api_key('stabilityai', args["using"], self.engine_storage, strict=False) 

59 

60 local_directory_path = args["local_directory_path"] 

61 engine_id = args.get('engine_id', "stable-diffusion-xl-1024-v1-0") 

62 

63 return StabilityAPIClient(api_key=api_key, dir_to_save=local_directory_path, engine=engine_id) 

64 

65 def _process_text_image(self, df, args): 

66 

67 def generate_text_image(conds, client): 

68 conds = conds.to_dict() 

69 return client.text_to_image(prompt=conds.get("text"), height=conds.get("height", 1024), width=conds.get("width", 1024)) 

70 

71 supported_params = set(["text", "height", "width"]) 

72 

73 if "text" not in df.columns: 

74 raise Exception("`text` column has to be given in the query.") 

75 

76 for col in df.columns: 

77 if col not in supported_params: 

78 raise Exception(f"Unknown column {col}. Currently supported parameters for text to image - {supported_params}") 

79 

80 client = self._get_stability_client(args) 

81 

82 return df[df.columns.intersection(supported_params)].apply(generate_text_image, client=client, axis=1) 

83 

84 def _process_image_image(self, df, args): 

85 

86 def generate_image_image(conds, client): 

87 conds = conds.to_dict() 

88 return client.image_to_image(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height", 1024), width=conds.get("width", 1024)) 

89 

90 supported_params = set(["image_url", "text", "height", "width"]) 

91 

92 if "image_url" not in df.columns: 

93 raise Exception("`image_url` column has to be given in the query.") 

94 

95 for col in df.columns: 

96 if col not in supported_params: 

97 raise Exception(f"Unknown column {col}. Currently supported parameters for image to image - {supported_params}") 

98 

99 client = self._get_stability_client(args) 

100 

101 return df[df.columns.intersection(supported_params)].apply(generate_image_image, client=client, axis=1) 

102 

103 def _process_image_upscaling(self, df, args): 

104 

105 def generate_image_upscaling(conds, client): 

106 conds = conds.to_dict() 

107 return client.image_upscaling(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height"), width=conds.get("width")) 

108 

109 supported_params = set(["image_url", "text", "height", "width"]) 

110 

111 if "image_url" not in df.columns: 

112 raise Exception("`image_url` column has to be given in the query.") 

113 

114 for col in df.columns: 

115 if col not in supported_params: 

116 raise Exception(f"Unknown column {col}. Currently supported parameters for image scaling - {supported_params}") 

117 

118 client = self._get_stability_client(args) 

119 

120 return df[df.columns.intersection(supported_params)].apply(generate_image_upscaling, client=client, axis=1) 

121 

122 def _process_image_masking(self, df, args): 

123 

124 def generate_image_mask(conds, client): 

125 conds = conds.to_dict() 

126 return client.image_to_image(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height"), width=conds.get("width"), mask_image_url=conds.get("mask_image_url")) 

127 

128 supported_params = set(["image_url", "text", "height", "width", "mask_image_url"]) 

129 

130 if "image_url" not in df.columns: 

131 raise Exception("`image_url` column has to be given in the query.") 

132 

133 if "mask_image_url" not in df.columns: 

134 raise Exception("`mask_image_url` column has to be given in the query.") 

135 

136 for col in df.columns: 

137 if col not in supported_params: 

138 raise Exception(f"Unknown column {col}. Currently supported parameters for image masking - {supported_params}") 

139 

140 client = self._get_stability_client(args) 

141 

142 return df[df.columns.intersection(supported_params)].apply(generate_image_mask, client=client, axis=1) 

143 

144 def predict(self, df, args=None): 

145 

146 args = self.model_storage.json_get('args') 

147 

148 if args["task"] == "text-to-image": 

149 preds = self._process_text_image(df, args) 

150 elif args["task"] == "image-to-image": 

151 preds = self._process_image_image(df, args) 

152 elif args["task"] == "image-upscaling": 

153 preds = self._process_image_upscaling(df, args) 

154 elif args["task"] == "image-masking": 

155 preds = self._process_image_masking(df, args) 

156 

157 result_df = pd.DataFrame() 

158 

159 result_df['predictions'] = preds 

160 

161 result_df = result_df.rename(columns={'predictions': args['target']}) 

162 

163 return result_df 

164 

165 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

166 args = self.model_storage.json_get('args') 

167 client = StabilityAPIClient(args["stabilityai_api_key"], "") 

168 engine_id = args["engine_id"] 

169 upscale_engine_id = args["upscale_engine_id"] 

170 engine_id_res = client.available_engines.get(engine_id) 

171 upscale_engine_id_res = client.available_engines.get(upscale_engine_id) 

172 return pd.json_normalize([engine_id_res, upscale_engine_id_res])