Coverage for mindsdb / integrations / handlers / stabilityai_handler / stabilityai.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from stability_sdk import client 

2from stability_sdk.client import generation 

3from PIL import Image 

4import io 

5import requests 

6 

7 

8class StabilityAPIClient: 

9 

10 def _init_(self, api_key, dir_to_save, engine="stable-diffusion-xl-1024-v1-0", upscale_engine="esrgan-v1-x2plus"): 

11 """Initialize the stability wrapper api client. 

12 

13 Args: 

14 api_key: Stability AI API Key 

15 dir_to_save: The local directory path to save the response images from the API 

16 engine (str, optional): Stability AI engine to use. Defaults to "stable-diffusion-xl-1024-v1-0". 

17 upscale_engine (str, optional): Stability AI upscaling engine to use. Defaults to "esrgan-v1-x2plus". 

18 

19 Raises: 

20 ValueError: For unknown engine or upscale engine 

21 """ 

22 self.api_key = api_key 

23 self.STABILITY_HOST = 'grpc.stability.ai:443' 

24 self.save_dir = dir_to_save + "/" if not dir_to_save.endswith("/") else dir_to_save 

25 self.available_engines = self.get_existing_engines() 

26 if not self._is_valid_engine(engine): 

27 raise ValueError("Unknown engine. The available engines are - " + self.available_engines) 

28 if not self._is_valid_engine(upscale_engine): 

29 raise ValueError("Unknown upscale engine. The available engines are - " + self.available_engines) 

30 

31 self.stability_api = client.StabilityInference(host=self.STABILITY_HOST, 

32 key=self.api_key, 

33 engine=engine, 

34 upscale_engine=upscale_engine, 

35 verbose=True) 

36 

37 def save_image(self, artifact): 

38 """Save the binary image in the artifact to the local directory 

39 

40 Args: 

41 artifact: Artifact returned by the API 

42 

43 Returns: 

44 The local image path. 

45 """ 

46 img = Image.open(io.BytesIO(artifact.binary)) 

47 path = self.save_dir + str(artifact.seed) + ".png" 

48 img.save(path) 

49 return path 

50 

51 def _process_artifacts(self, artifacts): 

52 """Process the artificats returned by the API 

53 

54 Args: 

55 artifacts :Artifact returned by the API 

56 

57 Returns: 

58 The saved image paths 

59 """ 

60 saved_image_paths = [] 

61 for resp in artifacts: 

62 for artifact in resp.artifacts: 

63 if artifact.finish_reason == generation.FILTER: 

64 saved_image_paths.append( 

65 "Your request activated the API's safety filters \ 

66 and could not be processed. Please modify the prompt and try again.") 

67 if artifact.type == generation.ARTIFACT_IMAGE: 

68 saved_image_paths.append(self.save_image(artifact)) 

69 return saved_image_paths 

70 

71 def get_existing_engines(self): 

72 """Get the existing engines 

73 

74 Returns: Engines supported by the API 

75 """ 

76 url = "https://api.stability.ai/v1/engines/list" 

77 if self.api_key is None: 

78 raise Exception("Missing Stability API key.") 

79 

80 response = requests.get(url, headers={ 

81 "Authorization": f"Bearer {self.api_key}" 

82 }) 

83 

84 if response.status_code != 200: 

85 raise Exception("Non-200 response: " + str(response.text)) 

86 

87 payload = response.json() 

88 

89 return {engine["id"]: engine for engine in payload} 

90 

91 def _is_valid_engine(self, engine_id): 

92 """Validates the given engine id against the supported engines by Stability 

93 

94 Args: 

95 engine_id: The engine id to check 

96 

97 Returns: 

98 True if valid engine else False 

99 """ 

100 return engine_id in self.available_engines 

101 

102 def _read_image_url(self, image_url): 

103 """Downloads the given image url. 

104 

105 Args: 

106 image_url: The image url to download 

107 

108 Returns: 

109 Downloaded image 

110 """ 

111 return Image.open(requests.get(image_url, stream=True).raw) 

112 

113 def text_to_image(self, prompt, height=1024, width=1024): 

114 """Converts the given text to image using stability API. 

115 

116 Args: 

117 prompt: The given text 

118 height (int, optional): Height of the image. Defaults to 1024. 

119 width (int, optional): Width of the image. Defaults to 1024. 

120 

121 Returns: 

122 The local saved paths of the generated images 

123 """ 

124 answers = self.stability_api.generate( 

125 prompt=prompt, 

126 height=height, 

127 width=width 

128 ) 

129 

130 saved_images = self._process_artifacts(answers) 

131 

132 return saved_images 

133 

134 def image_to_image(self, image_url, prompt=None, 

135 height=1024, width=1024, 

136 mask_image_url=None): 

137 """Image to Image inpainting + masking 

138 

139 Args: 

140 image_url: The image url 

141 prompt: The given text 

142 height (int, optional): Height of the image. Defaults to 1024. 

143 width (int, optional): Width of the image. Defaults to 1024. 

144 mask_image_url (string, optional): The mask image url for masking. Defaults to None. 

145 

146 Returns: 

147 The local saved paths of the generated images 

148 """ 

149 img = self._read_image_url(image_url) 

150 mask_img = None if (mask_image_url is None) else self._read_image_url(mask_image_url) 

151 if prompt is None: 

152 prompt = "" 

153 answers = self.stability_api.generate(prompt=prompt, init_image=img, mask_image=mask_img, width=width, height=height) 

154 

155 saved_images = self._process_artifacts(answers) 

156 

157 return saved_images 

158 

159 def image_upscaling(self, image_url, height=None, width=None, prompt=None): 

160 """Image upscaling 

161 

162 Args: 

163 image_url: The image url 

164 height (int, optional): Height of the image. Defaults to None. 

165 width (int, optional): Width of the image. Defaults to None. 

166 prompt: The given text 

167 

168 Returns: 

169 The local saved paths of the generated images 

170 """ 

171 img = self._read_image_url(image_url) 

172 if height is not None and width is not None: 

173 raise Exception("Either height or width can be given. Refer - https://platform.stability.ai/docs/features/image-upscaling#initial-generation-parameters") 

174 

175 if height is None and width is None: 

176 answers = self.stability_api.upscale(init_image=img, prompt=prompt) 

177 

178 if height is not None: 

179 answers = self.stability_api.upscale(init_image=img, height=height, prompt=prompt) 

180 

181 if width is not None: 

182 answers = self.stability_api.upscale(init_image=img, width=width, prompt=prompt) 

183 

184 saved_images = self._process_artifacts(answers) 

185 

186 return saved_images