@@ -28,6 +28,17 @@ def _aqt_is_int4(aqt):
28
28
)
29
29
30
30
31
+ def _same_metadata (self : "Int4PackedTensorImpl" , src : "Int4PackedTensorImpl" ) -> bool :
32
+ return (
33
+ isinstance (self , Int4PackedTensorImpl )
34
+ and isinstance (src , Int4PackedTensorImpl )
35
+ and self .shape == src .shape
36
+ and self .int_data .shape == src .int_data .shape
37
+ and self .scale .shape == src .scale .shape
38
+ and type (self ._layout ) == type (src ._layout )
39
+ )
40
+
41
+
31
42
@dataclass (frozen = True )
32
43
class CutlassInt4PackedLayout (Layout ):
33
44
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
@@ -77,6 +88,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
77
88
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
78
89
)
79
90
91
+ elif func is aten .copy_ .default :
92
+ self = args [0 ]
93
+ src = args [1 ]
94
+ if _same_metadata (self , src ):
95
+ self_tensors = self .__tensor_flatten__ ()[0 ]
96
+ for tensor_name in self_tensors :
97
+ getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
98
+ return
99
+ raise ValueError (
100
+ f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
101
+ )
102
+
80
103
raise NotImplementedError (
81
104
f"Int4PackedTensorImpl dispatch: attempting to run { func } , this is not supported"
82
105
)
0 commit comments