Skip to content

Commit da6fea3

Browse files
committed
simplify and update alpha mask to work with various cases
1 parent f2dd43e commit da6fea3

10 files changed

+139
-104
lines changed

finetune/prepare_buckets_latents.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111

1212
import torch
1313
from library.device_utils import init_ipex, get_preferred_device
14+
1415
init_ipex()
1516

1617
from torchvision import transforms
1718

1819
import library.model_util as model_util
1920
import library.train_util as train_util
2021
from library.utils import setup_logging
22+
2123
setup_logging()
2224
import logging
25+
2326
logger = logging.getLogger(__name__)
2427

2528
DEVICE = get_preferred_device()
@@ -89,7 +92,9 @@ def main(args):
8992

9093
# bucketのサイズを計算する
9194
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
92-
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
95+
assert (
96+
len(max_reso) == 2
97+
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
9398

9499
bucket_manager = train_util.BucketManager(
95100
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -107,7 +112,7 @@ def main(args):
107112
def process_batch(is_last):
108113
for bucket in bucket_manager.buckets:
109114
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
110-
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
115+
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
111116
bucket.clear()
112117

113118
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
208213
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
209214
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
210215
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
211-
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
216+
parser.add_argument(
217+
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
218+
)
212219
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
213220
parser.add_argument(
214221
"--max_data_loader_n_workers",
@@ -231,18 +238,32 @@ def setup_parser() -> argparse.ArgumentParser:
231238
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
232239
)
233240
parser.add_argument(
234-
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
241+
"--bucket_no_upscale",
242+
action="store_true",
243+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
235244
)
236245
parser.add_argument(
237-
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
246+
"--mixed_precision",
247+
type=str,
248+
default="no",
249+
choices=["no", "fp16", "bf16"],
250+
help="use mixed precision / 混合精度を使う場合、その精度",
238251
)
239252
parser.add_argument(
240253
"--full_path",
241254
action="store_true",
242255
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
243256
)
244257
parser.add_argument(
245-
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
258+
"--flip_aug",
259+
action="store_true",
260+
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
261+
)
262+
parser.add_argument(
263+
"--alpha_mask",
264+
type=str,
265+
default="",
266+
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
246267
)
247268
parser.add_argument(
248269
"--skip_existing",

library/config_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
214214
DB_SUBSET_DISTINCT_SCHEMA = {
215215
Required("image_dir"): str,
216216
"is_reg": bool,
217+
"alpha_mask": bool,
217218
}
218219
# FT means FineTuning
219220
FT_SUBSET_DISTINCT_SCHEMA = {
220221
Required("metadata_file"): str,
221222
"image_dir": str,
223+
"alpha_mask": bool,
222224
}
223225
CN_SUBSET_ASCENDABLE_SCHEMA = {
224226
"caption_extension": str,

library/custom_train_functions.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
479479
return noise
480480

481481

482-
def apply_masked_loss(loss, mask_image):
483-
# mask image is -1 to 1. we need to convert it to 0 to 1
484-
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
485-
mask_image = mask_image.to(dtype=loss.dtype)
482+
def apply_masked_loss(loss, batch):
483+
if "conditioning_images" in batch:
484+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
485+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
486+
mask_image = mask_image / 2 + 0.5
487+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
488+
# alpha mask is 0 to 1
489+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
490+
else:
491+
return loss
486492

487493
# resize to the same size as the loss
488494
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
489-
mask_image = mask_image / 2 + 0.5
490495
loss = loss * mask_image
491496
return loss
492497

0 commit comments

Comments
 (0)