34
34
import piexif .helper
35
35
from contextlib import closing
36
36
37
+ from modules import initialize_util
38
+ from modules import script_callbacks
39
+ from modules import initialize
40
+
37
41
from ray import serve
38
42
39
43
app = FastAPI ()
@@ -199,35 +203,26 @@ async def fastapi_exception_handler(request: Request, e: Exception):
199
203
async def http_exception_handler (request : Request , e : HTTPException ):
200
204
return handle_exception (request , e )
201
205
202
- from ray import serve
203
- import ray
204
-
205
-
206
- @serve .deployment (
207
- ray_actor_options = {"num_gpus" : 1 },
208
- autoscaling_config = {"min_replicas" : 0 , "max_replicas" : 2 },
209
- #route_prefix="/sdapi/v1",
210
- )
206
+ @serve .deployment (
207
+ ray_actor_options = {"num_gpus" : 1 },
208
+ autoscaling_config = {"min_replicas" : 0 , "max_replicas" : 2 },
209
+ )
211
210
@serve .ingress (app )
212
- class Api :
213
- def __init__ (self , app : FastAPI ):
211
+ class Raypi :
212
+ def __init__ (self ):
214
213
if shared .cmd_opts .api_auth :
215
214
self .credentials = {}
216
215
for auth in shared .cmd_opts .api_auth .split ("," ):
217
216
user , password = auth .split (":" )
218
217
self .credentials [user ] = password
219
-
220
- self .app = app
221
- #self.queue_lock = queue_lock
222
- api_middleware (self .app )
223
- self .launch_ray ()
218
+
219
+
224
220
print ("API initialized" )
225
221
226
222
self .default_script_arg_txt2img = []
227
223
self .default_script_arg_img2img = []
228
224
229
225
230
-
231
226
def auth (self , credentials : HTTPBasicCredentials = Depends (HTTPBasic ())):
232
227
if credentials .username in self .credentials :
233
228
if compare_digest (credentials .password , self .credentials [credentials .username ]):
@@ -337,24 +332,23 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
337
332
send_images = args .pop ('send_images' , True )
338
333
args .pop ('save_images' , None )
339
334
340
- with self .queue_lock :
341
- with closing (StableDiffusionProcessingTxt2Img (sd_model = shared .sd_model , ** args )) as p :
342
- p .is_api = True
343
- p .scripts = script_runner
344
- p .outpath_grids = opts .outdir_txt2img_grids
345
- p .outpath_samples = opts .outdir_txt2img_samples
346
-
347
- try :
348
- shared .state .begin (job = "scripts_txt2img" )
349
- if selectable_scripts is not None :
350
- p .script_args = script_args
351
- processed = scripts .scripts_txt2img .run (p , * p .script_args ) # Need to pass args as list here
352
- else :
353
- p .script_args = tuple (script_args ) # Need to pass args as tuple here
354
- processed = process_images (p )
355
- finally :
356
- shared .state .end ()
357
- shared .total_tqdm .clear ()
335
+
336
+ with closing (StableDiffusionProcessingTxt2Img (sd_model = shared .sd_model , ** args )) as p :
337
+ p .is_api = True
338
+ p .scripts = script_runner
339
+ p .outpath_grids = opts .outdir_txt2img_grids
340
+ p .outpath_samples = opts .outdir_txt2img_samples
341
+ try :
342
+ shared .state .begin (job = "scripts_txt2img" )
343
+ if selectable_scripts is not None :
344
+ p .script_args = script_args
345
+ processed = scripts .scripts_txt2img .run (p , * p .script_args ) # Need to pass args as list here
346
+ else :
347
+ p .script_args = tuple (script_args ) # Need to pass args as tuple here
348
+ processed = process_images (p )
349
+ finally :
350
+ shared .state .end ()
351
+ shared .total_tqdm .clear ()
358
352
359
353
b64images = list (map (encode_pil_to_base64 , processed .images )) if send_images else []
360
354
@@ -398,25 +392,24 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
398
392
send_images = args .pop ('send_images' , True )
399
393
args .pop ('save_images' , None )
400
394
401
- with self .queue_lock :
402
- with closing (StableDiffusionProcessingImg2Img (sd_model = shared .sd_model , ** args )) as p :
403
- p .init_images = [decode_base64_to_image (x ) for x in init_images ]
404
- p .is_api = True
405
- p .scripts = script_runner
406
- p .outpath_grids = opts .outdir_img2img_grids
407
- p .outpath_samples = opts .outdir_img2img_samples
408
-
409
- try :
410
- shared .state .begin (job = "scripts_img2img" )
411
- if selectable_scripts is not None :
412
- p .script_args = script_args
413
- processed = scripts .scripts_img2img .run (p , * p .script_args ) # Need to pass args as list here
414
- else :
415
- p .script_args = tuple (script_args ) # Need to pass args as tuple here
416
- processed = process_images (p )
417
- finally :
418
- shared .state .end ()
419
- shared .total_tqdm .clear ()
395
+
396
+ with closing (StableDiffusionProcessingImg2Img (sd_model = shared .sd_model , ** args )) as p :
397
+ p .init_images = [decode_base64_to_image (x ) for x in init_images ]
398
+ p .is_api = True
399
+ p .scripts = script_runner
400
+ p .outpath_grids = opts .outdir_img2img_grids
401
+ p .outpath_samples = opts .outdir_img2img_samples
402
+ try :
403
+ shared .state .begin (job = "scripts_img2img" )
404
+ if selectable_scripts is not None :
405
+ p .script_args = script_args
406
+ processed = scripts .scripts_img2img .run (p , * p .script_args ) # Need to pass args as list here
407
+ else :
408
+ p .script_args = tuple (script_args ) # Need to pass args as tuple here
409
+ processed = process_images (p )
410
+ finally :
411
+ shared .state .end ()
412
+ shared .total_tqdm .clear ()
420
413
421
414
b64images = list (map (encode_pil_to_base64 , processed .images )) if send_images else []
422
415
@@ -432,8 +425,8 @@ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
432
425
433
426
reqDict ['image' ] = decode_base64_to_image (reqDict ['image' ])
434
427
435
- with self . queue_lock :
436
- result = postprocessing .run_extras (extras_mode = 0 , image_folder = "" , input_dir = "" , output_dir = "" , save_output = False , ** reqDict )
428
+
429
+ result = postprocessing .run_extras (extras_mode = 0 , image_folder = "" , input_dir = "" , output_dir = "" , save_output = False , ** reqDict )
437
430
438
431
return models .ExtrasSingleImageResponse (image = encode_pil_to_base64 (result [0 ][0 ]), html_info = result [1 ])
439
432
@@ -444,8 +437,8 @@ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
444
437
image_list = reqDict .pop ('imageList' , [])
445
438
image_folder = [decode_base64_to_image (x .data ) for x in image_list ]
446
439
447
- with self . queue_lock :
448
- result = postprocessing .run_extras (extras_mode = 1 , image_folder = image_folder , image = "" , input_dir = "" , output_dir = "" , save_output = False , ** reqDict )
440
+
441
+ result = postprocessing .run_extras (extras_mode = 1 , image_folder = image_folder , image = "" , input_dir = "" , output_dir = "" , save_output = False , ** reqDict )
449
442
450
443
return models .ExtrasBatchImagesResponse (images = list (map (encode_pil_to_base64 , result [0 ])), html_info = result [1 ])
451
444
@@ -505,14 +498,13 @@ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
505
498
img = img .convert ('RGB' )
506
499
507
500
# Override object param
508
- with self .queue_lock :
509
- if interrogatereq .model == "clip" :
510
- processed = shared .interrogator .interrogate (img )
511
- elif interrogatereq .model == "deepdanbooru" :
512
- processed = deepbooru .model .tag (img )
513
- else :
514
- raise HTTPException (status_code = 404 , detail = "Model not found" )
515
-
501
+ if interrogatereq .model == "clip" :
502
+ processed = shared .interrogator .interrogate (img )
503
+ elif interrogatereq .model == "deepdanbooru" :
504
+ processed = deepbooru .model .tag (img )
505
+ else :
506
+ raise HTTPException (status_code = 404 , detail = "Model not found" )
507
+
516
508
return models .InterrogateResponse (caption = processed )
517
509
518
510
@app .post ("/interrupt" )
@@ -646,13 +638,13 @@ def convert_embeddings(embeddings):
646
638
647
639
@app .post ("/refresh-checkpoints" )
648
640
def refresh_checkpoints (self ):
649
- with self . queue_lock :
650
- shared .refresh_checkpoints ()
641
+
642
+ shared .refresh_checkpoints ()
651
643
652
644
@app .post ("/refresh-vae" )
653
645
def refresh_vae (self ):
654
- with self . queue_lock :
655
- shared_items .refresh_vae_list ()
646
+
647
+ shared_items .refresh_vae_list ()
656
648
657
649
@app .post ("/create/embedding" , response_model = models .CreateResponse )
658
650
def create_embedding (self , args : dict ):
@@ -794,43 +786,6 @@ def stop_webui(request):
794
786
shared .state .server_command = "stop"
795
787
return Response ("Stopping." )
796
788
797
- def launch_ray (self ):
798
- self .app .add_api_route ("/sdapi/v1/txt2img" , self .text2imgapi , methods = ["POST" ], response_model = models .TextToImageResponse )
799
- self .app .add_api_route ("/sdapi/v1/img2img" , self .img2imgapi , methods = ["POST" ], response_model = models .ImageToImageResponse )
800
- self .app .add_api_route ("/sdapi/v1/extra-single-image" , self .extras_single_image_api , methods = ["POST" ], response_model = models .ExtrasSingleImageResponse )
801
- self .app .add_api_route ("/sdapi/v1/extra-batch-images" , self .extras_batch_images_api , methods = ["POST" ], response_model = models .ExtrasBatchImagesResponse )
802
- self .app .add_api_route ("/sdapi/v1/png-info" , self .pnginfoapi , methods = ["POST" ], response_model = models .PNGInfoResponse )
803
- self .app .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = models .ProgressResponse )
804
- self .app .add_api_route ("/sdapi/v1/interrogate" , self .interrogateapi , methods = ["POST" ])
805
- self .app .add_api_route ("/sdapi/v1/interrupt" , self .interruptapi , methods = ["POST" ])
806
- self .app .add_api_route ("/sdapi/v1/skip" , self .skip , methods = ["POST" ])
807
- self .app .add_api_route ("/sdapi/v1/options" , self .get_config , methods = ["GET" ], response_model = models .OptionsModel )
808
- self .app .add_api_route ("/sdapi/v1/options" , self .set_config , methods = ["POST" ])
809
- self .app .add_api_route ("/sdapi/v1/cmd-flags" , self .get_cmd_flags , methods = ["GET" ], response_model = models .FlagsModel )
810
- self .app .add_api_route ("/sdapi/v1/samplers" , self .get_samplers , methods = ["GET" ], response_model = List [models .SamplerItem ])
811
- self .app .add_api_route ("/sdapi/v1/upscalers" , self .get_upscalers , methods = ["GET" ], response_model = List [models .UpscalerItem ])
812
- self .app .add_api_route ("/sdapi/v1/latent-upscale-modes" , self .get_latent_upscale_modes , methods = ["GET" ], response_model = List [models .LatentUpscalerModeItem ])
813
- self .app .add_api_route ("/sdapi/v1/sd-models" , self .get_sd_models , methods = ["GET" ], response_model = List [models .SDModelItem ])
814
- self .app .add_api_route ("/sdapi/v1/sd-vae" , self .get_sd_vaes , methods = ["GET" ], response_model = List [models .SDVaeItem ])
815
- self .app .add_api_route ("/sdapi/v1/hypernetworks" , self .get_hypernetworks , methods = ["GET" ], response_model = List [models .HypernetworkItem ])
816
- self .app .add_api_route ("/sdapi/v1/face-restorers" , self .get_face_restorers , methods = ["GET" ], response_model = List [models .FaceRestorerItem ])
817
- self .app .add_api_route ("/sdapi/v1/realesrgan-models" , self .get_realesrgan_models , methods = ["GET" ], response_model = List [models .RealesrganItem ])
818
- self .app .add_api_route ("/sdapi/v1/prompt-styles" , self .get_prompt_styles , methods = ["GET" ], response_model = List [models .PromptStyleItem ])
819
- self .app .add_api_route ("/sdapi/v1/embeddings" , self .get_embeddings , methods = ["GET" ], response_model = models .EmbeddingsResponse )
820
- self .app .add_api_route ("/sdapi/v1/refresh-checkpoints" , self .refresh_checkpoints , methods = ["POST" ])
821
- self .app .add_api_route ("/sdapi/v1/refresh-vae" , self .refresh_vae , methods = ["POST" ])
822
- self .app .add_api_route ("/sdapi/v1/create/embedding" , self .create_embedding , methods = ["POST" ], response_model = models .CreateResponse )
823
- self .app .add_api_route ("/sdapi/v1/create/hypernetwork" , self .create_hypernetwork , methods = ["POST" ], response_model = models .CreateResponse )
824
- self .app .add_api_route ("/sdapi/v1/preprocess" , self .preprocess , methods = ["POST" ], response_model = models .PreprocessResponse )
825
- self .app .add_api_route ("/sdapi/v1/train/embedding" , self .train_embedding , methods = ["POST" ], response_model = models .TrainResponse )
826
- self .app .add_api_route ("/sdapi/v1/train/hypernetwork" , self .train_hypernetwork , methods = ["POST" ], response_model = models .TrainResponse )
827
- self .app .add_api_route ("/sdapi/v1/memory" , self .get_memory , methods = ["GET" ], response_model = models .MemoryResponse )
828
- self .app .add_api_route ("/sdapi/v1/unload-checkpoint" , self .unloadapi , methods = ["POST" ])
829
- self .app .add_api_route ("/sdapi/v1/reload-checkpoint" , self .reloadapi , methods = ["POST" ])
830
- self .app .add_api_route ("/sdapi/v1/scripts" , self .get_scripts_list , methods = ["GET" ], response_model = models .ScriptsList )
831
- self .app .add_api_route ("/sdapi/v1/script-info" , self .get_script_info , methods = ["GET" ], response_model = List [models .ScriptInfo ])
832
-
833
- if shared .cmd_opts .api_server_stop :
834
- self .app .add_api_route ("/sdapi/v1/server-kill" , self .kill_webui , methods = ["POST" ])
835
- self .app .add_api_route ("/sdapi/v1/server-restart" , self .restart_webui , methods = ["POST" ])
836
- self .app .add_api_route ("/sdapi/v1/server-stop" , self .stop_webui , methods = ["POST" ])
789
+ def launch (self , server_name , port , root_path ):
790
+ self .app .include_router (self .router )
791
+ uvicorn .run (self .app , host = server_name , port = port , timeout_keep_alive = shared .cmd_opts .timeout_keep_alive , root_path = root_path )
0 commit comments