Skip to content

Commit 0885a49

Browse files
committed
cutlass int4 support
1 parent 3093f6d commit 0885a49

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

torchao/dtypes/uintx/cutlass_int4_packed_layout.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def _aqt_is_int4(aqt):
2828
)
2929

3030

31+
def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
32+
return (
33+
isinstance(self, Int4PackedTensorImpl)
34+
and isinstance(src, Int4PackedTensorImpl)
35+
and self.shape == src.shape
36+
and self.int_data.shape == src.int_data.shape
37+
and self.scale.shape == src.scale.shape
38+
and type(self._layout) == type(src._layout)
39+
)
40+
41+
3142
@dataclass(frozen=True)
3243
class CutlassInt4PackedLayout(Layout):
3344
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
@@ -77,6 +88,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
7788
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
7889
)
7990

91+
elif func is aten.copy_.default:
92+
self = args[0]
93+
src = args[1]
94+
if _same_metadata(self, src):
95+
self_tensors = self.__tensor_flatten__()[0]
96+
for tensor_name in self_tensors:
97+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
98+
return
99+
raise ValueError(
100+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
101+
)
102+
80103
raise NotImplementedError(
81104
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
82105
)

0 commit comments

Comments
 (0)