45
45
46
46
47
47
class DiffusionCanvasWindow (QMainWindow ):
48
- canvas_image_tensor : torch .Tensor | None
48
+ gpu_canvas_image_tensor : torch .Tensor | None
49
+ cpu_canvas_image_tensor : torch .Tensor
50
+ cpu_canvas_q_image : QImage
49
51
show_noisy : bool
50
52
dirty_region_full : Bounds2D | None
51
53
dirty_region_quick : Bounds2D | None
@@ -172,13 +174,74 @@ def initialize_canvas(self, layer: Layer):
172
174
self .show_noisy = False
173
175
self .dirty_region_full : Bounds2D | None = None
174
176
self .dirty_region_quick : Bounds2D | None = None
175
- self .canvas_image_tensor = None
177
+ self .gpu_canvas_image_tensor = None
178
+
179
+ image_size = (
180
+ layer .clean_latent .shape [2 ] * latent_size_in_pixels ,
181
+ layer .clean_latent .shape [3 ] * latent_size_in_pixels
182
+ )
183
+
184
+ # Create a numpy array as the backing store
185
+ numpy_buffer = np .zeros ((image_size [0 ], image_size [1 ], 4 ), dtype = np .uint8 ) # RGBA format
186
+
187
+ # Create a QImage using the numpy buffer
188
+ self .cpu_canvas_q_image = QImage (
189
+ numpy_buffer .data , # Pointer to the data
190
+ image_size [1 ], # width
191
+ image_size [0 ], # height
192
+ QImage .Format .Format_RGB32 # Format
193
+ )
194
+
195
+ # Create a PyTorch tensor that shares the same memory
196
+ self .cpu_canvas_image_tensor = torch .from_numpy (numpy_buffer )
197
+
176
198
self .history = History (layer )
177
199
self .create_undo = True
178
200
179
201
# Update the display with the new canvas
180
202
self .update_canvas_view (noisy = False , full = True )
181
203
204
+ @staticmethod
205
+ @torch .no_grad ()
206
+ def _get_cpu_image_tensor (tensor : torch .Tensor , add_alpha : bool = True ):
207
+ """
208
+ Convert a tensor from the format used by stable diffusion VAE decoders
209
+ to a format used by QImage.
210
+
211
+ Args:
212
+ tensor (torch.Tensor): Input tensor with shape (1, 3, height, width).
213
+ add_alpha (bool): Whether to add a dummy alpha channel for QImage Format_RGB32.
214
+
215
+ Returns:
216
+ torch.Tensor: Tensor with shape (height, width, 4) or (height, width, 3).
217
+ """
218
+ # Ensure batch size is 1 and remove it
219
+ assert tensor .shape [0 ] == 1 , "Tensor batch size must be 1."
220
+
221
+ tensor = tensor .squeeze (0 ) # Shape: (RGB, height, width)
222
+
223
+ # Rearrange channels to BGR if needed for QImage
224
+ tensor = tensor [[2 , 1 , 0 ], :, :] # Shape: (BGR, height, width)
225
+
226
+ # Permute to (height, width, channels)
227
+ tensor = tensor .permute (1 , 2 , 0 ) # Shape: (height, width, BGR)
228
+
229
+ # Map and clamp range (0, 1) to (0, 255)
230
+ tensor = (tensor * 255 ).clamp (0 , 255 )
231
+
232
+ # Add a dummy alpha channel if required
233
+ if add_alpha :
234
+ alpha_channel = torch .full (
235
+ (tensor .shape [0 ], tensor .shape [1 ], 1 ),
236
+ 255 ,
237
+ dtype = tensor .dtype , # Match dtype of tensor (still likely float16 or float32)
238
+ device = tensor .device
239
+ )
240
+ tensor = torch .cat ((tensor , alpha_channel ), dim = 2 ) # Shape: (height, width, BGRA)
241
+
242
+ # Convert to uint8 on the CPU as the final step
243
+ return tensor .to (dtype = torch .uint8 , device = 'cpu' )
244
+
182
245
def closeEvent (self , event ):
183
246
with ExceptionCatcher (self , "Failed to handle close event" ):
184
247
"""
@@ -327,7 +390,7 @@ def update_frame(self):
327
390
self .full_preview_timer = 0
328
391
329
392
if self .showing_quick_preview :
330
- self .update_canvas_view (full = True )
393
+ self .update_canvas_view (full = True , region = None )
331
394
332
395
def canvas_mousePressEvent (self , event ):
333
396
with ExceptionCatcher (self , "Failed to handle mouse event" ):
@@ -374,7 +437,7 @@ def apply_brush(self, event: QMouseEvent):
374
437
normalized_mouse_coord = normalized_position
375
438
)
376
439
377
- self .update_canvas_view (noisy = result .show_noisy , modified_bounds = result .modified_bounds , full = False )
440
+ self .update_canvas_view (noisy = result .show_noisy , region = result .modified_bounds , full = False )
378
441
379
442
def keyPressEvent (self , event : QKeyEvent ):
380
443
with ExceptionCatcher (self , "Failed to handle key press event" ):
@@ -409,10 +472,9 @@ def redo(self):
409
472
self .history .redo (1 )
410
473
self .update_canvas_view (full = False )
411
474
412
- def update_canvas_view (self , noisy : bool | None = None , modified_bounds : Bounds2D | None = None , full : bool = True ):
475
+ def update_canvas_view (self , noisy : bool | None = None , region : Bounds2D | str | None = 'all' , full : bool = True ):
413
476
414
477
from utils .time_utils import Timer
415
- import utils .texture_convert as conv
416
478
417
479
if isinstance (noisy , bool ):
418
480
self .show_noisy = noisy
@@ -434,19 +496,26 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
434
496
else :
435
497
self .showing_quick_preview = False
436
498
437
- if modified_bounds is None :
438
- self .dirty_region_quick = full_bounds
439
- self .dirty_region_full = full_bounds
440
- else :
499
+ if region is not None :
500
+ if region == 'all' :
501
+ region = full_bounds
502
+
503
+ if isinstance (region , str ):
504
+ region = None
505
+
506
+ if not isinstance (region , Bounds2D ):
507
+ region = None
508
+
509
+ if isinstance (region , Bounds2D ):
441
510
self .dirty_region_quick = (
442
- modified_bounds
511
+ region
443
512
if self .dirty_region_quick is None
444
- else self .dirty_region_quick .get_encapsulated (modified_bounds )
513
+ else self .dirty_region_quick .get_encapsulated (region )
445
514
)
446
515
self .dirty_region_full = (
447
- modified_bounds
516
+ region
448
517
if self .dirty_region_full is None
449
- else self .dirty_region_full .get_encapsulated (modified_bounds )
518
+ else self .dirty_region_full .get_encapsulated (region )
450
519
)
451
520
452
521
region_to_redraw = (
@@ -462,8 +531,8 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
462
531
region_to_redraw = region_to_redraw .get_clipped (full_bounds )
463
532
region_to_redraw_with_padding = region_to_redraw_with_padding .get_clipped (full_bounds )
464
533
465
- if region_to_redraw_with_padding == full_bounds or self .canvas_image_tensor is None :
466
- self .canvas_image_tensor = self .api .latent_to_image_tiled (
534
+ if region_to_redraw_with_padding == full_bounds or self .gpu_canvas_image_tensor is None :
535
+ self .gpu_canvas_image_tensor = self .api .latent_to_image_tiled (
467
536
latent_to_show ,
468
537
max_tile_size_latents = 64 ,
469
538
overlap_size_latents = 8 ,
@@ -504,19 +573,22 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
504
573
]
505
574
506
575
with Timer ("Write to Image Tensor" ):
507
- self .canvas_image_tensor [
576
+ self .gpu_canvas_image_tensor [
508
577
:, :,
509
578
region_to_redraw .y_bounds [0 ] * latent_size_in_pixels :
510
579
region_to_redraw .y_bounds [1 ] * latent_size_in_pixels ,
511
580
region_to_redraw .x_bounds [0 ] * latent_size_in_pixels :
512
581
region_to_redraw .x_bounds [1 ] * latent_size_in_pixels
513
582
] = decoded_trimmed
514
583
515
- with Timer ("Convert to QImage" ):
516
- q_image = conv .convert (self .canvas_image_tensor , QImage )
584
+ with Timer ("Convert and Pass to CPU" ):
585
+ cpu_image_tensor = self ._get_cpu_image_tensor (self .gpu_canvas_image_tensor )
586
+
587
+ with Timer ("Write to CPU buffer" ):
588
+ self .cpu_canvas_image_tensor [:, :, :] = cpu_image_tensor
517
589
518
590
with Timer ("Update Canvas" ):
519
- self .canvas_view .update_image (q_image )
591
+ self .canvas_view .update_image (self . cpu_canvas_q_image )
520
592
521
593
if full :
522
594
self .dirty_region_full = None
0 commit comments