Skip to content

Commit fe2aa32

Browse files
committed
adjust min/max bucket reso divisible by reso steps #1632
1 parent ce49ced commit fe2aa32

9 files changed

+48
-8
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
143143
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
144144
- If you encounter any issues, please report them.
145145

146+
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
147+
146148
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
147149
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
148150

docs/config_README-en.md

+2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d
128128

129129
* `batch_size`
130130
* This corresponds to the command-line argument `--train_batch_size`.
131+
* `max_bucket_reso`, `min_bucket_reso`
132+
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
131133

132134
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
133135

docs/config_README-ja.md

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
118118

119119
* `batch_size`
120120
* コマンドライン引数の `--train_batch_size` と同等です。
121+
* `max_bucket_reso`, `min_bucket_reso`
122+
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
121123

122124
これらの設定はデータセットごとに固定です。
123125
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。

fine_tune.py

+2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def train(args):
9191
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
9292
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
9393

94+
train_dataset_group.verify_bucket_reso_steps(64)
95+
9496
if args.debug_dataset:
9597
train_util.debug_dataset(train_dataset_group)
9698
return

library/train_util.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,34 @@ def __init__(
653653
# caching
654654
self.caching_mode = None # None, 'latents', 'text'
655655

656+
def adjust_min_max_bucket_reso_by_steps(
657+
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
658+
) -> Tuple[int, int]:
659+
# make min/max bucket reso to be multiple of bucket_reso_steps
660+
if min_bucket_reso % bucket_reso_steps != 0:
661+
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
662+
logger.warning(
663+
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
664+
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
665+
)
666+
min_bucket_reso = adjusted_min_bucket_reso
667+
if max_bucket_reso % bucket_reso_steps != 0:
668+
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
669+
logger.warning(
670+
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
671+
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
672+
)
673+
max_bucket_reso = adjusted_max_bucket_reso
674+
675+
assert (
676+
min(resolution) >= min_bucket_reso
677+
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
678+
assert (
679+
max(resolution) <= max_bucket_reso
680+
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
681+
682+
return min_bucket_reso, max_bucket_reso
683+
656684
def set_seed(self, seed):
657685
self.seed = seed
658686

@@ -1533,12 +1561,9 @@ def __init__(
15331561

15341562
self.enable_bucket = enable_bucket
15351563
if self.enable_bucket:
1536-
assert (
1537-
min(resolution) >= min_bucket_reso
1538-
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
1539-
assert (
1540-
max(resolution) <= max_bucket_reso
1541-
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
1564+
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
1565+
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
1566+
)
15421567
self.min_bucket_reso = min_bucket_reso
15431568
self.max_bucket_reso = max_bucket_reso
15441569
self.bucket_reso_steps = bucket_reso_steps
@@ -1901,6 +1926,9 @@ def __init__(
19011926

19021927
self.enable_bucket = enable_bucket
19031928
if self.enable_bucket:
1929+
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
1930+
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
1931+
)
19041932
self.min_bucket_reso = min_bucket_reso
19051933
self.max_bucket_reso = max_bucket_reso
19061934
self.bucket_reso_steps = bucket_reso_steps

train_controlnet.py

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def train(args):
107107
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
108108
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
109109

110+
train_dataset_group.verify_bucket_reso_steps(64)
111+
110112
if args.debug_dataset:
111113
train_util.debug_dataset(train_dataset_group)
112114
return

train_db.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def train(args):
9393
if args.no_token_padding:
9494
train_dataset_group.disable_token_padding()
9595

96+
train_dataset_group.verify_bucket_reso_steps(64)
97+
9698
if args.debug_dataset:
9799
train_util.debug_dataset(train_dataset_group)
98100
return

train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def generate_step_logs(
9595
return logs
9696

9797
def assert_extra_args(self, args, train_dataset_group):
98-
pass
98+
train_dataset_group.verify_bucket_reso_steps(64)
9999

100100
def load_target_model(self, args, weight_dtype, accelerator):
101101
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

train_textual_inversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self):
9999
self.is_sdxl = False
100100

101101
def assert_extra_args(self, args, train_dataset_group):
102-
pass
102+
train_dataset_group.verify_bucket_reso_steps(64)
103103

104104
def load_target_model(self, args, weight_dtype, accelerator):
105105
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

0 commit comments

Comments
 (0)