v0.10.0
Highlights
We are excited to announce the 0.10.0 release of torchao! This release adds support for end to end training for mxfp8 on Nvidia B200, PARQ (for quantization aware training), module swap quantization API to for research, and some updates for low bit kernels!
Low Bit Optimizers moved to Official Support (#1864)
Low bit optimizers (added in 0.4) is moved out of prototype and now have official support in torchao.
[Prototype] End to End Training Support for mxfp8 on NVIDIA B200 (#1786, #1841, #1951, #1932, #1980)
We have an early version of the end to end training workflow for the mxfp8 dtypes with torch.compile on NVIDIA B200, with the cuBLAS mxfp8 gemm seeing an observed speedup of over 2x over bfloat16 gemm, and casts from bfloat16 to mxfp8 achieving up to 5.5 TB/s. Please see our README.md for MX for more information. We plan to improve performance further in future releases.
[Prototype] Piecewise-Affine Regularized Quantization (#1738)
- PARQ is a new theoretical framework for inducing quantization through regularization. It supports standard QAT, as well as new gradual quantization methods, in an easy to use optimizer-only interface. No modifications to a model’s forward or backward pass are needed for quantization.
from torchao.prototype.parq.optim import QuantOptimizer, ProxHardQuant
from torchao.prototype.parq.quant import UnifQuantizer
# Separate quantizable from non-quantizable parameter groups
param_groups = [
{"params": weights, "quant_bits": 2}, # add extra quant_bits key for QAT
{"params": others},
]
# Initialize any torch.optim.Optimizer
base_optimizer = torch.optim.SGD(param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4)
# Apply a simple wrapper to quantize in optimizer.step()
optimizer = QuantOptimizer(
base_optimizer, quantizer=UnifQuantizer(), prox_map=ProxHardQuant()
)
[Prototype] Module Swap Quantization API (#1886)
We added a prototype API for post-training quantization. Users can swap their linear or embedding layers into their QuantizedLinear and QuantizedEmbedding counterparts, and set the quantizers that specify how they want the input activations or weights to be quantized:
quantized_linear = QuantizedLinear(...)
quantized_linear.weight_quantization = IntQuantizer(
num_bits=4,
group_size=32,
dynamic=True,
quantization_mode="symmetric",
)
quantized_linear.input_quantization = CodeBookQuantizer(
num_bits=8,
features=10,
)
Note: The API is highly subject to change and will be integrated with quantize_ in the future. For more detail, please see the README.
[Prototype] Low Bit Kernels Updates (#1826, #1935, #1998, #1652)
Low-bit CPU and MPS kernels are now pip installable from source. To install torchao with low-bit CPU kernels, you can use the following command on an Arm-based Mac:
USE_CPP=1 pip install git+https://github.com/pytorch/ao.git
You can then quantize your model to run on Arm-based Macs with high-performance CPU kernels in torchao. SharedEmbeddingQuantizer,EmbeddingQuantizer
, and Int8DynamicActivationIntxWeightConfig
all support 1-8 bit quantization.
from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig, SharedEmbeddingQuantizer, EmbeddingQuantizer
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.quant_api import quantize_
# Quantize embedding/unembedding to 8-bits with SharedEmbeddingQuantizer
# SharedEmbeddingQuantizer is for quantizing models like Llama1B/3B
# where the embedding/unembedding layers share weights
# If the embedding/unembedding layers do not share weights, use
# EmbeddingQuantizer instead
SharedEmbeddingQuantizer(
weight_dtype=torch.int8,
granularity=PerRow(),
has_weight_zeros=True
).quantize(model) # Quantize linear layers to 4-bits
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
granularity=PerGroup(128),
has_weight_zeros=False,
)
)
BC Breaking
Delete delayed scaling from torchao.float8 (#1753)
The following usage of `Float8Config` is deprecated in torchao v0.10.0:
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
If you would like to use float8 training with delayed scaling, please use an earlier release of torchao. Please see #1680 for more context about this deprecation.
Enforce AOBaseConfig type in quantize_
's config
argument (#1861)
This was done following a deprecation window to simplify the arguments of quantize_, please see #1690 for more context.
# torchao v.0.9.0
def quantize_(
model: torch.nn.Module,
**config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],**
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
set_inductor_config: Optional[bool] = None,
device: Optional[torch.types.Device] = None,
):
# torchao v.0.10.0
def quantize_(
model: torch.nn.Module,
config: AOBaseConfig,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
set_inductor_config: Optional[bool] = None,
device: Optional[torch.types.Device] = None,
):
Remove the set_inductor_config
argument of quantize_
. (#1865)
This was done following a deprecation window to decouple quantize_ from torchinductor, please see #1715 for more context.
# torchao v.0.9.0
def quantize_(
...,
set_inductor_config: Optional[bool] = None,
...,
):
# if set_inductor_config != None, throw a deprecation warning
# if set_inductor_config == None, set it to True to stay consistent with old behavior
# torchao v0.10.0
def quantize_(
...,
):
# set_inductor_config is removed from quantize_ and moved to relevant individual workflows
Deprecations
We removed some of our prototype features that are not used, including DORA (#1815), split_k kernel (#1816), profiler (#1862) and bitnet (#1866).
New Features
QAT
- Added PARQ (#1738)
Low Bit Optimizers
- Promote Low Bit Optim out of prototype (#1864)
Module swap quantization API
- Add module swap quantization API from Quanty (#1886)
Benchmarking
- Micro-benchmark inference (#1759)
- Add sparsity to benchmarking (#1917)
- Add float8 training benchmarking scripts (#1802)
Improvement
Kernels
- 1-8 bit CPU and MPS kernels are now pip installable from source (#1826)
- Added 1-8 bit shared embedding ops to further compress models like Llama1B/3B where the embedding/unembedding weights are shared (#1935)
- CPU kernels added runtime microkernel selection based on CPU features and matrix size (#1998)
- KleidiAI microkernel library was integrated with CPU kernels to improve GEMM performance on Arm CPUs (#1652)
- Add build flag to set parallel_backend (#1870)
- Add quant api + python test for shared embedding (#1937)
- Add dynamic shape support for lowbit kernels (#1942)
- Add LUT-based bitpacking for 1-4 bits (#1987)
- Add lut support to linear kernel (#1990)
- Quantized matmul (#1994)
- Add fp32xint8 matmul (#2004)
- Add quantized q @ k test for intented used in quantized attention (#2006)
- ROCm Support : Tile_Layout kernel (#1201)
- Metal lowbit kernels: pip install (#1785)
- Metal lowbit ops: ci (#1825)
- ROCm Sparse Marlin Kernels #1206 (#1834)
- ROCm OCP FP8 Support (#1677)
- Migrate to int args (#1846)
- Add bias support to torchao kernels (#1879)
- Write weight packing/unpacking functions for universal kernels (#1921)
- Unpack weights at col (#1933)
- Shared embedding kernel (#1934)
- Bug fixes for shared_embedding (#1941)
- Update linear.h (#1963)
- Reintroduce has_weight_zeros as a template param (#1991)
AOConfigs
- Support Serialization for AOConfigs (#1875)
- Migrate to config for Int8DynamicActivationIntxWeightConfig (#1836)
- Migrate
sparsify_
to configs (#1856)
SAM2
- SAM2: Use torch.export for VOS (#1708)
QAT
- Add linear bias support for QAT (#1755)
MX
- Allow for scales to be in new e8m0 dtype (#1742)
- Support MXFP6 packing and fused unpack-dequantize kernel (#1810)
- Implemented RCEIL (CUBLAS-style) MXFP scale factor derivation, with test cases. (#1835)
- Use torch.float8_e8m0fnu in mx_formats (#1966)
- Mx_formats: move training to the quantize_ API (#1970)
Affine Quantization
- Add support for copy_ for plain layout and tensor core tiled layout (#1791)
- Add bias support for Int8DynActInt4WeightLinear (#1845)
- Move config out of experimental (#1954)
Bug Fixes
- Fix potential out-of-bound access in int8_mm.py (#1751)
- Fixing DORA imports (#1795)
- Avoid assert error when there's bias (#1839)
- Update triton import error message (#1842)
- Enable the CPU int4 with HQQ quant (#1824)
- Do not override requires_grad=False when enable_float8_all_gather=True (#1873)
- Add MI300X specs to roofline benchmark (#1913)
- Fix dynamic shape for shared embedding (#1946)
Performance
- Modify cast from hp to mx to help inductor fuse (#1786)
- Enable torch.compile for mxfp8_cublas recipe (#1841)
- Optimize tensor_flatten for runtime (#1951)
- Triton kernel to cast to mx and write in col-major (#1932)
- small speedup with dim0 cast for mx (#1980)
Documentation
- Updating Cuda 12.1/12.4 to 12.4/12.6 to reflect current state (#1794)
- Update float8 training benchmark readme (#1872)
- Add perf benchmarks for float8 training with rowwise + tensorwise scaling (#1793)
- Fix link markdown in readme (#1881)
- Refresh torchao.float8 README (#1986)
- Refresh float8 training section of main README (#1985)
- Refresh MX README (#1989)
New Contributors
- @jithunnair-amd made their first contribution in #1749
- @facebook-github-bot made their first contribution in #1752
- @mark14wu made their first contribution in #1751
- @lisjin made their first contribution in #1738
- @mayank31398 made their first contribution in #1849
- @alex-titterton made their first contribution in #1810
- @mreso made their first contribution in #1913
- @frsun-nvda made their first contribution in #1835
Full Changelog: v0.9.0...v0.10.0-rc1