@@ -337,15 +337,16 @@ def _check_per_channel_group_params(
337
337
# For now group quantization is only supported for 4b weights
338
338
assert quant_params .is_qc4w , "Only 4b group quantization is supported"
339
339
340
- def define_tensor (
340
+ def define_tensor ( # noqa: C901
341
341
self ,
342
342
tensor : torch .fx .Node ,
343
343
xnn_graph : XNNGraph ,
344
344
vals_to_ids : Dict [torch .fx .Node , int ],
345
345
convert_to_nhwc : bool = False ,
346
- swap_nc_for_depthwise_weights : bool = False ,
346
+ swap_in_out_for_weights : bool = False ,
347
347
quant_params : Optional [QuantParams ] = None ,
348
348
fp32_static_weights : bool = False ,
349
+ groups : int = 1 ,
349
350
) -> None :
350
351
"""
351
352
Defines an tensor value into the XNNGraph
@@ -357,16 +358,21 @@ def define_tensor(
357
358
their corresponding ids in XNNGraph
358
359
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
359
360
reflect the nhwc memory format.
360
- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
361
- should be permuted such that the N and C dimensions are
362
- swapped , which should be used for depthwise convolution
361
+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
362
+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
363
+ , which should be used for depthwise/transpose convolution
363
364
weights. This is only valid for tensors which hold
364
365
constant data. If used along with convert_to_nhwc, this
365
366
swap will happen before converting to nhwc.
366
367
quant_params: Quantization meta data for this tensor, None if it is not quantized
367
368
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
369
+ groups: number of groups for swap_in_out_for_weights
368
370
"""
369
371
372
+ assert (
373
+ swap_in_out_for_weights or groups == 1
374
+ ), "groups is option for swap_in_out_for_weights"
375
+
370
376
if tensor in vals_to_ids :
371
377
return
372
378
@@ -394,15 +400,16 @@ def define_tensor(
394
400
xnn_graph ,
395
401
vals_to_ids ,
396
402
convert_to_nhwc ,
397
- swap_nc_for_depthwise_weights ,
403
+ swap_in_out_for_weights ,
398
404
quant_params ,
399
405
fp32_static_weights ,
406
+ groups ,
400
407
)
401
408
402
409
# convert tensor shape must reflect memory format, default is contiguous, so
403
410
# only permute shape if we are converting the tensor to nhwc format
404
- if swap_nc_for_depthwise_weights :
405
- dims = [dims [1 ], dims [0 ]] + dims [2 :]
411
+ if swap_in_out_for_weights :
412
+ dims = [dims [1 ] * groups , dims [0 ] // groups ] + dims [2 :]
406
413
if convert_to_nhwc :
407
414
check_or_raise (len (dims ) == 4 , "Converting to nhwc requires 4d tensor" )
408
415
dims = [dims [i ] for i in PERM_NCHW_TO_NHWC ]
@@ -422,16 +429,16 @@ def define_tensor(
422
429
)
423
430
424
431
# Override the quant params axis since we have
425
- # updated the weights for depthwise, with that the out_channels dim
432
+ # updated the weights for depthwise/ transposed_conv2d , with that the out_channels dim
426
433
# will be dims[3] instead of dims[0]. Let's update the per_channel
427
434
# quant axis to match the new weight tensor before serializing
428
- if swap_nc_for_depthwise_weights and (
429
- quant_params and quant_params .per_channel
430
- ):
435
+ if swap_in_out_for_weights and (quant_params and quant_params .per_channel ):
431
436
if quant_params .axis == 0 :
432
437
quant_params .axis = len (dims ) - 1
438
+ elif quant_params .axis == 1 :
439
+ quant_params .axis = 0
433
440
else :
434
- assert f"Unsupported weight per channel quantization axis for depthwise conv2d: { quant_params .axis } , expecting 0."
441
+ assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : { quant_params .axis } , expecting 0 / 1 ."
435
442
436
443
# Serialize tensor value
437
444
ser_val = (
@@ -492,9 +499,10 @@ def get_serialized_buffer_index(
492
499
xnn_graph : XNNGraph ,
493
500
vals_to_ids : Dict [torch .fx .Node , int ],
494
501
convert_to_nhwc : bool ,
495
- swap_nc_for_depthwise_weights : bool ,
502
+ swap_in_out_for_weights : bool ,
496
503
quant_params : Optional [QuantParams ],
497
504
fp32_static_weights : bool = False ,
505
+ groups : int = 1 ,
498
506
) -> int :
499
507
"""
500
508
If tensor holds some constant data, serialize it and return the
@@ -507,24 +515,30 @@ def get_serialized_buffer_index(
507
515
their corresponding ids in XNNGraph
508
516
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
509
517
reflect the nhwc memory format.
510
- swap_nc_for_depthwise_weights : bool to indicate whether tensor shape
511
- should be permuted such that the N and C dimensions are
512
- swapped , which should be used for depthwise convolution
518
+ swap_in_out_for_weights : bool to indicate whether tensor shape should be
519
+ permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
520
+ , which should be used for depthwise/transpose convolution
513
521
weights. This is only valid for tensors which hold
514
522
constant data. If used along with convert_to_nhwc, this
515
523
swap will happen before converting to nhwc.
516
524
quant_params: Quantization meta data for this tensor, None if it is not quantize
517
525
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526
+ groups: groups for swap_in_out_for_weights
518
527
519
528
Returns:
520
529
buffer_idx: idx of the serialized data. 0 If not associated constant
521
530
data
522
531
"""
532
+
533
+ assert (
534
+ swap_in_out_for_weights or groups == 1
535
+ ), "groups is option for swap_in_out_for_weights"
536
+
523
537
# The get_attr node is the input to quant_params.
524
538
get_attr_node = tensor if quant_params is None else quant_params .q_input
525
539
if not is_param_node (self .exported_program , get_attr_node ):
526
540
check_or_raise (
527
- not swap_nc_for_depthwise_weights ,
541
+ not swap_in_out_for_weights ,
528
542
"Swapping N and C dimensions is only valid for constant data tensors" ,
529
543
)
530
544
return 0
@@ -541,9 +555,16 @@ def get_serialized_buffer_index(
541
555
# ensure that the const is fp32
542
556
const_val = const_val .to (dtype = torch .float32 ).contiguous ()
543
557
544
- if swap_nc_for_depthwise_weights :
545
- const_val = const_val .permute (
546
- dims = ((1 , 0 ) + tuple (range (2 , const_val .dim ())))
558
+ if swap_in_out_for_weights :
559
+ # Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
560
+ # which should be used for depthwise/transpose convolution weights for XNNPACK
561
+ shape = const_val .shape
562
+ const_val = const_val .reshape (
563
+ (groups , const_val .shape [0 ] // groups ) + const_val .shape [1 :]
564
+ )
565
+ const_val = const_val .permute ((0 , 2 , 1 ) + tuple (range (3 , const_val .dim ())))
566
+ const_val = const_val .reshape (
567
+ (shape [1 ] * groups , shape [0 ] // groups ) + shape [2 :]
547
568
).contiguous ()
548
569
549
570
if convert_to_nhwc :
0 commit comments