Skip to content

Commit e8cfd4b

Browse files

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed
 

‎library/config_util.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class BaseSubsetParams:
7878
caption_tag_dropout_rate: float = 0.0
7979
token_warmup_min: int = 1
8080
token_warmup_step: float = 0
81-
alpha_mask: bool = False
8281

8382

8483
@dataclass
@@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
8786
class_tokens: Optional[str] = None
8887
caption_extension: str = ".caption"
8988
cache_info: bool = False
89+
alpha_mask: bool = False
9090

9191

9292
@dataclass
9393
class FineTuningSubsetParams(BaseSubsetParams):
9494
metadata_file: Optional[str] = None
95+
alpha_mask: bool = False
9596

9697

9798
@dataclass

‎library/custom_train_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch):
484484
# conditioning image is -1 to 1. we need to convert it to 0 to 1
485485
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
486486
mask_image = mask_image / 2 + 0.5
487+
# print(f"conditioning_image: {mask_image.shape}")
487488
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
488489
# alpha mask is 0 to 1
489-
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
490+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
491+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
490492
else:
491493
return loss
492494

‎library/train_util.py

+12
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def __init__(
561561

562562
super().__init__(
563563
image_dir,
564+
False, # alpha_mask
564565
num_repeats,
565566
shuffle_caption,
566567
caption_separator,
@@ -1947,6 +1948,7 @@ def __init__(
19471948
None,
19481949
subset.caption_extension,
19491950
subset.cache_info,
1951+
False,
19501952
subset.num_repeats,
19511953
subset.shuffle_caption,
19521954
subset.caption_separator,
@@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
21962198
return False
21972199
if npz["alpha_mask"].shape[0:2] != reso: # HxW
21982200
return False
2201+
else:
2202+
if "alpha_mask" in npz:
2203+
return False
21992204
except Exception as e:
22002205
logger.error(f"Error loading file: {npz_path}")
22012206
raise e
@@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False):
22962301
if os.name == "nt":
22972302
cv2.imshow("cond_img", cond_img)
22982303

2304+
if "alpha_masks" in example and example["alpha_masks"] is not None:
2305+
alpha_mask = example["alpha_masks"][j]
2306+
logger.info(f"alpha mask size: {alpha_mask.size()}")
2307+
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
2308+
if os.name == "nt":
2309+
cv2.imshow("alpha_mask", alpha_mask)
2310+
22992311
if os.name == "nt": # only windows
23002312
cv2.imshow("img", im)
23012313
k = cv2.waitKey()

0 commit comments

Comments
 (0)
Please sign in to comment.