From a99663d37a5e20f843394aa8b28a8930955650ea Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 20 May 2025 18:24:38 +0200 Subject: [PATCH 1/3] load tensors on cuda --- src/diffusers/models/modeling_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..a6b00e8d11ca 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1438,8 +1438,17 @@ def _load_pretrained_model( if len(resolved_model_file) > 1: resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + map_location = "cpu" + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + for shard_file in resolved_model_file: - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location) def _find_mismatched_keys( state_dict, From cad495446dd484b8d85a0604df6a8e9572f241cc Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 20 May 2025 18:39:46 +0200 Subject: [PATCH 2/3] quick fix --- src/diffusers/models/modeling_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a6b00e8d11ca..3cbd09a3de18 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1185,8 +1185,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P state_dict = None if not is_sharded: + map_location = "cpu" + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) # Time to load the checkpoint - state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries) + state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location) # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict) @@ -1443,10 +1451,9 @@ def _load_pretrained_model( device_map is not None and hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"] + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] ): map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) - for shard_file in resolved_model_file: state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location) From 93c38d2094ce79c941aaf7c79977ad9a05eed2d4 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 20 May 2025 18:43:52 +0200 Subject: [PATCH 3/3] style --- src/diffusers/models/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 3cbd09a3de18..6ab660192917 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1194,7 +1194,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ): map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) # Time to load the checkpoint - state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location) + state_dict = load_state_dict( + resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location + ) # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. model._fix_state_dict_keys_on_load(state_dict)