@@ -35,21 +35,32 @@ class Target(Enum):
35
35
36
36
# AUTO target will automatically select a packing format
37
37
# based on the available hardware.
38
- # TODO: in future, add the ability to specify specific
39
- # hardware targets
40
38
AUTO = auto ()
39
+ UNIVERSAL = auto ()
40
+ KLEIDIAI = auto ()
41
41
42
42
# ATEN target will use the ATen operator
43
43
ATEN = auto ()
44
44
45
45
46
+ _TARGET_AND_STR = [
47
+ (Target .AUTO , "auto" ),
48
+ (Target .ATEN , "aten" ),
49
+ (Target .UNIVERSAL , "universal" ),
50
+ (Target .KLEIDIAI , "kleidiai" ),
51
+ ]
52
+
53
+
54
+ def target_to_str (target : Target ) -> str :
55
+ target_to_str = {t : s for t , s in _TARGET_AND_STR }
56
+ return target_to_str [target ]
57
+
58
+
46
59
def target_from_str (target : str ) -> Target :
47
- if target .lower () == "auto" :
48
- return Target .AUTO
49
- elif target .lower () == "aten" :
50
- return Target .ATEN
51
- else :
52
- raise ValueError (f"Invalid target: { target } " )
60
+ str_to_target = {s : t for t , s in _TARGET_AND_STR }
61
+ if target .lower () in str_to_target :
62
+ return str_to_target [target .lower ()]
63
+ raise ValueError (f"Invalid target: { target } " )
53
64
54
65
55
66
class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
@@ -146,10 +157,9 @@ def from_plain(
146
157
):
147
158
assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
148
159
assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
149
- assert layout .target in {
150
- Target .AUTO ,
151
- Target .ATEN ,
152
- }, f"Unexpected target: { layout .target } "
160
+ assert layout .target in [
161
+ t for t , _ in _TARGET_AND_STR
162
+ ], f"Unexpected target: { layout .target } "
153
163
154
164
n , k = int_data .shape
155
165
if layout .target == Target .ATEN :
@@ -174,7 +184,7 @@ def from_plain(
174
184
zero_point .reshape (- 1 ).to (torch .int8 ) if layout .has_weight_zeros else None ,
175
185
layout .group_size ,
176
186
bias if layout .has_bias else None ,
177
- None , # target, if not passed a packing format will be chosen on C++ side
187
+ target_to_str ( layout . target ) if layout . target != Target . AUTO else None ,
178
188
]
179
189
180
190
packed_weight = getattr (
@@ -223,7 +233,7 @@ def _linear_check(input_tensor, weight_tensor, bias):
223
233
224
234
225
235
def _linear_impl (input_tensor , weight_tensor , bias ):
226
- def _impl_2d_auto (input_tensor , weight_tensor ):
236
+ def _impl_2d_non_aten (input_tensor , weight_tensor ):
227
237
assert input_tensor .dim () == 2
228
238
assert weight_tensor .dim () == 2
229
239
@@ -272,8 +282,8 @@ def _impl_2d_aten(input_tensor, weight_tensor):
272
282
if target == Target .ATEN :
273
283
assert TORCH_VERSION_AT_LEAST_2_6 == 1 , "Target.ATEN requires torch >= 2.6.0"
274
284
_impl_2d = _impl_2d_aten
275
- elif target == Target . AUTO :
276
- _impl_2d = _impl_2d_auto
285
+ else :
286
+ _impl_2d = _impl_2d_non_aten
277
287
278
288
if input_tensor .dim () == 2 :
279
289
res = _impl_2d (input_tensor , weight_tensor )
0 commit comments