4
4
using SixLabors . ImageSharp ;
5
5
using SixLabors . ImageSharp . Processing ;
6
6
using SwarmUI . Accounts ;
7
- using SwarmUI . Backends ;
8
7
using SwarmUI . Core ;
9
8
using SwarmUI . Text2Image ;
10
9
using SwarmUI . Utils ;
@@ -80,7 +79,68 @@ public static async Task<JObject> GenerateText2ImageWS(WebSocket socket, Session
80
79
[ API . APIParameter ( "The number of images to generate." ) ] int images ,
81
80
[ API . APIParameter ( "Raw mapping of input should contain general T2I parameters (see listing on Generate tab of main interface) to values, eg `{ \" prompt\" : \" a photo of a cat\" , \" model\" : \" OfficialStableDiffusion/sd_xl_base_1.0\" , \" steps\" : 20, ... }`. Note that this is the root raw map, ie all params go on the same level as `images`, `session_id`, etc." ) ] JObject rawInput )
82
81
{
83
- await API . RunWebsocketHandlerCallWS ( GenT2I_Internal , session , ( images , rawInput ) , socket ) ;
82
+ using CancellationTokenSource cancelTok = new ( ) ;
83
+ bool retain = false , ended = false ;
84
+ using CancellationTokenSource linked = CancellationTokenSource . CreateLinkedTokenSource ( Program . GlobalProgramCancel , cancelTok . Token ) ;
85
+ SharedGenT2IData data = new ( ) ;
86
+ ConcurrentDictionary < Task , Task > tasks = [ ] ;
87
+ static int guessBatchSize ( JObject input )
88
+ {
89
+ if ( input . TryGetValue ( "batchsize" , out JToken batch ) )
90
+ {
91
+ return batch . Value < int > ( ) ;
92
+ }
93
+ return 1 ;
94
+ }
95
+ _ = Utilities . RunCheckedTask ( async ( ) =>
96
+ {
97
+ try
98
+ {
99
+ int batchOffset = images * guessBatchSize ( rawInput ) ;
100
+ while ( ! cancelTok . IsCancellationRequested )
101
+ {
102
+ byte [ ] rec = await socket . ReceiveData ( 1024 * 1024 * 256 , linked . Token ) ;
103
+ Volatile . Write ( ref retain , true ) ;
104
+ if ( socket . State != WebSocketState . Open || cancelTok . IsCancellationRequested || Volatile . Read ( ref ended ) )
105
+ {
106
+ return ;
107
+ }
108
+ JObject newInput = StringConversionHelper . UTF8Encoding . GetString ( rec ) . ParseToJson ( ) ;
109
+ int newImages = newInput . Value < int > ( "images" ) ;
110
+ Task handleMore = API . RunWebsocketHandlerCallWS ( GenT2I_Internal , session , ( newImages , newInput , data , batchOffset ) , socket ) ;
111
+ tasks . TryAdd ( handleMore , handleMore ) ;
112
+ Volatile . Write ( ref retain , false ) ;
113
+ batchOffset += newImages * guessBatchSize ( newInput ) ;
114
+ }
115
+ }
116
+ catch ( TaskCanceledException )
117
+ {
118
+ return ;
119
+ }
120
+ finally
121
+ {
122
+ Volatile . Write ( ref retain , false ) ;
123
+ }
124
+ } ) ;
125
+ Task handle = API . RunWebsocketHandlerCallWS ( GenT2I_Internal , session , ( images , rawInput , data , 0 ) , socket ) ;
126
+ tasks . TryAdd ( handle , handle ) ;
127
+ while ( Volatile . Read ( ref retain ) || tasks . Any ( ) )
128
+ {
129
+ await Task . WhenAny ( tasks . Keys . ToList ( ) ) ;
130
+ foreach ( Task t in tasks . Keys . Where ( t => t . IsCompleted ) . ToList ( ) )
131
+ {
132
+ tasks . TryRemove ( t , out _ ) ;
133
+ }
134
+ if ( tasks . IsEmpty ( ) )
135
+ {
136
+ await socket . SendJson ( new JObject ( ) { [ "socket_intention" ] = "close" } , API . WebsocketTimeout ) ;
137
+ await Task . Delay ( TimeSpan . FromSeconds ( 2 ) ) ; // Give 2 seconds to allow a new gen request before actually closing
138
+ if ( tasks . IsEmpty ( ) )
139
+ {
140
+ Volatile . Write ( ref ended , true ) ;
141
+ }
142
+ }
143
+ }
84
144
await socket . SendJson ( BasicAPIFeatures . GetCurrentStatusRaw ( session ) , API . WebsocketTimeout ) ;
85
145
return null ;
86
146
}
@@ -100,7 +160,7 @@ public static async Task<JObject> GenerateText2Image(Session session,
100
160
[ API . APIParameter ( "The number of images to generate." ) ] int images ,
101
161
[ API . APIParameter ( "Raw mapping of input should contain general T2I parameters (see listing on Generate tab of main interface) to values, eg `{ \" prompt\" : \" a photo of a cat\" , \" model\" : \" OfficialStableDiffusion/sd_xl_base_1.0\" , \" steps\" : 20, ... }`. Note that this is the root raw map, ie all params go on the same level as `images`, `session_id`, etc." ) ] JObject rawInput )
102
162
{
103
- List < JObject > outputs = await API . RunWebsocketHandlerCallDirect ( GenT2I_Internal , session , ( images , rawInput ) ) ;
163
+ List < JObject > outputs = await API . RunWebsocketHandlerCallDirect ( GenT2I_Internal , session , ( images , rawInput , new SharedGenT2IData ( ) , 0 ) ) ;
104
164
Dictionary < int , string > imageOutputs = [ ] ;
105
165
int [ ] discards = null ;
106
166
foreach ( JObject obj in outputs )
@@ -169,10 +229,15 @@ public static T2IParamInput RequestToParams(Session session, JObject rawInput)
169
229
return user_input ;
170
230
}
171
231
232
+ public class SharedGenT2IData
233
+ {
234
+ public int NumExtra , NumNonReal ;
235
+ }
236
+
172
237
/// <summary>Internal route for generating images.</summary>
173
- public static async Task GenT2I_Internal ( Session session , ( int , JObject ) input , Action < JObject > output , bool isWS )
238
+ public static async Task GenT2I_Internal ( Session session , ( int , JObject , SharedGenT2IData , int ) input , Action < JObject > output , bool isWS )
174
239
{
175
- ( int images , JObject rawInput ) = input ;
240
+ ( int images , JObject rawInput , SharedGenT2IData data , int batchOffset ) = input ;
176
241
using Session . GenClaim claim = session . Claim ( gens : images ) ;
177
242
void setError ( string message )
178
243
{
@@ -214,7 +279,6 @@ void removeDoneTasks()
214
279
}
215
280
int max_degrees = session . User . Restrictions . CalcMaxT2ISimultaneous ;
216
281
List < int > discard = [ ] ;
217
- int numExtra = 0 , numNonReal = 0 ;
218
282
int batchSizeExpected = user_input . Get ( T2IParamTypes . BatchSize , 1 ) ;
219
283
void saveImage ( T2IEngine . ImageOutput image , int actualIndex , T2IParamInput thisParams , string metadata )
220
284
{
@@ -255,7 +319,7 @@ void saveImage(T2IEngine.ImageOutput image, int actualIndex, T2IParamInput thisP
255
319
{
256
320
break ;
257
321
}
258
- int imageIndex = i * batchSizeExpected ;
322
+ int imageIndex = i * batchSizeExpected + batchOffset ;
259
323
T2IParamInput thisParams = user_input . Clone ( ) ;
260
324
if ( ! thisParams . Get ( T2IParamTypes . NoSeedIncrement , false ) )
261
325
{
@@ -278,12 +342,12 @@ void saveImage(T2IEngine.ImageOutput image, int actualIndex, T2IParamInput thisP
278
342
numCalls ++ ;
279
343
if ( numCalls > batchSizeExpected )
280
344
{
281
- actualIndex = images * batchSizeExpected + Interlocked . Increment ( ref numExtra ) ;
345
+ actualIndex = images * batchSizeExpected + Interlocked . Increment ( ref data . NumExtra ) ;
282
346
}
283
347
}
284
348
else
285
349
{
286
- actualIndex = - 10 - Interlocked . Increment ( ref numNonReal ) ;
350
+ actualIndex = - 10 - Interlocked . Increment ( ref data . NumNonReal ) ;
287
351
}
288
352
saveImage ( image , actualIndex , thisParams , metadata ) ;
289
353
} ) ) ) ;
0 commit comments