Skip to content

Commit 85c9947

Browse files
committed
Made the QImage and CPU tensor share memory so that when one is updated the other is as well. Also made it so the full preview only updates the dirty region.
1 parent fdcdba8 commit 85c9947

File tree

2 files changed

+97
-20
lines changed

2 files changed

+97
-20
lines changed

common.py

+5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def transform_bounds(self, other: 'Bounds2D'):
108108
)
109109

110110
def __eq__(self, other):
111+
if not hasattr(other, "x_bounds"):
112+
return False
113+
if not hasattr(other, "y_bounds"):
114+
return False
115+
111116
return self.x_bounds == other.x_bounds and self.y_bounds == other.y_bounds
112117

113118
def _get_span(self):

ui.py

+92-20
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545

4646

4747
class DiffusionCanvasWindow(QMainWindow):
48-
canvas_image_tensor: torch.Tensor | None
48+
gpu_canvas_image_tensor: torch.Tensor | None
49+
cpu_canvas_image_tensor: torch.Tensor
50+
cpu_canvas_q_image: QImage
4951
show_noisy: bool
5052
dirty_region_full: Bounds2D | None
5153
dirty_region_quick: Bounds2D | None
@@ -172,13 +174,74 @@ def initialize_canvas(self, layer: Layer):
172174
self.show_noisy = False
173175
self.dirty_region_full: Bounds2D | None = None
174176
self.dirty_region_quick: Bounds2D | None = None
175-
self.canvas_image_tensor = None
177+
self.gpu_canvas_image_tensor = None
178+
179+
image_size = (
180+
layer.clean_latent.shape[2] * latent_size_in_pixels,
181+
layer.clean_latent.shape[3] * latent_size_in_pixels
182+
)
183+
184+
# Create a numpy array as the backing store
185+
numpy_buffer = np.zeros((image_size[0], image_size[1], 4), dtype=np.uint8) # RGBA format
186+
187+
# Create a QImage using the numpy buffer
188+
self.cpu_canvas_q_image = QImage(
189+
numpy_buffer.data, # Pointer to the data
190+
image_size[1], # width
191+
image_size[0], # height
192+
QImage.Format.Format_RGB32 # Format
193+
)
194+
195+
# Create a PyTorch tensor that shares the same memory
196+
self.cpu_canvas_image_tensor = torch.from_numpy(numpy_buffer)
197+
176198
self.history = History(layer)
177199
self.create_undo = True
178200

179201
# Update the display with the new canvas
180202
self.update_canvas_view(noisy=False, full=True)
181203

204+
@staticmethod
205+
@torch.no_grad()
206+
def _get_cpu_image_tensor(tensor: torch.Tensor, add_alpha: bool = True):
207+
"""
208+
Convert a tensor from the format used by stable diffusion VAE decoders
209+
to a format used by QImage.
210+
211+
Args:
212+
tensor (torch.Tensor): Input tensor with shape (1, 3, height, width).
213+
add_alpha (bool): Whether to add a dummy alpha channel for QImage Format_RGB32.
214+
215+
Returns:
216+
torch.Tensor: Tensor with shape (height, width, 4) or (height, width, 3).
217+
"""
218+
# Ensure batch size is 1 and remove it
219+
assert tensor.shape[0] == 1, "Tensor batch size must be 1."
220+
221+
tensor = tensor.squeeze(0) # Shape: (RGB, height, width)
222+
223+
# Rearrange channels to BGR if needed for QImage
224+
tensor = tensor[[2, 1, 0], :, :] # Shape: (BGR, height, width)
225+
226+
# Permute to (height, width, channels)
227+
tensor = tensor.permute(1, 2, 0) # Shape: (height, width, BGR)
228+
229+
# Map and clamp range (0, 1) to (0, 255)
230+
tensor = (tensor * 255).clamp(0, 255)
231+
232+
# Add a dummy alpha channel if required
233+
if add_alpha:
234+
alpha_channel = torch.full(
235+
(tensor.shape[0], tensor.shape[1], 1),
236+
255,
237+
dtype=tensor.dtype, # Match dtype of tensor (still likely float16 or float32)
238+
device=tensor.device
239+
)
240+
tensor = torch.cat((tensor, alpha_channel), dim=2) # Shape: (height, width, BGRA)
241+
242+
# Convert to uint8 on the CPU as the final step
243+
return tensor.to(dtype=torch.uint8, device='cpu')
244+
182245
def closeEvent(self, event):
183246
with ExceptionCatcher(self, "Failed to handle close event"):
184247
"""
@@ -327,7 +390,7 @@ def update_frame(self):
327390
self.full_preview_timer = 0
328391

329392
if self.showing_quick_preview:
330-
self.update_canvas_view(full=True)
393+
self.update_canvas_view(full=True, region=None)
331394

332395
def canvas_mousePressEvent(self, event):
333396
with ExceptionCatcher(self, "Failed to handle mouse event"):
@@ -374,7 +437,7 @@ def apply_brush(self, event: QMouseEvent):
374437
normalized_mouse_coord=normalized_position
375438
)
376439

377-
self.update_canvas_view(noisy=result.show_noisy, modified_bounds=result.modified_bounds, full=False)
440+
self.update_canvas_view(noisy=result.show_noisy, region=result.modified_bounds, full=False)
378441

379442
def keyPressEvent(self, event: QKeyEvent):
380443
with ExceptionCatcher(self, "Failed to handle key press event"):
@@ -409,10 +472,9 @@ def redo(self):
409472
self.history.redo(1)
410473
self.update_canvas_view(full=False)
411474

412-
def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2D | None = None, full: bool = True):
475+
def update_canvas_view(self, noisy: bool | None = None, region: Bounds2D | str | None = 'all', full: bool = True):
413476

414477
from utils.time_utils import Timer
415-
import utils.texture_convert as conv
416478

417479
if isinstance(noisy, bool):
418480
self.show_noisy = noisy
@@ -434,19 +496,26 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
434496
else:
435497
self.showing_quick_preview = False
436498

437-
if modified_bounds is None:
438-
self.dirty_region_quick = full_bounds
439-
self.dirty_region_full = full_bounds
440-
else:
499+
if region is not None:
500+
if region == 'all':
501+
region = full_bounds
502+
503+
if isinstance(region, str):
504+
region = None
505+
506+
if not isinstance(region, Bounds2D):
507+
region = None
508+
509+
if isinstance(region, Bounds2D):
441510
self.dirty_region_quick = (
442-
modified_bounds
511+
region
443512
if self.dirty_region_quick is None
444-
else self.dirty_region_quick.get_encapsulated(modified_bounds)
513+
else self.dirty_region_quick.get_encapsulated(region)
445514
)
446515
self.dirty_region_full = (
447-
modified_bounds
516+
region
448517
if self.dirty_region_full is None
449-
else self.dirty_region_full.get_encapsulated(modified_bounds)
518+
else self.dirty_region_full.get_encapsulated(region)
450519
)
451520

452521
region_to_redraw = (
@@ -462,8 +531,8 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
462531
region_to_redraw = region_to_redraw.get_clipped(full_bounds)
463532
region_to_redraw_with_padding = region_to_redraw_with_padding.get_clipped(full_bounds)
464533

465-
if region_to_redraw_with_padding == full_bounds or self.canvas_image_tensor is None:
466-
self.canvas_image_tensor = self.api.latent_to_image_tiled(
534+
if region_to_redraw_with_padding == full_bounds or self.gpu_canvas_image_tensor is None:
535+
self.gpu_canvas_image_tensor = self.api.latent_to_image_tiled(
467536
latent_to_show,
468537
max_tile_size_latents=64,
469538
overlap_size_latents=8,
@@ -504,19 +573,22 @@ def update_canvas_view(self, noisy: bool | None = None, modified_bounds: Bounds2
504573
]
505574

506575
with Timer("Write to Image Tensor"):
507-
self.canvas_image_tensor[
576+
self.gpu_canvas_image_tensor[
508577
:, :,
509578
region_to_redraw.y_bounds[0] * latent_size_in_pixels:
510579
region_to_redraw.y_bounds[1] * latent_size_in_pixels,
511580
region_to_redraw.x_bounds[0] * latent_size_in_pixels:
512581
region_to_redraw.x_bounds[1] * latent_size_in_pixels
513582
] = decoded_trimmed
514583

515-
with Timer("Convert to QImage"):
516-
q_image = conv.convert(self.canvas_image_tensor, QImage)
584+
with Timer("Convert and Pass to CPU"):
585+
cpu_image_tensor = self._get_cpu_image_tensor(self.gpu_canvas_image_tensor)
586+
587+
with Timer("Write to CPU buffer"):
588+
self.cpu_canvas_image_tensor[:, :, :] = cpu_image_tensor
517589

518590
with Timer("Update Canvas"):
519-
self.canvas_view.update_image(q_image)
591+
self.canvas_view.update_image(self.cpu_canvas_q_image)
520592

521593
if full:
522594
self.dirty_region_full = None

0 commit comments

Comments
 (0)