Coverage for mindsdb / integrations / handlers / stabilityai_handler / stabilityai.py: 0%
67 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from stability_sdk import client
2from stability_sdk.client import generation
3from PIL import Image
4import io
5import requests
8class StabilityAPIClient:
10 def _init_(self, api_key, dir_to_save, engine="stable-diffusion-xl-1024-v1-0", upscale_engine="esrgan-v1-x2plus"):
11 """Initialize the stability wrapper api client.
13 Args:
14 api_key: Stability AI API Key
15 dir_to_save: The local directory path to save the response images from the API
16 engine (str, optional): Stability AI engine to use. Defaults to "stable-diffusion-xl-1024-v1-0".
17 upscale_engine (str, optional): Stability AI upscaling engine to use. Defaults to "esrgan-v1-x2plus".
19 Raises:
20 ValueError: For unknown engine or upscale engine
21 """
22 self.api_key = api_key
23 self.STABILITY_HOST = 'grpc.stability.ai:443'
24 self.save_dir = dir_to_save + "/" if not dir_to_save.endswith("/") else dir_to_save
25 self.available_engines = self.get_existing_engines()
26 if not self._is_valid_engine(engine):
27 raise ValueError("Unknown engine. The available engines are - " + self.available_engines)
28 if not self._is_valid_engine(upscale_engine):
29 raise ValueError("Unknown upscale engine. The available engines are - " + self.available_engines)
31 self.stability_api = client.StabilityInference(host=self.STABILITY_HOST,
32 key=self.api_key,
33 engine=engine,
34 upscale_engine=upscale_engine,
35 verbose=True)
37 def save_image(self, artifact):
38 """Save the binary image in the artifact to the local directory
40 Args:
41 artifact: Artifact returned by the API
43 Returns:
44 The local image path.
45 """
46 img = Image.open(io.BytesIO(artifact.binary))
47 path = self.save_dir + str(artifact.seed) + ".png"
48 img.save(path)
49 return path
51 def _process_artifacts(self, artifacts):
52 """Process the artificats returned by the API
54 Args:
55 artifacts :Artifact returned by the API
57 Returns:
58 The saved image paths
59 """
60 saved_image_paths = []
61 for resp in artifacts:
62 for artifact in resp.artifacts:
63 if artifact.finish_reason == generation.FILTER:
64 saved_image_paths.append(
65 "Your request activated the API's safety filters \
66 and could not be processed. Please modify the prompt and try again.")
67 if artifact.type == generation.ARTIFACT_IMAGE:
68 saved_image_paths.append(self.save_image(artifact))
69 return saved_image_paths
71 def get_existing_engines(self):
72 """Get the existing engines
74 Returns: Engines supported by the API
75 """
76 url = "https://api.stability.ai/v1/engines/list"
77 if self.api_key is None:
78 raise Exception("Missing Stability API key.")
80 response = requests.get(url, headers={
81 "Authorization": f"Bearer {self.api_key}"
82 })
84 if response.status_code != 200:
85 raise Exception("Non-200 response: " + str(response.text))
87 payload = response.json()
89 return {engine["id"]: engine for engine in payload}
91 def _is_valid_engine(self, engine_id):
92 """Validates the given engine id against the supported engines by Stability
94 Args:
95 engine_id: The engine id to check
97 Returns:
98 True if valid engine else False
99 """
100 return engine_id in self.available_engines
102 def _read_image_url(self, image_url):
103 """Downloads the given image url.
105 Args:
106 image_url: The image url to download
108 Returns:
109 Downloaded image
110 """
111 return Image.open(requests.get(image_url, stream=True).raw)
113 def text_to_image(self, prompt, height=1024, width=1024):
114 """Converts the given text to image using stability API.
116 Args:
117 prompt: The given text
118 height (int, optional): Height of the image. Defaults to 1024.
119 width (int, optional): Width of the image. Defaults to 1024.
121 Returns:
122 The local saved paths of the generated images
123 """
124 answers = self.stability_api.generate(
125 prompt=prompt,
126 height=height,
127 width=width
128 )
130 saved_images = self._process_artifacts(answers)
132 return saved_images
134 def image_to_image(self, image_url, prompt=None,
135 height=1024, width=1024,
136 mask_image_url=None):
137 """Image to Image inpainting + masking
139 Args:
140 image_url: The image url
141 prompt: The given text
142 height (int, optional): Height of the image. Defaults to 1024.
143 width (int, optional): Width of the image. Defaults to 1024.
144 mask_image_url (string, optional): The mask image url for masking. Defaults to None.
146 Returns:
147 The local saved paths of the generated images
148 """
149 img = self._read_image_url(image_url)
150 mask_img = None if (mask_image_url is None) else self._read_image_url(mask_image_url)
151 if prompt is None:
152 prompt = ""
153 answers = self.stability_api.generate(prompt=prompt, init_image=img, mask_image=mask_img, width=width, height=height)
155 saved_images = self._process_artifacts(answers)
157 return saved_images
159 def image_upscaling(self, image_url, height=None, width=None, prompt=None):
160 """Image upscaling
162 Args:
163 image_url: The image url
164 height (int, optional): Height of the image. Defaults to None.
165 width (int, optional): Width of the image. Defaults to None.
166 prompt: The given text
168 Returns:
169 The local saved paths of the generated images
170 """
171 img = self._read_image_url(image_url)
172 if height is not None and width is not None:
173 raise Exception("Either height or width can be given. Refer - https://platform.stability.ai/docs/features/image-upscaling#initial-generation-parameters")
175 if height is None and width is None:
176 answers = self.stability_api.upscale(init_image=img, prompt=prompt)
178 if height is not None:
179 answers = self.stability_api.upscale(init_image=img, height=height, prompt=prompt)
181 if width is not None:
182 answers = self.stability_api.upscale(init_image=img, width=width, prompt=prompt)
184 saved_images = self._process_artifacts(answers)
186 return saved_images