Skip to content

Commit 4ff3fe0

Browse files
committed
[Experimental] reuse established websockets for multiple image gens
for #380
1 parent ea670ab commit 4ff3fe0

File tree

4 files changed

+100
-19
lines changed

4 files changed

+100
-19
lines changed

src/WebAPI/T2IAPI.cs

+73-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using SixLabors.ImageSharp;
55
using SixLabors.ImageSharp.Processing;
66
using SwarmUI.Accounts;
7-
using SwarmUI.Backends;
87
using SwarmUI.Core;
98
using SwarmUI.Text2Image;
109
using SwarmUI.Utils;
@@ -80,7 +79,68 @@ public static async Task<JObject> GenerateText2ImageWS(WebSocket socket, Session
8079
[API.APIParameter("The number of images to generate.")] int images,
8180
[API.APIParameter("Raw mapping of input should contain general T2I parameters (see listing on Generate tab of main interface) to values, eg `{ \"prompt\": \"a photo of a cat\", \"model\": \"OfficialStableDiffusion/sd_xl_base_1.0\", \"steps\": 20, ... }`. Note that this is the root raw map, ie all params go on the same level as `images`, `session_id`, etc.")] JObject rawInput)
8281
{
83-
await API.RunWebsocketHandlerCallWS(GenT2I_Internal, session, (images, rawInput), socket);
82+
using CancellationTokenSource cancelTok = new();
83+
bool retain = false, ended = false;
84+
using CancellationTokenSource linked = CancellationTokenSource.CreateLinkedTokenSource(Program.GlobalProgramCancel, cancelTok.Token);
85+
SharedGenT2IData data = new();
86+
ConcurrentDictionary<Task, Task> tasks = [];
87+
static int guessBatchSize(JObject input)
88+
{
89+
if (input.TryGetValue("batchsize", out JToken batch))
90+
{
91+
return batch.Value<int>();
92+
}
93+
return 1;
94+
}
95+
_ = Utilities.RunCheckedTask(async () =>
96+
{
97+
try
98+
{
99+
int batchOffset = images * guessBatchSize(rawInput);
100+
while (!cancelTok.IsCancellationRequested)
101+
{
102+
byte[] rec = await socket.ReceiveData(1024 * 1024 * 256, linked.Token);
103+
Volatile.Write(ref retain, true);
104+
if (socket.State != WebSocketState.Open || cancelTok.IsCancellationRequested || Volatile.Read(ref ended))
105+
{
106+
return;
107+
}
108+
JObject newInput = StringConversionHelper.UTF8Encoding.GetString(rec).ParseToJson();
109+
int newImages = newInput.Value<int>("images");
110+
Task handleMore = API.RunWebsocketHandlerCallWS(GenT2I_Internal, session, (newImages, newInput, data, batchOffset), socket);
111+
tasks.TryAdd(handleMore, handleMore);
112+
Volatile.Write(ref retain, false);
113+
batchOffset += newImages * guessBatchSize(newInput);
114+
}
115+
}
116+
catch (TaskCanceledException)
117+
{
118+
return;
119+
}
120+
finally
121+
{
122+
Volatile.Write(ref retain, false);
123+
}
124+
});
125+
Task handle = API.RunWebsocketHandlerCallWS(GenT2I_Internal, session, (images, rawInput, data, 0), socket);
126+
tasks.TryAdd(handle, handle);
127+
while (Volatile.Read(ref retain) || tasks.Any())
128+
{
129+
await Task.WhenAny(tasks.Keys.ToList());
130+
foreach (Task t in tasks.Keys.Where(t => t.IsCompleted).ToList())
131+
{
132+
tasks.TryRemove(t, out _);
133+
}
134+
if (tasks.IsEmpty())
135+
{
136+
await socket.SendJson(new JObject() { ["socket_intention"] = "close" }, API.WebsocketTimeout);
137+
await Task.Delay(TimeSpan.FromSeconds(2)); // Give 2 seconds to allow a new gen request before actually closing
138+
if (tasks.IsEmpty())
139+
{
140+
Volatile.Write(ref ended, true);
141+
}
142+
}
143+
}
84144
await socket.SendJson(BasicAPIFeatures.GetCurrentStatusRaw(session), API.WebsocketTimeout);
85145
return null;
86146
}
@@ -100,7 +160,7 @@ public static async Task<JObject> GenerateText2Image(Session session,
100160
[API.APIParameter("The number of images to generate.")] int images,
101161
[API.APIParameter("Raw mapping of input should contain general T2I parameters (see listing on Generate tab of main interface) to values, eg `{ \"prompt\": \"a photo of a cat\", \"model\": \"OfficialStableDiffusion/sd_xl_base_1.0\", \"steps\": 20, ... }`. Note that this is the root raw map, ie all params go on the same level as `images`, `session_id`, etc.")] JObject rawInput)
102162
{
103-
List<JObject> outputs = await API.RunWebsocketHandlerCallDirect(GenT2I_Internal, session, (images, rawInput));
163+
List<JObject> outputs = await API.RunWebsocketHandlerCallDirect(GenT2I_Internal, session, (images, rawInput, new SharedGenT2IData(), 0));
104164
Dictionary<int, string> imageOutputs = [];
105165
int[] discards = null;
106166
foreach (JObject obj in outputs)
@@ -169,10 +229,15 @@ public static T2IParamInput RequestToParams(Session session, JObject rawInput)
169229
return user_input;
170230
}
171231

232+
public class SharedGenT2IData
233+
{
234+
public int NumExtra, NumNonReal;
235+
}
236+
172237
/// <summary>Internal route for generating images.</summary>
173-
public static async Task GenT2I_Internal(Session session, (int, JObject) input, Action<JObject> output, bool isWS)
238+
public static async Task GenT2I_Internal(Session session, (int, JObject, SharedGenT2IData, int) input, Action<JObject> output, bool isWS)
174239
{
175-
(int images, JObject rawInput) = input;
240+
(int images, JObject rawInput, SharedGenT2IData data, int batchOffset) = input;
176241
using Session.GenClaim claim = session.Claim(gens: images);
177242
void setError(string message)
178243
{
@@ -214,7 +279,6 @@ void removeDoneTasks()
214279
}
215280
int max_degrees = session.User.Restrictions.CalcMaxT2ISimultaneous;
216281
List<int> discard = [];
217-
int numExtra = 0, numNonReal = 0;
218282
int batchSizeExpected = user_input.Get(T2IParamTypes.BatchSize, 1);
219283
void saveImage(T2IEngine.ImageOutput image, int actualIndex, T2IParamInput thisParams, string metadata)
220284
{
@@ -255,7 +319,7 @@ void saveImage(T2IEngine.ImageOutput image, int actualIndex, T2IParamInput thisP
255319
{
256320
break;
257321
}
258-
int imageIndex = i * batchSizeExpected;
322+
int imageIndex = i * batchSizeExpected + batchOffset;
259323
T2IParamInput thisParams = user_input.Clone();
260324
if (!thisParams.Get(T2IParamTypes.NoSeedIncrement, false))
261325
{
@@ -278,12 +342,12 @@ void saveImage(T2IEngine.ImageOutput image, int actualIndex, T2IParamInput thisP
278342
numCalls++;
279343
if (numCalls > batchSizeExpected)
280344
{
281-
actualIndex = images * batchSizeExpected + Interlocked.Increment(ref numExtra);
345+
actualIndex = images * batchSizeExpected + Interlocked.Increment(ref data.NumExtra);
282346
}
283347
}
284348
else
285349
{
286-
actualIndex = -10 - Interlocked.Increment(ref numNonReal);
350+
actualIndex = -10 - Interlocked.Increment(ref data.NumNonReal);
287351
}
288352
saveImage(image, actualIndex, thisParams, metadata);
289353
})));

src/wwwroot/js/genpage/generatehandler.js

+25-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class GenerateHandler {
66
this.totalGenRunTime = 0;
77
this.validateModel = true;
88
this.interrupted = -1;
9+
this.socket = null;
910
this.imageContainerDivId = 'current_image';
1011
this.imageId = 'current_image_img';
1112
this.progressBarHtml = `<div class="image-preview-progress-inner"><div class="image-preview-progress-overall"></div><div class="image-preview-progress-current"></div></div>`;
@@ -95,7 +96,20 @@ class GenerateHandler {
9596
let discardable = {};
9697
let timeLastGenHit = Date.now();
9798
let actualInput = this.getGenInput(input_overrides, input_preoverrides);
98-
makeWSRequestT2I('GenerateText2ImageWS', actualInput, data => {
99+
let socket = null;
100+
let handleData = data => {
101+
if ('socket_intention' in data && data.socket_intention == 'close') {
102+
if (this.socket == socket) {
103+
this.socket = null;
104+
}
105+
if (Object.keys(discardable).length > 0) {
106+
// clear any lingering previews
107+
for (let img of Object.values(images)) {
108+
img.div.remove();
109+
}
110+
}
111+
return;
112+
}
99113
if (isPreview) {
100114
if (data.image) {
101115
this.setCurrentImage(data.image, data.metadata, `${batch_id}_${data.batch_index}`, false, true);
@@ -197,20 +211,22 @@ class GenerateHandler {
197211
this.setCurrentImage(imgs[0].image, imgs[0].metadata);
198212
}
199213
}
200-
if (Object.keys(discardable).length > 0) {
201-
// clear any lingering previews
202-
for (let img of Object.values(images)) {
203-
img.div.remove();
204-
}
205-
}
206214
}
207-
}, e => {
215+
};
216+
let handleError = e => {
208217
console.log(`Error in GenerateText2ImageWS: ${e}, ${this.interrupted}, ${batch_id}`);
209218
if (this.interrupted >= batch_id) {
210219
return;
211220
}
212221
this.hadError(e);
213-
});
222+
};
223+
if (this.socket && this.socket.readyState == WebSocket.OPEN) {
224+
this.socket.send(JSON.stringify(actualInput));
225+
}
226+
else {
227+
socket = makeWSRequestT2I('GenerateText2ImageWS', actualInput, handleData, handleError);
228+
this.socket = socket;
229+
}
214230
};
215231
if (this.validateModel) {
216232
if (getRequiredElementById('current_model').value == '') {

src/wwwroot/js/genpage/main.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ function updateGenCount() {
895895
}
896896

897897
function makeWSRequestT2I(url, in_data, callback, errorHandle = null) {
898-
makeWSRequest(url, in_data, data => {
898+
return makeWSRequest(url, in_data, data => {
899899
if (data.status) {
900900
updateCurrentStatusDirect(data.status);
901901
}

src/wwwroot/js/site.js

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ function makeWSRequest(url, in_data, callback, depth = 0, errorHandle = null, on
129129
callback(data);
130130
}
131131
socket.onerror = errorHandle ? () => errorHandle(genericServerErrorMsg.get()) : genericServerError;
132+
return socket;
132133
}
133134

134135
let failedCrash = translatable(`Failed to send request to server. Did the server crash?`);

0 commit comments

Comments
 (0)