Skip to content

Commit 7d19af2

Browse files
authored
Merge branch 'main' into lstein/feat/simple-mm2-api
2 parents fde58ce + 0dbec3a commit 7d19af2

12 files changed

+819
-699
lines changed
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import Any, Union
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
import torch
6+
7+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
8+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
9+
from invokeai.app.invocations.primitives import LatentsOutput
10+
from invokeai.app.services.shared.invocation_context import InvocationContext
11+
from invokeai.backend.util.devices import TorchDevice
12+
13+
14+
@invocation(
15+
"lblend",
16+
title="Blend Latents",
17+
tags=["latents", "blend"],
18+
category="latents",
19+
version="1.0.3",
20+
)
21+
class BlendLatentsInvocation(BaseInvocation):
22+
"""Blend two latents using a given alpha. Latents must have same size."""
23+
24+
latents_a: LatentsField = InputField(
25+
description=FieldDescriptions.latents,
26+
input=Input.Connection,
27+
)
28+
latents_b: LatentsField = InputField(
29+
description=FieldDescriptions.latents,
30+
input=Input.Connection,
31+
)
32+
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
33+
34+
def invoke(self, context: InvocationContext) -> LatentsOutput:
35+
latents_a = context.tensors.load(self.latents_a.latents_name)
36+
latents_b = context.tensors.load(self.latents_b.latents_name)
37+
38+
if latents_a.shape != latents_b.shape:
39+
raise Exception("Latents to blend must be the same size.")
40+
41+
device = TorchDevice.choose_torch_device()
42+
43+
def slerp(
44+
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
45+
v0: Union[torch.Tensor, npt.NDArray[Any]],
46+
v1: Union[torch.Tensor, npt.NDArray[Any]],
47+
DOT_THRESHOLD: float = 0.9995,
48+
) -> Union[torch.Tensor, npt.NDArray[Any]]:
49+
"""
50+
Spherical linear interpolation
51+
Args:
52+
t (float/np.ndarray): Float value between 0.0 and 1.0
53+
v0 (np.ndarray): Starting vector
54+
v1 (np.ndarray): Final vector
55+
DOT_THRESHOLD (float): Threshold for considering the two vectors as
56+
colineal. Not recommended to alter this.
57+
Returns:
58+
v2 (np.ndarray): Interpolation vector between v0 and v1
59+
"""
60+
inputs_are_torch = False
61+
if not isinstance(v0, np.ndarray):
62+
inputs_are_torch = True
63+
v0 = v0.detach().cpu().numpy()
64+
if not isinstance(v1, np.ndarray):
65+
inputs_are_torch = True
66+
v1 = v1.detach().cpu().numpy()
67+
68+
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
69+
if np.abs(dot) > DOT_THRESHOLD:
70+
v2 = (1 - t) * v0 + t * v1
71+
else:
72+
theta_0 = np.arccos(dot)
73+
sin_theta_0 = np.sin(theta_0)
74+
theta_t = theta_0 * t
75+
sin_theta_t = np.sin(theta_t)
76+
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
77+
s1 = sin_theta_t / sin_theta_0
78+
v2 = s0 * v0 + s1 * v1
79+
80+
if inputs_are_torch:
81+
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
82+
return v2_torch
83+
else:
84+
assert isinstance(v2, np.ndarray)
85+
return v2
86+
87+
# blend
88+
bl = slerp(self.alpha, latents_a, latents_b)
89+
assert isinstance(bl, torch.Tensor)
90+
blended_latents: torch.Tensor = bl # for type checking convenience
91+
92+
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
93+
blended_latents = blended_latents.to("cpu")
94+
95+
TorchDevice.empty_cache()
96+
97+
name = context.tensors.save(tensor=blended_latents)
98+
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torchvision.transforms as T
5+
from PIL import Image
6+
from torchvision.transforms.functional import resize as tv_resize
7+
8+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9+
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
10+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
11+
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
12+
from invokeai.app.invocations.model import VAEField
13+
from invokeai.app.invocations.primitives import DenoiseMaskOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
16+
17+
18+
@invocation(
19+
"create_denoise_mask",
20+
title="Create Denoise Mask",
21+
tags=["mask", "denoise"],
22+
category="latents",
23+
version="1.0.2",
24+
)
25+
class CreateDenoiseMaskInvocation(BaseInvocation):
26+
"""Creates mask for denoising model run."""
27+
28+
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
29+
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
30+
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
31+
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
32+
fp32: bool = InputField(
33+
default=DEFAULT_PRECISION == "float32",
34+
description=FieldDescriptions.fp32,
35+
ui_order=4,
36+
)
37+
38+
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
39+
if mask_image.mode != "L":
40+
mask_image = mask_image.convert("L")
41+
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
42+
if mask_tensor.dim() == 3:
43+
mask_tensor = mask_tensor.unsqueeze(0)
44+
# if shape is not None:
45+
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
46+
return mask_tensor
47+
48+
@torch.no_grad()
49+
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
50+
if self.image is not None:
51+
image = context.images.get_pil(self.image.image_name)
52+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
53+
if image_tensor.dim() == 3:
54+
image_tensor = image_tensor.unsqueeze(0)
55+
else:
56+
image_tensor = None
57+
58+
mask = self.prep_mask_tensor(
59+
context.images.get_pil(self.mask.image_name),
60+
)
61+
62+
if image_tensor is not None:
63+
vae_info = context.models.load(self.vae.vae)
64+
65+
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
66+
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
67+
# TODO:
68+
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
69+
70+
masked_latents_name = context.tensors.save(tensor=masked_latents)
71+
else:
72+
masked_latents_name = None
73+
74+
mask_name = context.tensors.save(tensor=mask)
75+
76+
return DenoiseMaskOutput.build(
77+
mask_name=mask_name,
78+
masked_latents_name=masked_latents_name,
79+
gradient=False,
80+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Literal, Optional
2+
3+
import numpy as np
4+
import torch
5+
import torchvision.transforms as T
6+
from PIL import Image, ImageFilter
7+
from torchvision.transforms.functional import resize as tv_resize
8+
9+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
10+
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
11+
from invokeai.app.invocations.fields import (
12+
DenoiseMaskField,
13+
FieldDescriptions,
14+
ImageField,
15+
Input,
16+
InputField,
17+
OutputField,
18+
)
19+
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
20+
from invokeai.app.invocations.model import UNetField, VAEField
21+
from invokeai.app.services.shared.invocation_context import InvocationContext
22+
from invokeai.backend.model_manager import LoadedModel
23+
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
24+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
25+
26+
27+
@invocation_output("gradient_mask_output")
28+
class GradientMaskOutput(BaseInvocationOutput):
29+
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
30+
31+
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
32+
expanded_mask_area: ImageField = OutputField(
33+
description="Image representing the total gradient area of the mask. For paste-back purposes."
34+
)
35+
36+
37+
@invocation(
38+
"create_gradient_mask",
39+
title="Create Gradient Mask",
40+
tags=["mask", "denoise"],
41+
category="latents",
42+
version="1.1.0",
43+
)
44+
class CreateGradientMaskInvocation(BaseInvocation):
45+
"""Creates mask for denoising model run."""
46+
47+
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
48+
edge_radius: int = InputField(
49+
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
50+
)
51+
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
52+
minimum_denoise: float = InputField(
53+
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
54+
)
55+
image: Optional[ImageField] = InputField(
56+
default=None,
57+
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
58+
title="[OPTIONAL] Image",
59+
ui_order=6,
60+
)
61+
unet: Optional[UNetField] = InputField(
62+
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
63+
default=None,
64+
input=Input.Connection,
65+
title="[OPTIONAL] UNet",
66+
ui_order=5,
67+
)
68+
vae: Optional[VAEField] = InputField(
69+
default=None,
70+
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
71+
title="[OPTIONAL] VAE",
72+
input=Input.Connection,
73+
ui_order=7,
74+
)
75+
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
76+
fp32: bool = InputField(
77+
default=DEFAULT_PRECISION == "float32",
78+
description=FieldDescriptions.fp32,
79+
ui_order=9,
80+
)
81+
82+
@torch.no_grad()
83+
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
84+
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
85+
if self.edge_radius > 0:
86+
if self.coherence_mode == "Box Blur":
87+
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
88+
else: # Gaussian Blur OR Staged
89+
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
90+
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
91+
92+
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
93+
94+
# redistribute blur so that the original edges are 0 and blur outwards to 1
95+
blur_tensor = (blur_tensor - 0.5) * 2
96+
97+
threshold = 1 - self.minimum_denoise
98+
99+
if self.coherence_mode == "Staged":
100+
# wherever the blur_tensor is less than fully masked, convert it to threshold
101+
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
102+
else:
103+
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
104+
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
105+
106+
else:
107+
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
108+
109+
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
110+
111+
# compute a [0, 1] mask from the blur_tensor
112+
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
113+
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
114+
expanded_image_dto = context.images.save(expanded_mask_image)
115+
116+
masked_latents_name = None
117+
if self.unet is not None and self.vae is not None and self.image is not None:
118+
# all three fields must be present at the same time
119+
main_model_config = context.models.get_config(self.unet.unet.key)
120+
assert isinstance(main_model_config, MainConfigBase)
121+
if main_model_config.variant is ModelVariantType.Inpaint:
122+
mask = blur_tensor
123+
vae_info: LoadedModel = context.models.load(self.vae.vae)
124+
image = context.images.get_pil(self.image.image_name)
125+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
126+
if image_tensor.dim() == 3:
127+
image_tensor = image_tensor.unsqueeze(0)
128+
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
129+
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
130+
masked_latents = ImageToLatentsInvocation.vae_encode(
131+
vae_info, self.fp32, self.tiled, masked_image.clone()
132+
)
133+
masked_latents_name = context.tensors.save(tensor=masked_latents)
134+
135+
return GradientMaskOutput(
136+
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
137+
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
138+
)
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
2+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
3+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
4+
from invokeai.app.invocations.primitives import LatentsOutput
5+
from invokeai.app.services.shared.invocation_context import InvocationContext
6+
7+
8+
# The Crop Latents node was copied from @skunkworxdark's implementation here:
9+
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
10+
@invocation(
11+
"crop_latents",
12+
title="Crop Latents",
13+
tags=["latents", "crop"],
14+
category="latents",
15+
version="1.0.2",
16+
)
17+
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
18+
# Currently, if the class names conflict then 'GET /openapi.json' fails.
19+
class CropLatentsCoreInvocation(BaseInvocation):
20+
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
21+
divisible by the latent scale factor of 8.
22+
"""
23+
24+
latents: LatentsField = InputField(
25+
description=FieldDescriptions.latents,
26+
input=Input.Connection,
27+
)
28+
x: int = InputField(
29+
ge=0,
30+
multiple_of=LATENT_SCALE_FACTOR,
31+
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
32+
)
33+
y: int = InputField(
34+
ge=0,
35+
multiple_of=LATENT_SCALE_FACTOR,
36+
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
37+
)
38+
width: int = InputField(
39+
ge=1,
40+
multiple_of=LATENT_SCALE_FACTOR,
41+
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
42+
)
43+
height: int = InputField(
44+
ge=1,
45+
multiple_of=LATENT_SCALE_FACTOR,
46+
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
47+
)
48+
49+
def invoke(self, context: InvocationContext) -> LatentsOutput:
50+
latents = context.tensors.load(self.latents.latents_name)
51+
52+
x1 = self.x // LATENT_SCALE_FACTOR
53+
y1 = self.y // LATENT_SCALE_FACTOR
54+
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
55+
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
56+
57+
cropped_latents = latents[..., y1:y2, x1:x2]
58+
59+
name = context.tensors.save(tensor=cropped_latents)
60+
61+
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

0 commit comments

Comments
 (0)