Coverage for mindsdb / integrations / handlers / stabilityai_handler / stabilityai_handler.py: 0%
112 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional, Dict
2import pandas as pd
4from mindsdb.integrations.handlers.stabilityai_handler.stabilityai import StabilityAPIClient
6from mindsdb.integrations.libs.base import BaseMLEngine
8from mindsdb.utilities import log
10from mindsdb.integrations.utilities.handler_utils import get_api_key
13logger = log.getLogger(__name__)
16class StabilityAIHandler(BaseMLEngine):
17 name = "stabilityai"
19 @staticmethod
20 def create_validation(target, args=None, **kwargs):
21 args = args['using']
23 available_tasks = ["text-to-image", "image-to-image", "image-upscaling", "image-masking"]
25 if 'task' not in args:
26 raise Exception('task has to be specified. Available tasks are - ' + available_tasks)
28 if args['task'] not in available_tasks:
29 raise Exception('Unknown task specified. Available tasks are - ' + available_tasks)
31 if 'local_directory_path' not in args:
32 raise Exception('local_directory_path has to be specified')
34 client = StabilityAPIClient(args["stabilityai_api_key"], args["local_directory_path"])
36 if "engine_id" in args:
37 if not client._is_valid_engine(args["engine_id"]):
38 raise Exception("Unknown engine. The available engines are - " + list(client.available_engines.keys()))
39 else:
40 args["engine_id"] = "stable-diffusion-xl-1024-v1-0"
42 if "upscale_engine_id" in args:
43 if not client._is_valid_engine(args["upscale_engine_id"]):
44 raise Exception("Unknown engine. The available engines are - " + list(client.available_engines.keys()))
45 else:
46 args["upscale_engine_id"] = "esrgan-v1-x2plus"
48 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
49 if 'using' not in args:
50 raise Exception("Stability AI Inference engine requires a USING clause! Refer to its documentation for more details.")
51 self.generative = True
53 args = args['using']
54 args['target'] = target
55 self.model_storage.json_set('args', args)
57 def _get_stability_client(self, args):
58 api_key = get_api_key('stabilityai', args["using"], self.engine_storage, strict=False)
60 local_directory_path = args["local_directory_path"]
61 engine_id = args.get('engine_id', "stable-diffusion-xl-1024-v1-0")
63 return StabilityAPIClient(api_key=api_key, dir_to_save=local_directory_path, engine=engine_id)
65 def _process_text_image(self, df, args):
67 def generate_text_image(conds, client):
68 conds = conds.to_dict()
69 return client.text_to_image(prompt=conds.get("text"), height=conds.get("height", 1024), width=conds.get("width", 1024))
71 supported_params = set(["text", "height", "width"])
73 if "text" not in df.columns:
74 raise Exception("`text` column has to be given in the query.")
76 for col in df.columns:
77 if col not in supported_params:
78 raise Exception(f"Unknown column {col}. Currently supported parameters for text to image - {supported_params}")
80 client = self._get_stability_client(args)
82 return df[df.columns.intersection(supported_params)].apply(generate_text_image, client=client, axis=1)
84 def _process_image_image(self, df, args):
86 def generate_image_image(conds, client):
87 conds = conds.to_dict()
88 return client.image_to_image(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height", 1024), width=conds.get("width", 1024))
90 supported_params = set(["image_url", "text", "height", "width"])
92 if "image_url" not in df.columns:
93 raise Exception("`image_url` column has to be given in the query.")
95 for col in df.columns:
96 if col not in supported_params:
97 raise Exception(f"Unknown column {col}. Currently supported parameters for image to image - {supported_params}")
99 client = self._get_stability_client(args)
101 return df[df.columns.intersection(supported_params)].apply(generate_image_image, client=client, axis=1)
103 def _process_image_upscaling(self, df, args):
105 def generate_image_upscaling(conds, client):
106 conds = conds.to_dict()
107 return client.image_upscaling(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height"), width=conds.get("width"))
109 supported_params = set(["image_url", "text", "height", "width"])
111 if "image_url" not in df.columns:
112 raise Exception("`image_url` column has to be given in the query.")
114 for col in df.columns:
115 if col not in supported_params:
116 raise Exception(f"Unknown column {col}. Currently supported parameters for image scaling - {supported_params}")
118 client = self._get_stability_client(args)
120 return df[df.columns.intersection(supported_params)].apply(generate_image_upscaling, client=client, axis=1)
122 def _process_image_masking(self, df, args):
124 def generate_image_mask(conds, client):
125 conds = conds.to_dict()
126 return client.image_to_image(image_url=conds.get("image_url"), prompt=conds.get("text"), height=conds.get("height"), width=conds.get("width"), mask_image_url=conds.get("mask_image_url"))
128 supported_params = set(["image_url", "text", "height", "width", "mask_image_url"])
130 if "image_url" not in df.columns:
131 raise Exception("`image_url` column has to be given in the query.")
133 if "mask_image_url" not in df.columns:
134 raise Exception("`mask_image_url` column has to be given in the query.")
136 for col in df.columns:
137 if col not in supported_params:
138 raise Exception(f"Unknown column {col}. Currently supported parameters for image masking - {supported_params}")
140 client = self._get_stability_client(args)
142 return df[df.columns.intersection(supported_params)].apply(generate_image_mask, client=client, axis=1)
144 def predict(self, df, args=None):
146 args = self.model_storage.json_get('args')
148 if args["task"] == "text-to-image":
149 preds = self._process_text_image(df, args)
150 elif args["task"] == "image-to-image":
151 preds = self._process_image_image(df, args)
152 elif args["task"] == "image-upscaling":
153 preds = self._process_image_upscaling(df, args)
154 elif args["task"] == "image-masking":
155 preds = self._process_image_masking(df, args)
157 result_df = pd.DataFrame()
159 result_df['predictions'] = preds
161 result_df = result_df.rename(columns={'predictions': args['target']})
163 return result_df
165 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
166 args = self.model_storage.json_get('args')
167 client = StabilityAPIClient(args["stabilityai_api_key"], "")
168 engine_id = args["engine_id"]
169 upscale_engine_id = args["upscale_engine_id"]
170 engine_id_res = client.available_engines.get(engine_id)
171 upscale_engine_id_res = client.available_engines.get(upscale_engine_id)
172 return pd.json_normalize([engine_id_res, upscale_engine_id_res])