Skip to content

Commit a0c2299

Browse files
authored
Disable PEFT input autocast when using fp8 layerwise casting (#10685)
* disable peft input autocast * use new peft method name; only disable peft input autocast if submodule layerwise casting active * add test; reference PeftInputAutocastDisableHook in peft docs * add load_lora_weights test * casted -> cast * Update tests/lora/utils.py
1 parent 97abdd2 commit a0c2299

File tree

3 files changed

+151
-2
lines changed

3 files changed

+151
-2
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

+4
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,7 @@ pipe.delete_adapters("toy")
221221
pipe.get_active_adapters()
222222
["pixel"]
223223
```
224+
225+
## PeftInputAutocastDisableHook
226+
227+
[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook

src/diffusers/hooks/layerwise_casting.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
import torch
1919

20-
from ..utils import get_logger
20+
from ..utils import get_logger, is_peft_available, is_peft_version
2121
from .hooks import HookRegistry, ModelHook
2222

2323

2424
logger = get_logger(__name__) # pylint: disable=invalid-name
2525

2626

2727
# fmt: off
28+
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
29+
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
2830
SUPPORTED_PYTORCH_LAYERS = (
2931
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3032
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -34,6 +36,11 @@
3436
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
3537
# fmt: on
3638

39+
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
40+
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
41+
from peft.helpers import disable_input_dtype_casting
42+
from peft.tuners.tuners_utils import BaseTunerLayer
43+
3744

3845
class LayerwiseCastingHook(ModelHook):
3946
r"""
@@ -70,6 +77,32 @@ def post_forward(self, module: torch.nn.Module, output):
7077
return output
7178

7279

80+
class PeftInputAutocastDisableHook(ModelHook):
81+
r"""
82+
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
83+
casts the inputs to the weight dtype of the module, which can lead to precision loss.
84+
85+
The reasons for needing this are:
86+
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
87+
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
88+
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
89+
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
90+
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
91+
hoping to achieve:
92+
1. Making forward implementations independent of device/dtype casting operations as much as possible.
93+
2. Peforming inference without losing information from casting to different precisions. With the current
94+
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
95+
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
96+
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
97+
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
98+
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
99+
"""
100+
101+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
102+
with disable_input_dtype_casting(module):
103+
return self.fn_ref.original_forward(*args, **kwargs)
104+
105+
73106
def apply_layerwise_casting(
74107
module: torch.nn.Module,
75108
storage_dtype: torch.dtype,
@@ -134,6 +167,7 @@ def apply_layerwise_casting(
134167
skip_modules_classes,
135168
non_blocking,
136169
)
170+
_disable_peft_input_autocast(module)
137171

138172

139173
def _apply_layerwise_casting(
@@ -188,4 +222,24 @@ def apply_layerwise_casting_hook(
188222
"""
189223
registry = HookRegistry.check_if_exists_or_initialize(module)
190224
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
191-
registry.register_hook(hook, "layerwise_casting")
225+
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
226+
227+
228+
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
229+
for submodule in module.modules():
230+
if (
231+
hasattr(submodule, "_diffusers_hook")
232+
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
233+
):
234+
return True
235+
return False
236+
237+
238+
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
239+
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
240+
return
241+
for submodule in module.modules():
242+
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
243+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
244+
hook = PeftInputAutocastDisableHook()
245+
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)

tests/lora/utils.py

+91
Original file line numberDiff line numberDiff line change
@@ -2157,3 +2157,94 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21572157

21582158
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
21592159
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
2160+
2161+
@require_peft_version_greater("0.14.0")
2162+
def test_layerwise_casting_peft_input_autocast_denoiser(self):
2163+
r"""
2164+
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
2165+
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
2166+
cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
2167+
In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
2168+
this test will fail with the following error:
2169+
2170+
```
2171+
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
2172+
```
2173+
2174+
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
2175+
"""
2176+
2177+
from diffusers.hooks.layerwise_casting import (
2178+
_PEFT_AUTOCAST_DISABLE_HOOK,
2179+
DEFAULT_SKIP_MODULES_PATTERN,
2180+
SUPPORTED_PYTORCH_LAYERS,
2181+
apply_layerwise_casting,
2182+
)
2183+
2184+
storage_dtype = torch.float8_e4m3fn
2185+
compute_dtype = torch.float32
2186+
2187+
def check_module(denoiser):
2188+
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
2189+
for name, module in denoiser.named_modules():
2190+
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
2191+
continue
2192+
dtype_to_check = storage_dtype
2193+
if any(re.search(pattern, name) for pattern in patterns_to_check):
2194+
dtype_to_check = compute_dtype
2195+
if getattr(module, "weight", None) is not None:
2196+
self.assertEqual(module.weight.dtype, dtype_to_check)
2197+
if getattr(module, "bias", None) is not None:
2198+
self.assertEqual(module.bias.dtype, dtype_to_check)
2199+
if isinstance(module, BaseTunerLayer):
2200+
self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
2201+
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
2202+
2203+
# 1. Test forward with add_adapter
2204+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2205+
pipe = self.pipeline_class(**components)
2206+
pipe = pipe.to(torch_device, dtype=compute_dtype)
2207+
pipe.set_progress_bar_config(disable=None)
2208+
2209+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2210+
denoiser.add_adapter(denoiser_lora_config)
2211+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2212+
2213+
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2214+
if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
2215+
patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
2216+
2217+
apply_layerwise_casting(
2218+
denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
2219+
)
2220+
check_module(denoiser)
2221+
2222+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2223+
pipe(**inputs, generator=torch.manual_seed(0))[0]
2224+
2225+
# 2. Test forward with load_lora_weights
2226+
with tempfile.TemporaryDirectory() as tmpdirname:
2227+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2228+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2229+
self.pipeline_class.save_lora_weights(
2230+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2231+
)
2232+
2233+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
2234+
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
2235+
pipe = self.pipeline_class(**components)
2236+
pipe = pipe.to(torch_device, dtype=compute_dtype)
2237+
pipe.set_progress_bar_config(disable=None)
2238+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
2239+
2240+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2241+
apply_layerwise_casting(
2242+
denoiser,
2243+
storage_dtype=storage_dtype,
2244+
compute_dtype=compute_dtype,
2245+
skip_modules_pattern=patterns_to_check,
2246+
)
2247+
check_module(denoiser)
2248+
2249+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2250+
pipe(**inputs, generator=torch.manual_seed(0))[0]

0 commit comments

Comments
 (0)