Skip to content

Commit 06a003a

Browse files
committed
changes
1 parent 9e682a3 commit 06a003a

File tree

4 files changed

+83
-140
lines changed

4 files changed

+83
-140
lines changed

Diff for: .gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ notification.mp3
3838
/package-lock.json
3939
/.coverage*
4040

41-
sd-data
41+
sd-data
42+
modules/api/raypi2.py

Diff for: modules/api/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ async def http_exception_handler(request: Request, e: HTTPException):
198198

199199

200200
class Api:
201-
def __init__(self, app: FastAPI):
201+
def __init__(self, app: FastAPI, queue_lock: Lock):
202202
if shared.cmd_opts.api_auth:
203203
self.credentials = {}
204204
for auth in shared.cmd_opts.api_auth.split(","):
@@ -207,7 +207,7 @@ def __init__(self, app: FastAPI):
207207

208208
self.router = APIRouter()
209209
self.app = app
210-
#self.queue_lock = queue_lock
210+
self.queue_lock = queue_lock
211211
api_middleware(self.app)
212212
print("API initialized")
213213
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)

Diff for: modules/api/ray.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from ray import serve
22
import ray
33
from fastapi import FastAPI
4-
from modules.api.raypi import Api
4+
from modules.api.raypi import Raypi
55
from modules import initialize_util
66
from modules import script_callbacks
77
from modules import initialize
88
import time
99

1010
from ray.serve.handle import DeploymentHandle
11-
11+
from modules.call_queue import queue_lock
1212
from modules.shared_cmd_options import cmd_opts
1313

1414
ray.init()
@@ -24,38 +24,25 @@
2424
"per worker. Ignore this if your cluster auto-scales."
2525
)
2626

27-
#initialize.initialize()
28-
#app = FastAPI()
29-
#initialize_util.setup_middleware(app)
30-
#api = Api(app)
31-
#app.include_router(api.router)
32-
#script_callbacks.before_ui_callback()
33-
#script_callbacks.app_started_callback(None, app)
27+
initialize.initialize()
28+
app = FastAPI()
29+
#app.include_router(Raypi(app).router)
30+
initialize_util.setup_middleware(app)
31+
script_callbacks.before_ui_callback()
32+
script_callbacks.app_started_callback(None, app)
33+
34+
3435

3536

36-
def ray_only():
37-
from fastapi import FastAPI
38-
from modules.shared_cmd_options import cmd_opts
39-
from modules import script_callbacks
40-
# Shutdown any existing Serve replicas, if they're still around.
41-
serve.shutdown()
42-
serve.start()
4337

44-
initialize.initialize()
4538

46-
app = FastAPI()
47-
initialize_util.setup_middleware(app)
4839

49-
script_callbacks.before_ui_callback()
50-
script_callbacks.app_started_callback(None, app)
51-
#Api.deploy()
52-
#api = Api(app) # Create an instance of the Api class
53-
serve.run(Api.bind() , port=8000) #route_prefix="/sdapi/v1" # Call the launch_ray method to get the FastAPI app
40+
def ray_only():
41+
serve.shutdown()
42+
serve.start()
43+
serve.run(Raypi.bind(), port=8000) #route_prefix="/sdapi/v1" # Call the launch_ray method to get the FastAPI app
5444

5545

5646
print("Done setting up replicas! Now accepting requests...")
5747
while True:
5848
time.sleep(1000)
59-
60-
61-

Diff for: modules/api/raypi.py

+65-110
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
import piexif.helper
3535
from contextlib import closing
3636

37+
from modules import initialize_util
38+
from modules import script_callbacks
39+
from modules import initialize
40+
3741
from ray import serve
3842

3943
app = FastAPI()
@@ -199,35 +203,26 @@ async def fastapi_exception_handler(request: Request, e: Exception):
199203
async def http_exception_handler(request: Request, e: HTTPException):
200204
return handle_exception(request, e)
201205

202-
from ray import serve
203-
import ray
204-
205-
206-
@serve.deployment(
207-
ray_actor_options={"num_gpus": 1},
208-
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
209-
#route_prefix="/sdapi/v1",
210-
)
206+
@serve.deployment(
207+
ray_actor_options={"num_gpus": 1},
208+
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
209+
)
211210
@serve.ingress(app)
212-
class Api:
213-
def __init__(self, app: FastAPI):
211+
class Raypi:
212+
def __init__(self):
214213
if shared.cmd_opts.api_auth:
215214
self.credentials = {}
216215
for auth in shared.cmd_opts.api_auth.split(","):
217216
user, password = auth.split(":")
218217
self.credentials[user] = password
219-
220-
self.app = app
221-
#self.queue_lock = queue_lock
222-
api_middleware(self.app)
223-
self.launch_ray()
218+
219+
224220
print("API initialized")
225221

226222
self.default_script_arg_txt2img = []
227223
self.default_script_arg_img2img = []
228224

229225

230-
231226
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
232227
if credentials.username in self.credentials:
233228
if compare_digest(credentials.password, self.credentials[credentials.username]):
@@ -337,24 +332,23 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
337332
send_images = args.pop('send_images', True)
338333
args.pop('save_images', None)
339334

340-
with self.queue_lock:
341-
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
342-
p.is_api = True
343-
p.scripts = script_runner
344-
p.outpath_grids = opts.outdir_txt2img_grids
345-
p.outpath_samples = opts.outdir_txt2img_samples
346-
347-
try:
348-
shared.state.begin(job="scripts_txt2img")
349-
if selectable_scripts is not None:
350-
p.script_args = script_args
351-
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
352-
else:
353-
p.script_args = tuple(script_args) # Need to pass args as tuple here
354-
processed = process_images(p)
355-
finally:
356-
shared.state.end()
357-
shared.total_tqdm.clear()
335+
336+
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
337+
p.is_api = True
338+
p.scripts = script_runner
339+
p.outpath_grids = opts.outdir_txt2img_grids
340+
p.outpath_samples = opts.outdir_txt2img_samples
341+
try:
342+
shared.state.begin(job="scripts_txt2img")
343+
if selectable_scripts is not None:
344+
p.script_args = script_args
345+
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
346+
else:
347+
p.script_args = tuple(script_args) # Need to pass args as tuple here
348+
processed = process_images(p)
349+
finally:
350+
shared.state.end()
351+
shared.total_tqdm.clear()
358352

359353
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
360354

@@ -398,25 +392,24 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
398392
send_images = args.pop('send_images', True)
399393
args.pop('save_images', None)
400394

401-
with self.queue_lock:
402-
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
403-
p.init_images = [decode_base64_to_image(x) for x in init_images]
404-
p.is_api = True
405-
p.scripts = script_runner
406-
p.outpath_grids = opts.outdir_img2img_grids
407-
p.outpath_samples = opts.outdir_img2img_samples
408-
409-
try:
410-
shared.state.begin(job="scripts_img2img")
411-
if selectable_scripts is not None:
412-
p.script_args = script_args
413-
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
414-
else:
415-
p.script_args = tuple(script_args) # Need to pass args as tuple here
416-
processed = process_images(p)
417-
finally:
418-
shared.state.end()
419-
shared.total_tqdm.clear()
395+
396+
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
397+
p.init_images = [decode_base64_to_image(x) for x in init_images]
398+
p.is_api = True
399+
p.scripts = script_runner
400+
p.outpath_grids = opts.outdir_img2img_grids
401+
p.outpath_samples = opts.outdir_img2img_samples
402+
try:
403+
shared.state.begin(job="scripts_img2img")
404+
if selectable_scripts is not None:
405+
p.script_args = script_args
406+
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
407+
else:
408+
p.script_args = tuple(script_args) # Need to pass args as tuple here
409+
processed = process_images(p)
410+
finally:
411+
shared.state.end()
412+
shared.total_tqdm.clear()
420413

421414
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
422415

@@ -432,8 +425,8 @@ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
432425

433426
reqDict['image'] = decode_base64_to_image(reqDict['image'])
434427

435-
with self.queue_lock:
436-
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
428+
429+
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
437430

438431
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
439432

@@ -444,8 +437,8 @@ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
444437
image_list = reqDict.pop('imageList', [])
445438
image_folder = [decode_base64_to_image(x.data) for x in image_list]
446439

447-
with self.queue_lock:
448-
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
440+
441+
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
449442

450443
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
451444

@@ -505,14 +498,13 @@ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
505498
img = img.convert('RGB')
506499

507500
# Override object param
508-
with self.queue_lock:
509-
if interrogatereq.model == "clip":
510-
processed = shared.interrogator.interrogate(img)
511-
elif interrogatereq.model == "deepdanbooru":
512-
processed = deepbooru.model.tag(img)
513-
else:
514-
raise HTTPException(status_code=404, detail="Model not found")
515-
501+
if interrogatereq.model == "clip":
502+
processed = shared.interrogator.interrogate(img)
503+
elif interrogatereq.model == "deepdanbooru":
504+
processed = deepbooru.model.tag(img)
505+
else:
506+
raise HTTPException(status_code=404, detail="Model not found")
507+
516508
return models.InterrogateResponse(caption=processed)
517509

518510
@app.post("/interrupt")
@@ -646,13 +638,13 @@ def convert_embeddings(embeddings):
646638

647639
@app.post("/refresh-checkpoints")
648640
def refresh_checkpoints(self):
649-
with self.queue_lock:
650-
shared.refresh_checkpoints()
641+
642+
shared.refresh_checkpoints()
651643

652644
@app.post("/refresh-vae")
653645
def refresh_vae(self):
654-
with self.queue_lock:
655-
shared_items.refresh_vae_list()
646+
647+
shared_items.refresh_vae_list()
656648

657649
@app.post("/create/embedding", response_model=models.CreateResponse)
658650
def create_embedding(self, args: dict):
@@ -794,43 +786,6 @@ def stop_webui(request):
794786
shared.state.server_command = "stop"
795787
return Response("Stopping.")
796788

797-
def launch_ray(self):
798-
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
799-
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
800-
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
801-
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
802-
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
803-
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
804-
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
805-
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
806-
self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
807-
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
808-
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
809-
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
810-
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
811-
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
812-
self.app.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
813-
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
814-
self.app.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
815-
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
816-
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
817-
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
818-
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
819-
self.app.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
820-
self.app.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
821-
self.app.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
822-
self.app.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
823-
self.app.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
824-
self.app.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
825-
self.app.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
826-
self.app.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
827-
self.app.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
828-
self.app.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
829-
self.app.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
830-
self.app.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
831-
self.app.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
832-
833-
if shared.cmd_opts.api_server_stop:
834-
self.app.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
835-
self.app.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
836-
self.app.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
789+
def launch(self, server_name, port, root_path):
790+
self.app.include_router(self.router)
791+
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)

0 commit comments

Comments
 (0)