Skip to content

Commit 56b4ea9

Browse files
committed
Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722
1 parent 9c757c2 commit 56b4ea9

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
137137

138138
## Change History
139139

140+
### Oct 26, 2024 / 2024-10-26:
141+
142+
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
143+
- It will be included in the next release.
144+
145+
- `svd_merge_lora.py``sdxl_merge_lora.py``resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
146+
- 以上は次回リリースに含まれます。
147+
140148
### Sep 13, 2024 / 2024-09-13:
141149

142150
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).

networks/resize_lora.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype):
3939
return sd, metadata
4040

4141

42-
def save_to_file(file_name, state_dict, dtype, metadata):
43-
if dtype is not None:
44-
for key in list(state_dict.keys()):
45-
if type(state_dict[key]) == torch.Tensor:
46-
state_dict[key] = state_dict[key].to(dtype)
47-
42+
def save_to_file(file_name, state_dict, metadata):
4843
if model_util.is_safetensors(file_name):
4944
save_file(state_dict, file_name, metadata)
5045
else:
@@ -349,12 +344,18 @@ def str_to_dtype(p):
349344
metadata["ss_network_dim"] = "Dynamic"
350345
metadata["ss_network_alpha"] = "Dynamic"
351346

347+
# cast to save_dtype before calculating hashes
348+
for key in list(state_dict.keys()):
349+
value = state_dict[key]
350+
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
351+
state_dict[key] = value.to(save_dtype)
352+
352353
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
353354
metadata["sshs_model_hash"] = model_hash
354355
metadata["sshs_legacy_hash"] = legacy_hash
355356

356357
logger.info(f"saving model to: {args.save_to}")
357-
save_to_file(args.save_to, state_dict, save_dtype, metadata)
358+
save_to_file(args.save_to, state_dict, metadata)
358359

359360

360361
def setup_parser() -> argparse.ArgumentParser:

networks/sdxl_merge_lora.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype):
3535
return sd, metadata
3636

3737

38-
def save_to_file(file_name, model, state_dict, dtype, metadata):
39-
if dtype is not None:
40-
for key in list(state_dict.keys()):
41-
if type(state_dict[key]) == torch.Tensor:
42-
state_dict[key] = state_dict[key].to(dtype)
43-
38+
def save_to_file(file_name, model, metadata):
4439
if os.path.splitext(file_name)[1] == ".safetensors":
4540
save_file(model, file_name, metadata=metadata)
4641
else:
@@ -430,6 +425,12 @@ def str_to_dtype(p):
430425
else:
431426
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
432427

428+
# cast to save_dtype before calculating hashes
429+
for key in list(state_dict.keys()):
430+
value = state_dict[key]
431+
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
432+
state_dict[key] = value.to(save_dtype)
433+
433434
logger.info(f"calculating hashes and creating metadata...")
434435

435436
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
@@ -445,7 +446,7 @@ def str_to_dtype(p):
445446
metadata.update(sai_metadata)
446447

447448
logger.info(f"saving model to: {args.save_to}")
448-
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
449+
save_to_file(args.save_to, state_dict, metadata)
449450

450451

451452
def setup_parser() -> argparse.ArgumentParser:

networks/svd_merge_lora.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,7 @@ def load_state_dict(file_name, dtype):
216216
return sd, metadata
217217

218218

219-
def save_to_file(file_name, state_dict, dtype, metadata):
220-
if dtype is not None:
221-
for key in list(state_dict.keys()):
222-
if type(state_dict[key]) == torch.Tensor:
223-
state_dict[key] = state_dict[key].to(dtype)
224-
219+
def save_to_file(file_name, state_dict, metadata):
225220
if os.path.splitext(file_name)[1] == ".safetensors":
226221
save_file(state_dict, file_name, metadata=metadata)
227222
else:
@@ -430,6 +425,12 @@ def str_to_dtype(p):
430425
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
431426
)
432427

428+
# cast to save_dtype before calculating hashes
429+
for key in list(state_dict.keys()):
430+
value = state_dict[key]
431+
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
432+
state_dict[key] = value.to(save_dtype)
433+
433434
logger.info(f"calculating hashes and creating metadata...")
434435

435436
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
@@ -451,7 +452,7 @@ def str_to_dtype(p):
451452
metadata.update(sai_metadata)
452453

453454
logger.info(f"saving model to: {args.save_to}")
454-
save_to_file(args.save_to, state_dict, save_dtype, metadata)
455+
save_to_file(args.save_to, state_dict, metadata)
455456

456457

457458
def setup_parser() -> argparse.ArgumentParser:

0 commit comments

Comments
 (0)