11
11
12
12
import torch
13
13
from library .device_utils import init_ipex , get_preferred_device
14
+
14
15
init_ipex ()
15
16
16
17
from torchvision import transforms
17
18
18
19
import library .model_util as model_util
19
20
import library .train_util as train_util
20
21
from library .utils import setup_logging
22
+
21
23
setup_logging ()
22
24
import logging
25
+
23
26
logger = logging .getLogger (__name__ )
24
27
25
28
DEVICE = get_preferred_device ()
@@ -89,7 +92,9 @@ def main(args):
89
92
90
93
# bucketのサイズを計算する
91
94
max_reso = tuple ([int (t ) for t in args .max_resolution .split ("," )])
92
- assert len (max_reso ) == 2 , f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: { args .max_resolution } "
95
+ assert (
96
+ len (max_reso ) == 2
97
+ ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: { args .max_resolution } "
93
98
94
99
bucket_manager = train_util .BucketManager (
95
100
args .bucket_no_upscale , max_reso , args .min_bucket_reso , args .max_bucket_reso , args .bucket_reso_steps
@@ -107,7 +112,7 @@ def main(args):
107
112
def process_batch (is_last ):
108
113
for bucket in bucket_manager .buckets :
109
114
if (is_last and len (bucket ) > 0 ) or len (bucket ) >= args .batch_size :
110
- train_util .cache_batch_latents (vae , True , bucket , args .flip_aug , False )
115
+ train_util .cache_batch_latents (vae , True , bucket , args .flip_aug , args . alpha_mask , False )
111
116
bucket .clear ()
112
117
113
118
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
208
213
parser .add_argument ("in_json" , type = str , help = "metadata file to input / 読み込むメタデータファイル" )
209
214
parser .add_argument ("out_json" , type = str , help = "metadata file to output / メタデータファイル書き出し先" )
210
215
parser .add_argument ("model_name_or_path" , type = str , help = "model name or path to encode latents / latentを取得するためのモデル" )
211
- parser .add_argument ("--v2" , action = "store_true" , help = "not used (for backward compatibility) / 使用されません(互換性のため残してあります)" )
216
+ parser .add_argument (
217
+ "--v2" , action = "store_true" , help = "not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
218
+ )
212
219
parser .add_argument ("--batch_size" , type = int , default = 1 , help = "batch size in inference / 推論時のバッチサイズ" )
213
220
parser .add_argument (
214
221
"--max_data_loader_n_workers" ,
@@ -231,18 +238,32 @@ def setup_parser() -> argparse.ArgumentParser:
231
238
help = "steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します" ,
232
239
)
233
240
parser .add_argument (
234
- "--bucket_no_upscale" , action = "store_true" , help = "make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
241
+ "--bucket_no_upscale" ,
242
+ action = "store_true" ,
243
+ help = "make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ,
235
244
)
236
245
parser .add_argument (
237
- "--mixed_precision" , type = str , default = "no" , choices = ["no" , "fp16" , "bf16" ], help = "use mixed precision / 混合精度を使う場合、その精度"
246
+ "--mixed_precision" ,
247
+ type = str ,
248
+ default = "no" ,
249
+ choices = ["no" , "fp16" , "bf16" ],
250
+ help = "use mixed precision / 混合精度を使う場合、その精度" ,
238
251
)
239
252
parser .add_argument (
240
253
"--full_path" ,
241
254
action = "store_true" ,
242
255
help = "use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)" ,
243
256
)
244
257
parser .add_argument (
245
- "--flip_aug" , action = "store_true" , help = "flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
258
+ "--flip_aug" ,
259
+ action = "store_true" ,
260
+ help = "flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" ,
261
+ )
262
+ parser .add_argument (
263
+ "--alpha_mask" ,
264
+ type = str ,
265
+ default = "" ,
266
+ help = "save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する" ,
246
267
)
247
268
parser .add_argument (
248
269
"--skip_existing" ,
0 commit comments