Skip to content

testing torchao config migration #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions 20250212_torchao_migration_test/test_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Test that https://github.com/pytorch/ao/issues/1690 does not break HF
"""

import fire

import torch
import torchao
import transformers

def run():
print(f"torch version: {torch.__version__}")
print(f"torchao version: {torchao.__version__}")
print(f"transformers version: {transformers.__version__}")

# test code copy-pasted from
# https://huggingface.co/docs/transformers/main/en/quantization/torchao

from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
# quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))

# benchmark the performance
import torch.utils.benchmark as benchmark

def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)

t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"

MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

pass

if __name__ == '__main__':
fire.Fire(run)
50 changes: 50 additions & 0 deletions 20250212_torchao_migration_test/test_hf_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Test that https://huggingface.co/docs/diffusers/en/quantization/torchao is not
broken by https://github.com/pytorch/ao/issues/1690
"""

import fire

def run():
# copy-pasted from https://huggingface.co/docs/diffusers/en/quantization/torchao

import torch
import diffusers
import torchao
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

print(f"torch version: {torch.__version__}")
print(f"torchao version: {torchao.__version__}")
print(f"diffusers version: {diffusers.__version__}")

model_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
print(quantization_config)
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
print(transformer)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")

# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")

if __name__ == '__main__':
fire.Fire(run)
30 changes: 30 additions & 0 deletions 20250212_torchao_migration_test/test_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Test that SGLANG is not broken by
https://github.com/pytorch/ao/issues/1690

Note: I don't have a working sglang install, so the test below is a hack to verify
that just the torchao API still works.
"""

import fire
import torch
import torchao
import torch.nn as nn


def run():
import sglang
import sglang.srt.layers.torchao_utils as torchao_utils

print(f"torch version: {torch.__version__}")
print(f"torchao version: {torchao.__version__}")
print(f"sglang version: {sglang.__version__}")

m = nn.Sequential(nn.Linear(256, 256, bias=False, device="cuda"))
torchao_config = "int8wo"
filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear)
m = torchao_utils.apply_torchao_config_to_model(m, torchao_config, filter_fn)
print(m)

if __name__ == '__main__':
fire.Fire(run)