@@ -2157,3 +2157,94 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
2157
2157
2158
2158
pipe_float8_e4m3_bf16 = initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16 )
2159
2159
pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2160
+
2161
+ @require_peft_version_greater ("0.14.0" )
2162
+ def test_layerwise_casting_peft_input_autocast_denoiser (self ):
2163
+ r"""
2164
+ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
2165
+ is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
2166
+ cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
2167
+ In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
2168
+ this test will fail with the following error:
2169
+
2170
+ ```
2171
+ RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
2172
+ ```
2173
+
2174
+ See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
2175
+ """
2176
+
2177
+ from diffusers .hooks .layerwise_casting import (
2178
+ _PEFT_AUTOCAST_DISABLE_HOOK ,
2179
+ DEFAULT_SKIP_MODULES_PATTERN ,
2180
+ SUPPORTED_PYTORCH_LAYERS ,
2181
+ apply_layerwise_casting ,
2182
+ )
2183
+
2184
+ storage_dtype = torch .float8_e4m3fn
2185
+ compute_dtype = torch .float32
2186
+
2187
+ def check_module (denoiser ):
2188
+ # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
2189
+ for name , module in denoiser .named_modules ():
2190
+ if not isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
2191
+ continue
2192
+ dtype_to_check = storage_dtype
2193
+ if any (re .search (pattern , name ) for pattern in patterns_to_check ):
2194
+ dtype_to_check = compute_dtype
2195
+ if getattr (module , "weight" , None ) is not None :
2196
+ self .assertEqual (module .weight .dtype , dtype_to_check )
2197
+ if getattr (module , "bias" , None ) is not None :
2198
+ self .assertEqual (module .bias .dtype , dtype_to_check )
2199
+ if isinstance (module , BaseTunerLayer ):
2200
+ self .assertTrue (getattr (module , "_diffusers_hook" , None ) is not None )
2201
+ self .assertTrue (module ._diffusers_hook .get_hook (_PEFT_AUTOCAST_DISABLE_HOOK ) is not None )
2202
+
2203
+ # 1. Test forward with add_adapter
2204
+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2205
+ pipe = self .pipeline_class (** components )
2206
+ pipe = pipe .to (torch_device , dtype = compute_dtype )
2207
+ pipe .set_progress_bar_config (disable = None )
2208
+
2209
+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2210
+ denoiser .add_adapter (denoiser_lora_config )
2211
+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2212
+
2213
+ patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2214
+ if getattr (denoiser , "_skip_layerwise_casting_patterns" , None ) is not None :
2215
+ patterns_to_check += tuple (denoiser ._skip_layerwise_casting_patterns )
2216
+
2217
+ apply_layerwise_casting (
2218
+ denoiser , storage_dtype = storage_dtype , compute_dtype = compute_dtype , skip_modules_pattern = patterns_to_check
2219
+ )
2220
+ check_module (denoiser )
2221
+
2222
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2223
+ pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2224
+
2225
+ # 2. Test forward with load_lora_weights
2226
+ with tempfile .TemporaryDirectory () as tmpdirname :
2227
+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2228
+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2229
+ self .pipeline_class .save_lora_weights (
2230
+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
2231
+ )
2232
+
2233
+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
2234
+ components , _ , _ = self .get_dummy_components (self .scheduler_classes [0 ])
2235
+ pipe = self .pipeline_class (** components )
2236
+ pipe = pipe .to (torch_device , dtype = compute_dtype )
2237
+ pipe .set_progress_bar_config (disable = None )
2238
+ pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
2239
+
2240
+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2241
+ apply_layerwise_casting (
2242
+ denoiser ,
2243
+ storage_dtype = storage_dtype ,
2244
+ compute_dtype = compute_dtype ,
2245
+ skip_modules_pattern = patterns_to_check ,
2246
+ )
2247
+ check_module (denoiser )
2248
+
2249
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2250
+ pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments