Skip to content

Commit e09a6d0

Browse files
committed
raypi working
1 parent 06a003a commit e09a6d0

File tree

6 files changed

+80
-30
lines changed

6 files changed

+80
-30
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ notification.mp3
4040

4141
sd-data
4242
modules/api/raypi2.py
43+
modules/api/raypi3.py

Diff for: modules/api/api.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from fastapi.encoders import jsonable_encoder
1717
from secrets import compare_digest
1818

19+
from modules import initialize
20+
initialize.imports()
21+
1922
import modules.shared as shared
2023
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
2124
from modules.api import models
@@ -33,7 +36,9 @@
3336
import piexif
3437
import piexif.helper
3538
from contextlib import closing
39+
from ray import serve
3640

41+
app = FastAPI()
3742

3843
def script_name_to_index(name, scripts):
3944
try:
@@ -197,18 +202,28 @@ async def http_exception_handler(request: Request, e: HTTPException):
197202
return handle_exception(request, e)
198203

199204

205+
api_middleware(app)
206+
207+
208+
@serve.deployment(
209+
ray_actor_options={"num_gpus": 1},
210+
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
211+
#route_prefix="/sdapi/v1",
212+
)
213+
@serve.ingress(app)
200214
class Api:
201-
def __init__(self, app: FastAPI, queue_lock: Lock):
215+
def __init__(self):
216+
initialize.initialize()
202217
if shared.cmd_opts.api_auth:
203218
self.credentials = {}
204219
for auth in shared.cmd_opts.api_auth.split(","):
205220
user, password = auth.split(":")
206221
self.credentials[user] = password
207222

208223
self.router = APIRouter()
209-
self.app = app
210-
self.queue_lock = queue_lock
211-
api_middleware(self.app)
224+
225+
226+
212227
print("API initialized")
213228
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
214229
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)

Diff for: modules/api/ray.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22
import ray
33
from fastapi import FastAPI
44
from modules.api.raypi import Raypi
5-
from modules import initialize_util
6-
from modules import script_callbacks
7-
from modules import initialize
5+
from modules.api.api import Raypi
86
import time
97

10-
from ray.serve.handle import DeploymentHandle
11-
from modules.call_queue import queue_lock
12-
from modules.shared_cmd_options import cmd_opts
8+
139

1410
ray.init()
1511
#ray.init("ray://localhost:10001")
@@ -24,23 +20,31 @@
2420
"per worker. Ignore this if your cluster auto-scales."
2521
)
2622

27-
initialize.initialize()
2823
app = FastAPI()
29-
#app.include_router(Raypi(app).router)
30-
initialize_util.setup_middleware(app)
31-
script_callbacks.before_ui_callback()
32-
script_callbacks.app_started_callback(None, app)
3324

25+
@serve.deployment(
26+
ray_actor_options={"num_gpus": 1},
27+
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
28+
#route_prefix="/sdapi/v1",
29+
)
30+
@serve.ingress(app)
31+
class RayDeployment:
32+
def __init__(self):
33+
pass
3434

3535

36+
# 2: Deploy the deployment.
3637

3738

3839

3940

4041
def ray_only():
4142
serve.shutdown()
4243
serve.start()
43-
serve.run(Raypi.bind(), port=8000) #route_prefix="/sdapi/v1" # Call the launch_ray method to get the FastAPI app
44+
#Raypi.deploy()
45+
46+
47+
serve.run(Raypi.bind(), port=8000, route_prefix="/sdapi/v1") #route_prefix="/sdapi/v1" # Call the launch_ray method to get the FastAPI app
4448

4549

4650
print("Done setting up replicas! Now accepting requests...")

Diff for: modules/api/raypi.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from fastapi.encoders import jsonable_encoder
1717
from secrets import compare_digest
1818

19+
from modules import initialize
20+
initialize.imports()
21+
1922
import modules.shared as shared
2023
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
2124
from modules.api import models
@@ -36,8 +39,9 @@
3639

3740
from modules import initialize_util
3841
from modules import script_callbacks
39-
from modules import initialize
4042

43+
44+
import launch
4145
from ray import serve
4246

4347
app = FastAPI()
@@ -203,20 +207,31 @@ async def fastapi_exception_handler(request: Request, e: Exception):
203207
async def http_exception_handler(request: Request, e: HTTPException):
204208
return handle_exception(request, e)
205209

206-
@serve.deployment(
207-
ray_actor_options={"num_gpus": 1},
208-
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
209-
)
210+
211+
api_middleware(app)
212+
213+
@serve.deployment(
214+
ray_actor_options={"num_gpus": 1},
215+
autoscaling_config={"min_replicas": 0, "max_replicas": 2},
216+
#route_prefix="/sdapi/v1",
217+
)
210218
@serve.ingress(app)
211219
class Raypi:
212220
def __init__(self):
221+
print("Initializing API")
222+
initialize.initialize()
223+
print("preparing env")
224+
launch.prepare_environment()
225+
#app.include_router(Raypi(app).router)
226+
213227
if shared.cmd_opts.api_auth:
214228
self.credentials = {}
215229
for auth in shared.cmd_opts.api_auth.split(","):
216230
user, password = auth.split(":")
217231
self.credentials[user] = password
218-
219-
232+
print("preparing env")
233+
launch.prepare_environment()
234+
220235
print("API initialized")
221236

222237
self.default_script_arg_txt2img = []
@@ -785,7 +800,3 @@ def restart_webui(self):
785800
def stop_webui(request):
786801
shared.state.server_command = "stop"
787802
return Response("Stopping.")
788-
789-
def launch(self, server_name, port, root_path):
790-
self.app.include_router(self.router)
791-
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)

Diff for: modules/initialize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def initialize():
5151
initialize_util.fix_torch_version()
5252
initialize_util.fix_asyncio_event_loop_policy()
5353
initialize_util.validate_tls_options()
54-
initialize_util.configure_sigint_handler()
54+
#initialize_util.configure_sigint_handler()
5555
initialize_util.configure_opts_onchange()
5656

5757
from modules import modelloader

Diff for: webui.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,25 @@ def create_api(app):
2424
return api
2525

2626

27+
def ray_api():
28+
from modules.api. ray import ray_only
29+
30+
from modules.shared_cmd_options import cmd_opts
31+
32+
launch_api = cmd_opts.api
33+
initialize.initialize()
34+
35+
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
36+
37+
script_callbacks.before_ui_callback()
38+
startup_timer.record("scripts before_ui_callback")
39+
shared.demo = ui.create_ui()
40+
startup_timer.record("create ui")
41+
if not cmd_opts.no_gradio_queue:
42+
shared.demo.queue(64)
43+
ray_only()
44+
45+
2746
def api_only():
2847
from fastapi import FastAPI
2948
from modules.shared_cmd_options import cmd_opts
@@ -156,11 +175,11 @@ def webui():
156175

157176
if __name__ == "__main__":
158177
from modules.shared_cmd_options import cmd_opts
159-
from modules.api.ray import ray_only
178+
160179

161180
if cmd_opts.nowebui:
162181
api_only()
163182
elif cmd_opts.ray:
164-
ray_only()
183+
ray_api()
165184
else:
166185
webui()

0 commit comments

Comments
 (0)