Skip to content

Commit 79e3366

Browse files
authored
Add support for copy_ for plain layout and tensor core tiled layout (#1791)
* Add support for copy_ for plain layout and tensor core tiled layout Summary: att, only support copy_ from AQT to another AQT with same metadata (shapes etc.) Tested int4wo, int8wo, int8dq Test Plan: python test/dtypes/test_affine_quantized.py -k test_copy_ Reviewers: Subscribers: Tasks: Tags: * remove print * add metadata mismatch test * rebase and add float8 * cutlass int4 support
1 parent f478692 commit 79e3366

File tree

7 files changed

+210
-1
lines changed

7 files changed

+210
-1
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,53 @@ def test_print_quantized_module(self, apply_quant):
209209
ql = apply_quant(linear)
210210
assert "AffineQuantizedTensor" in str(ql)
211211

212+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
213+
@common_utils.parametrize(
214+
"apply_quant", get_quantization_functions(False, True, "cuda", False)
215+
)
216+
def test_copy_(self, apply_quant):
217+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
218+
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
219+
220+
if isinstance(apply_quant, AOBaseConfig):
221+
quantize_(linear, apply_quant)
222+
ql = linear
223+
quantize_(linear2, apply_quant)
224+
ql2 = linear2
225+
else:
226+
ql = apply_quant(linear)
227+
ql2 = apply_quant(linear2)
228+
229+
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
230+
output = ql(example_input)
231+
ql2.weight.copy_(ql.weight)
232+
ql2.bias = ql.bias
233+
output2 = ql2(example_input)
234+
self.assertEqual(output, output2)
235+
236+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
237+
@common_utils.parametrize(
238+
"apply_quant", get_quantization_functions(False, True, "cuda", False)
239+
)
240+
def test_copy__mismatch_metadata(self, apply_quant):
241+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
242+
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
243+
244+
if isinstance(apply_quant, AOBaseConfig):
245+
quantize_(linear, apply_quant)
246+
ql = linear
247+
quantize_(linear2, apply_quant)
248+
ql2 = linear2
249+
else:
250+
ql = apply_quant(linear)
251+
ql2 = apply_quant(linear2)
252+
253+
# copy should fail due to shape mismatch
254+
with self.assertRaisesRegex(
255+
ValueError, "Not supported args for copy_ due to metadata mistach:"
256+
):
257+
ql2.weight.copy_(ql.weight)
258+
212259

213260
class TestAffineQuantizedBasic(TestCase):
214261
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,27 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
9797
)
9898

9999

100+
def _same_metadata(self: AffineQuantizedTensor, src: AffineQuantizedTensor):
101+
return (
102+
isinstance(self, AffineQuantizedTensor)
103+
and isinstance(src, AffineQuantizedTensor)
104+
and all(
105+
[
106+
getattr(self, attr) == getattr(src, attr)
107+
for attr in [
108+
"block_size",
109+
"shape",
110+
"quant_min",
111+
"quant_max",
112+
"zero_point_domain",
113+
"dtype",
114+
]
115+
]
116+
)
117+
and type(self.tensor_impl) == type(src.tensor_impl)
118+
)
119+
120+
100121
class QuantizedLinearNotImplementedError(NotImplementedError):
101122
"""Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table"""
102123

@@ -331,6 +352,20 @@ def _(func, types, args, kwargs):
331352
)
332353

333354

355+
@implements(aten.copy_.default)
356+
def _(func, types, args, kwargs):
357+
self = args[0]
358+
src = args[1]
359+
if _same_metadata(self, src):
360+
self_tensors = self.__tensor_flatten__()[0]
361+
for tensor_name in self_tensors:
362+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
363+
return
364+
raise ValueError(
365+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
366+
)
367+
368+
334369
@implements(aten.t.default)
335370
def _(func, types, args, kwargs):
336371
block_size = args[0].block_size

torchao/dtypes/floatx/float8_layout.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
aten = torch.ops.aten
2424

2525

26+
def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
27+
return (
28+
isinstance(self, Float8AQTTensorImpl)
29+
and isinstance(src, Float8AQTTensorImpl)
30+
and self.shape == src.shape
31+
and self.float8_data.shape == src.float8_data.shape
32+
and self.scale.shape == src.scale.shape
33+
and self.transposed == src.transposed
34+
and type(self._layout) == type(src._layout)
35+
)
36+
37+
2638
@dataclass(frozen=True)
2739
class Float8Layout(Layout):
2840
"""Represents the layout configuration for Float8 affine quantized tensors.
@@ -126,6 +138,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
126138
"""
127139
args[0].transposed = not args[0].transposed
128140
return return_and_correct_aliasing(func, args, kwargs, args[0])
141+
elif func is aten.copy_.default:
142+
self = args[0]
143+
src = args[1]
144+
if _same_metadata(self, src):
145+
self_tensors = self.__tensor_flatten__()[0]
146+
for tensor_name in self_tensors:
147+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
148+
return
149+
raise ValueError(
150+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
151+
)
129152
elif func is aten.slice.Tensor:
130153
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
131154
if dim == 0:

torchao/dtypes/uintx/cutlass_int4_packed_layout.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def _aqt_is_int4(aqt):
2828
)
2929

3030

31+
def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
32+
return (
33+
isinstance(self, Int4PackedTensorImpl)
34+
and isinstance(src, Int4PackedTensorImpl)
35+
and self.shape == src.shape
36+
and self.int_data.shape == src.int_data.shape
37+
and self.scale.shape == src.scale.shape
38+
and type(self._layout) == type(src._layout)
39+
)
40+
41+
3142
@dataclass(frozen=True)
3243
class CutlassInt4PackedLayout(Layout):
3344
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
@@ -77,6 +88,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
7788
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
7889
)
7990

91+
elif func is aten.copy_.default:
92+
self = args[0]
93+
src = args[1]
94+
if _same_metadata(self, src):
95+
self_tensors = self.__tensor_flatten__()[0]
96+
for tensor_name in self_tensors:
97+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
98+
return
99+
raise ValueError(
100+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
101+
)
102+
80103
raise NotImplementedError(
81104
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
82105
)

torchao/dtypes/uintx/plain_layout.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@
2222
aten = torch.ops.aten
2323

2424

25+
def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> bool:
26+
return (
27+
isinstance(self, PlainAQTTensorImpl)
28+
and isinstance(src, PlainAQTTensorImpl)
29+
and self.shape == src.shape
30+
and self.int_data.shape == src.int_data.shape
31+
and self.scale.shape == src.scale.shape
32+
and (self.zero_point is None and src.zero_point is None)
33+
or (
34+
self.zero_point is not None
35+
and src.zero_point is not None
36+
and self.zero_point.shape == src.zero_point.shape
37+
)
38+
and type(self._layout) == type(src._layout)
39+
)
40+
41+
2542
@register_layout(PlainLayout)
2643
class PlainAQTTensorImpl(AQTTensorImpl):
2744
"""
@@ -108,11 +125,23 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
108125
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
109126
)
110127

111-
if func is aten.clone.default:
128+
elif func is aten.clone.default:
112129
return return_and_correct_aliasing(
113130
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
114131
)
115132

133+
elif func is aten.copy_.default:
134+
self = args[0]
135+
src = args[1]
136+
if _same_metadata(self, src):
137+
self_tensors = self.__tensor_flatten__()[0]
138+
for tensor_name in self_tensors:
139+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
140+
return
141+
raise ValueError(
142+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
143+
)
144+
116145
elif func is aten.t.default:
117146
tensor = args[0]
118147
new = tensor.__class__(

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ def _aqt_is_tensor_core_tile_uint4(aqt):
3232
)
3333

3434

35+
def _same_metadata(
36+
self: "TensorCoreTiledAQTTensorImpl", src: "TensorCoreTiledAQTTensorImpl"
37+
) -> bool:
38+
return (
39+
isinstance(self, TensorCoreTiledAQTTensorImpl)
40+
and isinstance(src, TensorCoreTiledAQTTensorImpl)
41+
and self.shape == src.shape
42+
and self.packed_weight.shape == src.packed_weight.shape
43+
and self.scale_and_zero.shape == src.scale_and_zero.shape
44+
and self.transposed == src.transposed
45+
and type(self._layout) == type(src._layout)
46+
)
47+
48+
3549
def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
3650
return (
3751
# input is native bfloat16 tensor
@@ -290,6 +304,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
290304
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
291305
)
292306

307+
if func is aten.copy_.default:
308+
self = args[0]
309+
src = args[1]
310+
if _same_metadata(self, src):
311+
self_tensors = self.__tensor_flatten__()[0]
312+
for tensor_name in self_tensors:
313+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
314+
return
315+
raise ValueError(
316+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
317+
)
318+
293319
if func is aten.t.default:
294320
"""we don't need to repack the weight and just rely on external
295321
shape being changed and record the status of transpose/no-transpose

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,18 @@ def to(self, *args, **kwargs):
112112
)
113113

114114

115+
def _same_metadata(
116+
self: LinearActivationQuantizedTensor, src: LinearActivationQuantizedTensor
117+
):
118+
return (
119+
isinstance(self, LinearActivationQuantizedTensor)
120+
and isinstance(src, LinearActivationQuantizedTensor)
121+
and self.shape == src.shape
122+
and self.input_quant_func == src.input_quant_func
123+
and self.quant_kwargs == src.quant_kwargs
124+
)
125+
126+
115127
implements = LinearActivationQuantizedTensor.implements
116128

117129

@@ -191,6 +203,20 @@ def _(func, types, args, kwargs):
191203
)
192204

193205

206+
@implements(aten.copy_.default)
207+
def _(func, types, args, kwargs):
208+
self = args[0]
209+
src = args[1]
210+
if _same_metadata(self, src):
211+
self_tensors = self.__tensor_flatten__()[0]
212+
for tensor_name in self_tensors:
213+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
214+
return
215+
raise ValueError(
216+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
217+
)
218+
219+
194220
@implements(aten.t.default)
195221
def _(func, types, args, kwargs):
196222
return return_and_correct_aliasing(

0 commit comments

Comments
 (0)