Skip to content

Commit 04c26a2

Browse files
committed
add metadata mismatch test
1 parent b5c8acb commit 04c26a2

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,22 @@ def test_copy_(self, apply_quant):
192192
output2 = ql2(example_input)
193193
self.assertEqual(output, output2)
194194

195+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
196+
@common_utils.parametrize(
197+
"apply_quant", get_quantization_functions(False, True, "cuda", False)
198+
)
199+
def test_copy__mismatch_metadata(self, apply_quant):
200+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
201+
ql = apply_quant(linear)
202+
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
203+
ql2 = apply_quant(linear2)
204+
205+
# copy should fail due to shape mismatch
206+
with self.assertRaisesRegex(
207+
ValueError, "Not supported args for copy_ due to metadata mistach:"
208+
):
209+
ql2.weight.copy_(ql.weight)
210+
195211

196212
class TestAffineQuantizedBasic(TestCase):
197213
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

torchao/dtypes/uintx/plain_layout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> boo
2626
return (
2727
isinstance(self, PlainAQTTensorImpl)
2828
and isinstance(src, PlainAQTTensorImpl)
29+
and self.shape == src.shape
2930
and self.int_data.shape == src.int_data.shape
3031
and self.scale.shape == src.scale.shape
3132
and self.zero_point.shape == src.zero_point.shape

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _same_metadata(
3939
return (
4040
isinstance(self, TensorCoreTiledAQTTensorImpl)
4141
and isinstance(src, TensorCoreTiledAQTTensorImpl)
42+
and self.shape == src.shape
4243
and self.packed_weight.shape == src.packed_weight.shape
4344
and self.scale_and_zero.shape == src.scale_and_zero.shape
4445
and self.transposed == src.transposed

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def _same_metadata(
118118
return (
119119
isinstance(self, LinearActivationQuantizedTensor)
120120
and isinstance(src, LinearActivationQuantizedTensor)
121+
and self.shape == src.shape
121122
and self.input_quant_func == src.input_quant_func
122123
and self.quant_kwargs == src.quant_kwargs
123124
)

0 commit comments

Comments
 (0)