Skip to content

fix bf16 model , conv2d 3x3, >320 dim close #127 #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ Composable LoRA はサブプロンプトごとに LoRA の適用有無を切り

## Change History

- 9 Mar. 2023, 2023/3/9: Release v0.5.1
- Fix the model saved with `bf16` causes an error. https://github.com/kohya-ss/sd-webui-additional-networks/issues/127
- Fix some Conv2d-3x3 LoRA modules are not effective. https://github.com/kohya-ss/sd-scripts/issues/275
- Fix LoRA modules with higher dim (rank) > 320 causes an error.
- `bf16` で学習されたモデルが読み込めない不具合を修正しました。 https://github.com/kohya-ss/sd-webui-additional-networks/issues/127
- いくつかの Conv2d-3x3 LoRA モジュールが有効にならない不具合を修正しました。 https://github.com/kohya-ss/sd-scripts/issues/275
- dim (rank) が 320 を超えるLoRAモデルが読み込めない不具合を修正しました。
- 8 Mar. 2023, 2023/3/8: Release v0.5.0
- Support current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
- LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
Expand Down
18 changes: 9 additions & 9 deletions scripts/lora_compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
in_dim = org_module.in_channels
out_dim = org_module.out_channels

self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# self.lora_dim = min(self.lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")

kernel_size = org_module.kernel_size
stride = org_module.stride
Expand All @@ -46,12 +46,12 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)

if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える

Expand Down Expand Up @@ -128,7 +128,7 @@ def create_network_and_apply_compvis(du_state_dict, multiplier_tenc, multiplier_

lora_name = key.split('.')[0]
if 'alpha' in key:
modules_alpha[lora_name] = float(value.detach().cpu().numpy())
modules_alpha[lora_name] = float(value.detach().to(torch.float).cpu().numpy())
elif 'lora_down' in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
Expand Down Expand Up @@ -346,7 +346,7 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules,
if '_resblocks_23_' in lora_name: # ignore last block in StabilityAi Text Encoder
break
if lora_name not in comp_vis_loras_dim_alpha:
break
continue

dim, alpha = comp_vis_loras_dim_alpha[lora_name]
lora = LoRAModule(lora_name, child_module, multiplier, dim, alpha)
Expand All @@ -363,7 +363,7 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules,

lora_name = module_name + '_' + suffix
if lora_name not in comp_vis_loras_dim_alpha:
break
continue
dim, alpha = comp_vis_loras_dim_alpha[lora_name]
lora_info = LoRAInfo(lora_name, module_name, child_module, multiplier, dim, alpha)
loras.append(lora_info)
Expand Down