Skip to content

Add support for copy_ for plain layout and tensor core tiled layout #1791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,53 @@ def test_print_quantized_module(self, apply_quant):
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
)
def test_copy_(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
quantize_(linear2, apply_quant)
ql2 = linear2
else:
ql = apply_quant(linear)
ql2 = apply_quant(linear2)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
output = ql(example_input)
ql2.weight.copy_(ql.weight)
ql2.bias = ql.bias
output2 = ql2(example_input)
self.assertEqual(output, output2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
)
def test_copy__mismatch_metadata(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")

if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
quantize_(linear2, apply_quant)
ql2 = linear2
else:
ql = apply_quant(linear)
ql2 = apply_quant(linear2)

# copy should fail due to shape mismatch
with self.assertRaisesRegex(
ValueError, "Not supported args for copy_ due to metadata mistach:"
):
ql2.weight.copy_(ql.weight)


class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand Down
35 changes: 35 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
)


def _same_metadata(self: AffineQuantizedTensor, src: AffineQuantizedTensor):
return (
isinstance(self, AffineQuantizedTensor)
and isinstance(src, AffineQuantizedTensor)
and all(
[
getattr(self, attr) == getattr(src, attr)
for attr in [
"block_size",
"shape",
"quant_min",
"quant_max",
"zero_point_domain",
"dtype",
]
]
)
and type(self.tensor_impl) == type(src.tensor_impl)
)


class QuantizedLinearNotImplementedError(NotImplementedError):
"""Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table"""

Expand Down Expand Up @@ -331,6 +352,20 @@ def _(func, types, args, kwargs):
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
block_size = args[0].block_size
Expand Down
23 changes: 23 additions & 0 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
aten = torch.ops.aten


def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
return (
isinstance(self, Float8AQTTensorImpl)
and isinstance(src, Float8AQTTensorImpl)
and self.shape == src.shape
and self.float8_data.shape == src.float8_data.shape
and self.scale.shape == src.scale.shape
and self.transposed == src.transposed
and type(self._layout) == type(src._layout)
)


@dataclass(frozen=True)
class Float8Layout(Layout):
"""Represents the layout configuration for Float8 affine quantized tensors.
Expand Down Expand Up @@ -126,6 +138,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])
elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
Expand Down
23 changes: 23 additions & 0 deletions torchao/dtypes/uintx/cutlass_int4_packed_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def _aqt_is_int4(aqt):
)


def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
return (
isinstance(self, Int4PackedTensorImpl)
and isinstance(src, Int4PackedTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and type(self._layout) == type(src._layout)
)


@dataclass(frozen=True)
class CutlassInt4PackedLayout(Layout):
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
Expand Down Expand Up @@ -77,6 +88,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

raise NotImplementedError(
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
)
Expand Down
31 changes: 30 additions & 1 deletion torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@
aten = torch.ops.aten


def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> bool:
return (
isinstance(self, PlainAQTTensorImpl)
and isinstance(src, PlainAQTTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and (self.zero_point is None and src.zero_point is None)
or (
self.zero_point is not None
and src.zero_point is not None
and self.zero_point.shape == src.zero_point.shape
)
and type(self._layout) == type(src._layout)
)


@register_layout(PlainLayout)
class PlainAQTTensorImpl(AQTTensorImpl):
"""
Expand Down Expand Up @@ -108,11 +125,23 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
Expand Down
26 changes: 26 additions & 0 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def _aqt_is_tensor_core_tile_uint4(aqt):
)


def _same_metadata(
self: "TensorCoreTiledAQTTensorImpl", src: "TensorCoreTiledAQTTensorImpl"
) -> bool:
return (
isinstance(self, TensorCoreTiledAQTTensorImpl)
and isinstance(src, TensorCoreTiledAQTTensorImpl)
and self.shape == src.shape
and self.packed_weight.shape == src.packed_weight.shape
and self.scale_and_zero.shape == src.scale_and_zero.shape
and self.transposed == src.transposed
and type(self._layout) == type(src._layout)
)


def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
return (
# input is native bfloat16 tensor
Expand Down Expand Up @@ -290,6 +304,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
Expand Down
26 changes: 26 additions & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def to(self, *args, **kwargs):
)


def _same_metadata(
self: LinearActivationQuantizedTensor, src: LinearActivationQuantizedTensor
):
return (
isinstance(self, LinearActivationQuantizedTensor)
and isinstance(src, LinearActivationQuantizedTensor)
and self.shape == src.shape
and self.input_quant_func == src.input_quant_func
and self.quant_kwargs == src.quant_kwargs
)


implements = LinearActivationQuantizedTensor.implements


Expand Down Expand Up @@ -191,6 +203,20 @@ def _(func, types, args, kwargs):
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
Loading