diff --git a/src/compressed_tensors/compressors/compress_to_fp4.py b/src/compressed_tensors/compressors/compress_to_fp4.py new file mode 100644 index 00000000..a523f9c0 --- /dev/null +++ b/src/compressed_tensors/compressors/compress_to_fp4.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import torch + + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] +conversion_dict = {} + +# Dictionary between fp4 value and index +for i in range(len(FLOAT_TO_E2M1)): + conversion_dict[FLOAT_TO_E2M1[i]] = i + + +def fp4_to_index(value): + sign = torch.signbit(value) + x = torch.abs(value) + index = conversion_dict.get(x.item()) + + if not sign: # all positives + return index + else: # all negatives + return index + 8 + + +def pack_fp4_values(x: torch.Tensor): + x_flatten = x.flatten() + # convert to index value, unpack to bits + x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) + x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") + + packed_shape = ( + torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") + ) + start = 0 + end = 16 + i = 0 + + # janky bit manipulation + while end <= len(x_index_bits): + print(start, end) + subset = x_index_bits[start:end] + + subset_a = subset[4:8] + subset_b = subset[12:16] + packed_shape[i + 4 : i + 8] = subset_a + packed_shape[i : i + 4] = subset_b + start = end + end = start + 16 + i += 8 + + # pack + packed = numpy.packbits(packed_shape.cpu().numpy()) + packed = torch.Tensor(packed).to(torch.uint8) + packed = packed.reshape(m, n // 2) + return packed + + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +# reference: https://github.com/vllm-project/vllm/pull/16362 +def break_fp4_bytes(a, dtype=torch.float32): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +# fp4 tensor +x = torch.Tensor( + [ + [-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, 0.0000, -0.0000], + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, 0.0000], + [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], + [1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000], + ] +) + +m, n = x.shape + +packed = pack_fp4_values(x) +out = break_fp4_bytes(packed) +assert torch.equal(out, x) # misleading as -0 and 0 are considered equal +sign_bitx = torch.signbit(x) +sign_bitout = torch.signbit(out) +assert torch.equal(sign_bitout, sign_bitx) diff --git a/src/compressed_tensors/compressors/quantized_compressors/__init__.py b/src/compressed_tensors/compressors/quantized_compressors/__init__.py index 51e8b8e2..496519d4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/__init__.py +++ b/src/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -14,5 +14,6 @@ # flake8: noqa from .base import * +from .modelopt_quantized import * from .naive_quantized import * from .pack_quantized import * diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 098328be..16cdcd7c 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -113,6 +113,9 @@ def compress( scale = model_state.get(merge_names(prefix, "weight_scale"), None) zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None) + global_scale = model_state.get( + merge_names(prefix, "weight_global_scale"), None + ) if scale is not None: # weight is quantized, compress it if isinstance(names_to_scheme[prefix], tuple): @@ -125,6 +128,7 @@ def compress( scale=scale, zero_point=zp, g_idx=g_idx, + global_scale=global_scale, quantization_args=quant_args, device="cpu", ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py new file mode 100644 index 00000000..8068b046 --- /dev/null +++ b/src/compressed_tensors/compressors/quantized_compressors/modelopt_quantized.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Optional, Tuple + +import numpy +import torch +from compressed_tensors.compressors.base import BaseCompressor +from compressed_tensors.compressors.quantized_compressors.base import ( + BaseQuantizationCompressor, +) +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from torch import Tensor + + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] +conversion_dict = {} + +# Dictionary between fp4 value and index +for i in range(len(FLOAT_TO_E2M1)): + conversion_dict[FLOAT_TO_E2M1[i]] = i + + +def fp4_to_index(value): + sign = torch.signbit(value) + x = torch.abs(value) + index = conversion_dict.get(x.item()) + + if not sign: # all positives + return index + else: # all negatives + return index + 8 + + +@BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value) +class ModelOptCompressor(BaseQuantizationCompressor): + """ + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. + """ + + @property + def compression_param_names(self) -> Tuple[str]: + """ + Returns a tuple of compression parameter names introduced by + the compressor during compression + """ + return ( + "weight_packed", + "weight_scale", + "weight_zero_point", + "weight_global_scale", + ) + + def compress_weight( + self, + weight: Tensor, + scale: Tensor, + global_scale: Tensor, + quantization_args: QuantizationArgs, + device: Optional[torch.device] = None, + zero_point: Optional[torch.Tensor] = None, + g_idx: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + + quantized_weight = quantize( + x=weight, + scale=scale, + global_scale=global_scale, + zero_point=zero_point, + args=quantization_args, + ) + compressed_dict = {} + weight_packed = pack_fp4_to_uint8(quantized_weight) + compressed_dict["weight_packed"] = weight_packed + return compressed_dict + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + + weight = compressed_data["weight_packed"] + scale = compressed_data["weight_scale"] + global_scale = compressed_data["weight_global_scale"] + m, n = weight.shape + # TODO: need a way to pass in the output_dtype - can't be assumed based on the scales + # for nvfp4 (maybe the global scale can be not fp32?) + unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + decompressed_weight = dequantize( + x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype + ) + + return decompressed_weight + + +def pack_fp4_to_uint8(x: torch.Tensor): + m, n = x.shape + x_flatten = x.flatten() + # convert to index value, unpack to bits + x_index = numpy.array([fp4_to_index(i) for i in x_flatten], dtype=numpy.uint8) + x_index_bits = torch.from_numpy(numpy.unpackbits(x_index)).to("cuda:0") + + packed_shape = ( + torch.zeros([x_index_bits.shape[0] // 2]).to(torch.uint8).to("cuda:0") + ) + start = 0 + end = 16 + i = 0 + + # janky bit manipulation + while end <= len(x_index_bits): + subset = x_index_bits[start:end] + + subset_a = subset[4:8] + subset_b = subset[12:16] + packed_shape[i + 4 : i + 8] = subset_a + packed_shape[i : i + 4] = subset_b + start = end + end = start + 16 + i += 8 + + # pack + packed = numpy.packbits(packed_shape.cpu().numpy()) + packed = torch.Tensor(packed).to(torch.uint8).to("cuda:0") + packed = packed.reshape(m, n // 2) + return packed + + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +# reference: : https://github.com/vllm-project/vllm/pull/16362 +def unpack_fp4_from_uint8(a: torch.Tensor, m: int, n: int, dtype=torch.float16): + assert a.dtype == torch.uint8 + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n).to(dtype=dtype) diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 9ca6f2cf..3ec3bc46 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -32,6 +32,7 @@ class CompressionFormat(Enum): naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" + modelopt_quantized = "modelopt-quantized" @unique