Skip to content

Torchao int4 serialization #11591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,8 +1185,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

state_dict = None
if not is_sharded:
map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sufficiently safe to say that we would always have a non-None device_map?

Also, what happens if the device_map has multiple CUDA devices specified? Would the indexing make sense there?

Okay for this PR but we could potentially have a resolve_map_location() per quantizer class, maybe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sufficiently safe to say that we would always have a non-None device_map?

I check that the device_map is not None. Also this should be safe enough. I took that from transformers. There shouldn't be an issue with the indexing, in any case we will move again the tensors if they are multiple index.

Yeah I can switch to update_map_location.

# Time to load the checkpoint
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
state_dict = load_state_dict(
resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location
)
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
model._fix_state_dict_keys_on_load(state_dict)

Expand Down Expand Up @@ -1438,8 +1448,16 @@ def _load_pretrained_model(
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

map_location = "cpu"
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location)

def _find_mismatched_keys(
state_dict,
Expand Down