Skip to content

Commit fe269dc

Browse files
committed
Add support to run scripts from API
For more context see this pr AUTOMATIC1111/stable-diffusion-webui#6469
1 parent a2cee02 commit fe269dc

File tree

3 files changed

+203
-44
lines changed

3 files changed

+203
-44
lines changed

README.md

+77
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,83 @@ result4.images[1]
126126
```
127127
![extra_batch_images_2](https://user-images.githubusercontent.com/1288793/200459542-aa8547a0-f6db-436b-bec1-031a93a7b1d4.jpg)
128128

129+
### Scripts support
130+
Scripts from AUTOMATIC1111's Web UI are supported, but there aren't official models that define a script's interface.
131+
132+
To find out the list of arguments that are accepted by a particular script look up the associated python file from
133+
AUTOMATIC1111's repo `scripts/[script_name].py`. Search for its `run(p, **args)` function and the arguments that come
134+
after 'p' is the list of accepted arguments
135+
136+
#### Example for X/Y Plot script:
137+
```
138+
(scripts/xy_grid.py file from AUTOMATIC1111's repo)
139+
140+
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
141+
...
142+
```
143+
List of accepted arguments:
144+
* _x_type_: Index of the axis for X axis. Indexes start from [0: Nothing]
145+
* _x_values_: String of comma-separated values for the X axis
146+
* _y_type_: Index of the axis type for Y axis. As the X axis, indexes start from [0: Nothing]
147+
* _y_values_: String of comma-separated values for the Y axis
148+
* _draw_legend_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
149+
* _include_lone_images_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
150+
* _no_fixed_seeds_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
151+
```
152+
# Available Axis options
153+
XYPlotAvailableScripts = [
154+
"Nothing",
155+
"Seed",
156+
"Var. seed",
157+
"Var. strength",
158+
"Steps",
159+
"CFG Scale",
160+
"Prompt S/R",
161+
"Prompt order",
162+
"Sampler",
163+
"Checkpoint Name",
164+
"Hypernetwork",
165+
"Hypernet str.",
166+
"Sigma Churn",
167+
"Sigma min",
168+
"Sigma max",
169+
"Sigma noise",
170+
"Eta",
171+
"Clip skip",
172+
"Denoising",
173+
"Hires upscaler",
174+
"Cond. Image Mask Weight",
175+
"VAE",
176+
"Styles"
177+
]
178+
179+
# Example call
180+
XAxisType = "Steps"
181+
XAxisValues = "8,16,32,64"
182+
YAxisType = "Sampler"
183+
YAxisValues = "k_euler_a, k_euler, k_lms, plms, k_heun, ddim, k_dpm_2, k_dpm_2_a"
184+
drawLegend = "True"
185+
includeSeparateImages = "False"
186+
keepRandomSeed = "False"
187+
188+
result = api.txt2img(
189+
prompt="cute squirrel",
190+
negative_prompt="ugly, out of frame",
191+
seed=1003,
192+
styles=["anime"],
193+
cfg_scale=7,
194+
script_name="X/Y Plot",
195+
script_args=[
196+
XYPlotAvailableScripts.index(XAxisType),
197+
XAxisValues,
198+
XYPlotAvailableScripts.index(YAxisType),
199+
YAxisValues,
200+
drawLegend,
201+
includeSeparateImages,
202+
keepRandomSeed
203+
]
204+
)
205+
```
129206

130207
### Configuration APIs
131208
```

webuiapi/webuiapi.py

+49-38
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class WebUIApiResult:
1111
images: list
1212
parameters: dict
1313
info: dict
14-
14+
1515
@property
1616
def image(self):
1717
return self.images[0]
@@ -35,28 +35,28 @@ def __init__(self,
3535
baseurl = f'https://{host}:{port}/sdapi/v1'
3636
else:
3737
baseurl = f'http://{host}:{port}/sdapi/v1'
38-
38+
3939
self.baseurl = baseurl
4040
self.default_sampler = sampler
4141
self.default_steps = steps
42-
42+
4343
self.session = requests.Session()
44-
44+
4545
def set_auth(self, username, password):
4646
self.session.auth = (username, password)
47-
47+
4848
def _to_api_result(self, response):
49-
49+
5050
if response.status_code != 200:
5151
raise RuntimeError(response.status_code, response.text)
52-
52+
5353
r = response.json()
5454
images = []
5555
if 'images' in r.keys():
5656
images = [Image.open(io.BytesIO(base64.b64decode(i))) for i in r['images']]
5757
elif 'image' in r.keys():
5858
images = [Image.open(io.BytesIO(base64.b64decode(r['image'])))]
59-
59+
6060
info = ''
6161
if 'info' in r.keys():
6262
try:
@@ -71,8 +71,8 @@ def _to_api_result(self, response):
7171
parameters = r['parameters']
7272

7373
return WebUIApiResult(images, parameters, info)
74-
75-
def txt2img(self,
74+
75+
def txt2img(self,
7676
enable_hr=False,
7777
denoising_strength=0.0,
7878
firstphase_width=0,
@@ -99,18 +99,22 @@ def txt2img(self,
9999
s_noise=1,
100100
override_settings={},
101101
override_settings_restore_afterwards=True,
102-
sampler_name=None, # use this instead of sampler_index
102+
sampler_name=None, # use this instead of sampler_index
103103
sampler_index=None,
104104
steps=None,
105-
):
105+
script_name=None,
106+
script_args=None # List of arguments for the script "script_name"
107+
):
106108
if sampler_index is None:
107109
sampler_index = self.default_sampler
108110
if sampler_name is None:
109111
sampler_name = self.default_sampler
110112
if steps is None:
111113
steps = self.default_steps
114+
if script_args is None:
115+
script_args = []
112116

113-
payload = {
117+
payload = {
114118
"enable_hr": enable_hr,
115119
"denoising_strength": denoising_strength,
116120
"firstphase_width": firstphase_width,
@@ -140,14 +144,15 @@ def txt2img(self,
140144
"override_settings_restore_afterwards": override_settings_restore_afterwards,
141145
"sampler_name": sampler_name,
142146
"sampler_index": sampler_index,
147+
"script_name": script_name,
148+
"script_args": script_args
143149
}
144150
response = self.session.post(url=f'{self.baseurl}/txt2img', json=payload)
145151
return self._to_api_result(response)
146152

147-
148153
def img2img(self,
149-
images=[], # list of PIL Image
150-
mask_image=None, # PIL Image mask
154+
images=[], # list of PIL Image
155+
mask_image=None, # PIL Image mask
151156
resize_mode=0,
152157
denoising_strength=0.75,
153158
mask_blur=4,
@@ -180,15 +185,19 @@ def img2img(self,
180185
override_settings_restore_afterwards=True,
181186
include_init_images=False,
182187
steps=None,
183-
sampler_name=None, # use this instead of sampler_index
188+
sampler_name=None, # use this instead of sampler_index
184189
sampler_index=None,
185-
):
190+
script_name=None,
191+
script_args=None # List of arguments for the script "script_name"
192+
):
186193
if sampler_name is None:
187194
sampler_name = self.default_sampler
188195
if sampler_index is None:
189196
sampler_index = self.default_sampler
190197
if steps is None:
191198
steps = self.default_steps
199+
if script_args is None:
200+
script_args = []
192201

193202
payload = {
194203
"init_images": [b64_img(x) for x in images],
@@ -226,15 +235,17 @@ def img2img(self,
226235
"sampler_name": sampler_name,
227236
"sampler_index": sampler_index,
228237
"include_init_images": include_init_images,
238+
"script_name": script_name,
239+
"script_args": script_args
229240
}
230241
if mask_image is not None:
231-
payload['mask']= b64_img(mask_image)
232-
242+
payload['mask'] = b64_img(mask_image)
243+
233244
response = self.session.post(url=f'{self.baseurl}/img2img', json=payload)
234245
return self._to_api_result(response)
235246

236247
def extra_single_image(self,
237-
image, # PIL Image
248+
image, # PIL Image
238249
resize_mode=0,
239250
show_extras_results=True,
240251
gfpgan_visibility=0,
@@ -248,7 +259,7 @@ def extra_single_image(self,
248259
upscaler_2="None",
249260
extras_upscaler_2_visibility=0,
250261
upscale_first=False,
251-
):
262+
):
252263
payload = {
253264
"resize_mode": resize_mode,
254265
"show_extras_results": show_extras_results,
@@ -265,13 +276,13 @@ def extra_single_image(self,
265276
"upscale_first": upscale_first,
266277
"image": b64_img(image),
267278
}
268-
279+
269280
response = self.session.post(url=f'{self.baseurl}/extra-single-image', json=payload)
270281
return self._to_api_result(response)
271282

272283
def extra_batch_images(self,
273-
images, # list of PIL images
274-
name_list=None, # list of image names
284+
images, # list of PIL images
285+
name_list=None, # list of image names
275286
resize_mode=0,
276287
show_extras_results=True,
277288
gfpgan_visibility=0,
@@ -285,21 +296,21 @@ def extra_batch_images(self,
285296
upscaler_2="None",
286297
extras_upscaler_2_visibility=0,
287298
upscale_first=False,
288-
):
299+
):
289300
if name_list is not None:
290301
if len(name_list) != len(images):
291302
raise RuntimeError('len(images) != len(name_list)')
292303
else:
293-
name_list = [f'image{i+1:05}' for i in range(len(images))]
304+
name_list = [f'image{i + 1:05}' for i in range(len(images))]
294305
images = [b64_img(x) for x in images]
295-
306+
296307
image_list = []
297308
for name, image in zip(name_list, images):
298309
image_list.append({
299310
"data": image,
300311
"name": name
301312
})
302-
313+
303314
payload = {
304315
"resize_mode": resize_mode,
305316
"show_extras_results": show_extras_results,
@@ -316,16 +327,16 @@ def extra_batch_images(self,
316327
"upscale_first": upscale_first,
317328
"imageList": image_list,
318329
}
319-
330+
320331
response = self.session.post(url=f'{self.baseurl}/extra-batch-images', json=payload)
321332
return self._to_api_result(response)
322-
333+
323334
# XXX 500 error (2022/12/26)
324335
def png_info(self, image):
325336
payload = {
326337
"image": b64_img(image),
327338
}
328-
339+
329340
response = self.session.post(url=f'{self.baseurl}/png-info', json=payload)
330341
return self._to_api_result(response)
331342

@@ -334,20 +345,20 @@ def interrogate(self, image):
334345
payload = {
335346
"image": b64_img(image),
336347
}
337-
348+
338349
response = self.session.post(url=f'{self.baseurl}/interrogate', json=payload)
339350
return self._to_api_result(response)
340351

341-
def get_options(self):
352+
def get_options(self):
342353
response = self.session.get(url=f'{self.baseurl}/options')
343354
return response.json()
344355

345356
# working (2022/11/21)
346-
def set_options(self, options):
357+
def set_options(self, options):
347358
response = self.session.post(url=f'{self.baseurl}/options', json=options)
348359
return response.json()
349360

350-
def get_cmd_flags(self):
361+
def get_cmd_flags(self):
351362
response = self.session.get(url=f'{self.baseurl}/cmd-flags')
352363
return response.json()
353364
def get_samplers(self):
@@ -380,7 +391,7 @@ def get_artists(self):
380391
def refresh_checkpoints(self):
381392
response = self.session.post(url=f'{self.baseurl}/refresh-checkpoints')
382393
return response.json()
383-
394+
384395
def get_endpoint(self, endpoint, baseurl):
385396
if baseurl:
386397
return f'{self.baseurl}/{endpoint}'
@@ -431,7 +442,7 @@ def util_get_current_model(self):
431442
return self.get_options()['sd_model_checkpoint']
432443

433444

434-
class Upscaler(str, Enum):
445+
class Upscaler(str, Enum):
435446
none = 'None'
436447
Lanczos = 'Lanczos'
437448
Nearest = 'Nearest'

webuiapi_demo.ipynb

+77-6
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)