Skip to content

Commit 63e6136

Browse files
authored
[XNNPACK] Support 2d Transposed Convolution in XNNPACK delegate (#7514)
* Support Transposed Convolution in XNNPACK delegate * Apply suggestions * Remove invalid restriction for transpose convolution batch normalization fusion * fix size analysis tool test
1 parent 8870fae commit 63e6136

17 files changed

+575
-172
lines changed

backends/xnnpack/_passes/fuse_activation_pass.py

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def call(self, graph_module: torch.fx.GraphModule):
6868
preceding_op.op == "call_function"
6969
and preceding_op.target in self.FUSEABLE_OPS
7070
):
71+
# Check that current activation is the only user of the preceding op
72+
# so that we can fuse the activation into the preceding op
73+
if len(preceding_op.users) > 1:
74+
continue
7175
# Delete activation, and embed metadata into preceding op
7276
output_min_max = self.get_output_min_max_from_activation(
7377
activation_node

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8282
# as an arg)
8383
eps = bn.args[-1]
8484

85+
is_transpose = conv.args[6]
8586
# Compute the updated weight and bias after fusing conv op
8687
# with batchnorm op.
8788
fused_weight, fused_bias = fuse_conv_bn_weights(
@@ -92,6 +93,7 @@ def call(self, graph_module: torch.fx.GraphModule):
9293
eps,
9394
bn_weight,
9495
bn_bias,
96+
is_transpose,
9597
)
9698

9799
# Modify the graph by updating the weight and bias of conv op

backends/xnnpack/_passes/tag_implicit_q_dq_pass.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,16 @@ def is_dynamically_quantized(self, node: torch.fx.Node) -> bool:
8383
return is_dynamic_qdq(node)
8484

8585
def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
86-
return (
87-
node.op == "call_function"
88-
and cast(torch._ops.OpOverload, node.target).name()
89-
in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
90-
)
86+
if node.op != "call_function":
87+
return False
88+
89+
op_name = cast(torch._ops.OpOverload, node.target).name()
90+
91+
# Weight and Input should both be quantized
92+
if op_name == exir_ops.edge.aten.convolution.default.name():
93+
return is_dequant(node.args[1])
94+
95+
return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
9196

9297
def is_supported_quant_module(self, node: torch.fx.Node) -> bool:
9398
is_supported = (

backends/xnnpack/operators/node_visitor.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,16 @@ def _check_per_channel_group_params(
337337
# For now group quantization is only supported for 4b weights
338338
assert quant_params.is_qc4w, "Only 4b group quantization is supported"
339339

340-
def define_tensor(
340+
def define_tensor( # noqa: C901
341341
self,
342342
tensor: torch.fx.Node,
343343
xnn_graph: XNNGraph,
344344
vals_to_ids: Dict[torch.fx.Node, int],
345345
convert_to_nhwc: bool = False,
346-
swap_nc_for_depthwise_weights: bool = False,
346+
swap_in_out_for_weights: bool = False,
347347
quant_params: Optional[QuantParams] = None,
348348
fp32_static_weights: bool = False,
349+
groups: int = 1,
349350
) -> None:
350351
"""
351352
Defines an tensor value into the XNNGraph
@@ -357,16 +358,21 @@ def define_tensor(
357358
their corresponding ids in XNNGraph
358359
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
359360
reflect the nhwc memory format.
360-
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
361-
should be permuted such that the N and C dimensions are
362-
swapped, which should be used for depthwise convolution
361+
swap_in_out_for_weights: bool to indicate whether tensor shape should be
362+
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
363+
, which should be used for depthwise/transpose convolution
363364
weights. This is only valid for tensors which hold
364365
constant data. If used along with convert_to_nhwc, this
365366
swap will happen before converting to nhwc.
366367
quant_params: Quantization meta data for this tensor, None if it is not quantized
367368
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
369+
groups: number of groups for swap_in_out_for_weights
368370
"""
369371

372+
assert (
373+
swap_in_out_for_weights or groups == 1
374+
), "groups is option for swap_in_out_for_weights"
375+
370376
if tensor in vals_to_ids:
371377
return
372378

@@ -394,15 +400,16 @@ def define_tensor(
394400
xnn_graph,
395401
vals_to_ids,
396402
convert_to_nhwc,
397-
swap_nc_for_depthwise_weights,
403+
swap_in_out_for_weights,
398404
quant_params,
399405
fp32_static_weights,
406+
groups,
400407
)
401408

402409
# convert tensor shape must reflect memory format, default is contiguous, so
403410
# only permute shape if we are converting the tensor to nhwc format
404-
if swap_nc_for_depthwise_weights:
405-
dims = [dims[1], dims[0]] + dims[2:]
411+
if swap_in_out_for_weights:
412+
dims = [dims[1] * groups, dims[0] // groups] + dims[2:]
406413
if convert_to_nhwc:
407414
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
408415
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
@@ -422,16 +429,16 @@ def define_tensor(
422429
)
423430

424431
# Override the quant params axis since we have
425-
# updated the weights for depthwise, with that the out_channels dim
432+
# updated the weights for depthwise/ transposed_conv2d, with that the out_channels dim
426433
# will be dims[3] instead of dims[0]. Let's update the per_channel
427434
# quant axis to match the new weight tensor before serializing
428-
if swap_nc_for_depthwise_weights and (
429-
quant_params and quant_params.per_channel
430-
):
435+
if swap_in_out_for_weights and (quant_params and quant_params.per_channel):
431436
if quant_params.axis == 0:
432437
quant_params.axis = len(dims) - 1
438+
elif quant_params.axis == 1:
439+
quant_params.axis = 0
433440
else:
434-
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
441+
assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1."
435442

436443
# Serialize tensor value
437444
ser_val = (
@@ -492,9 +499,10 @@ def get_serialized_buffer_index(
492499
xnn_graph: XNNGraph,
493500
vals_to_ids: Dict[torch.fx.Node, int],
494501
convert_to_nhwc: bool,
495-
swap_nc_for_depthwise_weights: bool,
502+
swap_in_out_for_weights: bool,
496503
quant_params: Optional[QuantParams],
497504
fp32_static_weights: bool = False,
505+
groups: int = 1,
498506
) -> int:
499507
"""
500508
If tensor holds some constant data, serialize it and return the
@@ -507,24 +515,30 @@ def get_serialized_buffer_index(
507515
their corresponding ids in XNNGraph
508516
convert_to_nhwc: bool to indicate whether tensor shape should be permuted to
509517
reflect the nhwc memory format.
510-
swap_nc_for_depthwise_weights: bool to indicate whether tensor shape
511-
should be permuted such that the N and C dimensions are
512-
swapped, which should be used for depthwise convolution
518+
swap_in_out_for_weights: bool to indicate whether tensor shape should be
519+
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
520+
, which should be used for depthwise/transpose convolution
513521
weights. This is only valid for tensors which hold
514522
constant data. If used along with convert_to_nhwc, this
515523
swap will happen before converting to nhwc.
516524
quant_params: Quantization meta data for this tensor, None if it is not quantize
517525
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+
groups: groups for swap_in_out_for_weights
518527
519528
Returns:
520529
buffer_idx: idx of the serialized data. 0 If not associated constant
521530
data
522531
"""
532+
533+
assert (
534+
swap_in_out_for_weights or groups == 1
535+
), "groups is option for swap_in_out_for_weights"
536+
523537
# The get_attr node is the input to quant_params.
524538
get_attr_node = tensor if quant_params is None else quant_params.q_input
525539
if not is_param_node(self.exported_program, get_attr_node):
526540
check_or_raise(
527-
not swap_nc_for_depthwise_weights,
541+
not swap_in_out_for_weights,
528542
"Swapping N and C dimensions is only valid for constant data tensors",
529543
)
530544
return 0
@@ -541,9 +555,16 @@ def get_serialized_buffer_index(
541555
# ensure that the const is fp32
542556
const_val = const_val.to(dtype=torch.float32).contiguous()
543557

544-
if swap_nc_for_depthwise_weights:
545-
const_val = const_val.permute(
546-
dims=((1, 0) + tuple(range(2, const_val.dim())))
558+
if swap_in_out_for_weights:
559+
# Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
560+
# which should be used for depthwise/transpose convolution weights for XNNPACK
561+
shape = const_val.shape
562+
const_val = const_val.reshape(
563+
(groups, const_val.shape[0] // groups) + const_val.shape[1:]
564+
)
565+
const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim())))
566+
const_val = const_val.reshape(
567+
(shape[1] * groups, shape[0] // groups) + shape[2:]
547568
).contiguous()
548569

549570
if convert_to_nhwc:

backends/xnnpack/operators/op_conv2d.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1717
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1818
XNNConv2d,
19+
XNNConvTranspose2d,
1920
XNNDepthwiseConv2d,
2021
XNNGraph,
2122
XNode,
@@ -52,35 +53,57 @@ def define_node(
5253
) # NHWC input
5354
kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)]
5455

55-
# filter shape for pytorch convolution is (oc, inc/groups, height, width)
56-
# shape for xnnpack convolution is (oc, height, width, inc/groups), to convert
57-
# to the proper shape, this is essentially a NCHW to NHWC conversion
56+
# filter shape for pytorch convolution is (oc, inc/groups, height, width),
57+
# filter shape for pytorch transpose convolution is (inc, oc/groups, height, width),
58+
# shape for xnnpack convolution is (oc, height, width, inc/groups),
59+
# shape for xnnpack transpose convolution is (oc, height, width, inc/groups),
60+
# to convert to the proper shape, this is essentially a NCHW to NHWC conversion
5861
kernel_node = get_input_node(node, 1)
5962
kernel_shape = get_shape(kernel_node)
6063
groups = cast(int, node.args[8])
61-
group_input_channels = kernel_shape[1]
62-
group_output_channels = int(kernel_shape[0] / groups)
64+
is_transpose = node.args[6]
65+
66+
if is_transpose:
67+
group_input_channels = int(kernel_shape[0] / groups)
68+
group_output_channels = kernel_shape[1]
69+
else:
70+
group_input_channels = kernel_shape[1]
71+
group_output_channels = int(kernel_shape[0] / groups)
6372

6473
# XNNPACK expects the kernel's N and C dimensions to be swapped for
6574
# Depthwise Convolution, which occurs under the following conditions:
6675
# 1) groups = input_channels (i.e. group_input_channels = 1)
6776
# 2) output_channels is a positive integer multiple of input channels
68-
is_depthwise_conv = (group_input_channels == 1) and (
69-
group_output_channels % group_input_channels == 0
77+
is_depthwise_conv = (
78+
(group_input_channels == 1)
79+
and (group_output_channels % group_input_channels == 0)
80+
and not is_transpose
7081
)
7182
weight_quant_params = QuantParams.from_weights(
7283
kernel_node, self._exported_program
7384
)
7485
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
7586

87+
if weight_quant_params is not None and weight_quant_params.per_channel:
88+
if is_transpose:
89+
check_or_raise(
90+
weight_quant_params.axis == 1 and groups == 1,
91+
"XNNPACK currently only supports per output channel quantization with groups == 1 for transpose convolutions",
92+
)
93+
elif is_depthwise_conv:
94+
check_or_raise(
95+
weight_quant_params.axis == 0,
96+
"XNNPACK currently only supports per input channel quantization for depthwise convolutions",
97+
)
7698
self.define_tensor(
7799
kernel_node,
78100
xnn_graph,
79101
vals_to_ids,
80102
convert_to_nhwc=True,
81-
swap_nc_for_depthwise_weights=is_depthwise_conv,
103+
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
82104
quant_params=weight_quant_params,
83105
fp32_static_weights=fp32_static_weights,
106+
groups=groups if is_transpose else 1,
84107
)
85108
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
86109

@@ -120,10 +143,6 @@ def define_node(
120143
if len(padding) == 1:
121144
padding = padding + padding
122145

123-
# args[6] = transposed
124-
check_or_raise(
125-
not cast(bool, node.args[6]), "No support for transposed convolution"
126-
)
127146
# args[7] = output padding
128147
check_or_raise(
129148
all(out_pad == 0 for out_pad in cast(List[int], node.args[7])),
@@ -152,6 +171,8 @@ def define_node(
152171

153172
if is_depthwise_conv:
154173
conv_node_type = XNNDepthwiseConv2d
174+
elif is_transpose:
175+
conv_node_type = XNNConvTranspose2d
155176
else:
156177
conv_node_type = XNNConv2d
157178

backends/xnnpack/partition/config/gemm_configs.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, List, Optional, Tuple
1010

1111
import torch
12+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1213
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1314
ConfigPrecisionType,
1415
XNNPartitionerConfig,
@@ -317,7 +318,7 @@ def __init__(self, **kwargs):
317318

318319
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
319320
"""
320-
Currently we have no support for convolution 3d and transposed convolution
321+
Currently we have no support for convolution 3d
321322
"""
322323
if not super().check_constraints(node, ep):
323324
return False
@@ -327,11 +328,24 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
327328
why(node, "Only support 1D + 2D Conv")
328329
return False # Only support 1D + 2D Conv
329330

330-
transposed = cast(bool, node.args[6])
331-
if transposed:
332-
why(node, "Transposed Conv is not supported")
333-
return False # Currently don't support transposed conv
331+
kernel_node = get_input_node(node, 1)
332+
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
334333

334+
is_transpose = node.args[6]
335+
groups = cast(int, node.args[8])
336+
337+
if (
338+
is_transpose
339+
and weight_quant_params is not None
340+
and weight_quant_params.per_channel
341+
and (groups > 1 or weight_quant_params.axis != 1)
342+
):
343+
why(
344+
node,
345+
"XNNPACK does not support per input channel quantization"
346+
"for transpose convolutions with groups > 1",
347+
)
348+
return False
335349
return True
336350

337351
def supported_precision_types(self):

backends/xnnpack/partition/configs.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
torch.nn.BatchNorm2d,
7474
torch.nn.BatchNorm1d,
7575
torch.nn.Conv2d,
76+
torch.nn.ConvTranspose2d,
7677
torch.nn.Linear,
7778
torch.nn.functional.linear,
7879
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr

0 commit comments

Comments
 (0)