Skip to content

Commit c5b6202

Browse files
committed
feat(aten::matmul|aten::addmm): Adds support for aten::matmul and
aten::admm Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d945eb9 commit c5b6202

17 files changed

+197
-95
lines changed

Diff for: core/conversion/conversion.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
7373
LOG_DEBUG(ctx->logger, "Node input is a value that needs to be evaluated");
7474
auto eval = EvaluateNode(ctx, input_node);
7575
if (eval) {
76-
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
76+
if (!eval.value().isTensor()) {
77+
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
78+
} else {
79+
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
80+
}
7781
ctx->evaluated_value_map[input] = std::move(eval.value());
7882
node_args.push_back(&(ctx->evaluated_value_map[input]));
7983
} else {

Diff for: core/conversion/converters/Arg.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ std::string Arg::type_name() const {
8989
}
9090

9191
const torch::jit::IValue* Arg::IValue() const {
92+
TRTORCH_CHECK(isIValue(), "Requested IValue from Arg, however arg type is " << type_name());
9293
if (type_ == Type::kIValue) {
9394
return ptr_.ivalue;
9495
} else {
@@ -97,6 +98,7 @@ const torch::jit::IValue* Arg::IValue() const {
9798
}
9899

99100
nvinfer1::ITensor* Arg::ITensor() const {
101+
TRTORCH_CHECK(isITensor(), "Requested ITensor from Arg, however arg type is " << type_name());
100102
if (type_ == Type::kITensor) {
101103
return ptr_.tensor;
102104
} else {

Diff for: core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
"impl/conv_deconv.cpp",
1616
"impl/element_wise.cpp",
1717
"impl/linear.cpp",
18+
"impl/matrix_multiply.cpp",
1819
"impl/pooling.cpp",
1920
"impl/reduce.cpp",
2021
"impl/shuffle.cpp",

Diff for: core/conversion/converters/impl/element_wise.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
1414

1515
TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
1616

17+
if (self_dims != other_dims) {
18+
LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims);
19+
auto other_shuffle = ctx->net->addShuffle(*other);
20+
other_shuffle->setReshapeDimensions(self_dims);
21+
other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str());
22+
other = other_shuffle->getOutput(0);
23+
}
24+
25+
1726
nvinfer1::ILayer* ele;
1827
if (scalar != 1) {
1928
LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform");

Diff for: core/conversion/converters/impl/linear.cpp

+1-8
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@ namespace impl {
99
namespace {
1010

1111
auto linear_registrations = RegisterNodeConversionPatterns()
12-
// .pattern({
13-
// "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
14-
// [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> {
15-
// auto in = args[0].ITensor();
16-
17-
// }
18-
// })
1912
.pattern({
2013
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
2114
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -71,4 +64,4 @@ auto linear_registrations = RegisterNodeConversionPatterns()
7164
} // namespace converters
7265
} // namespace conversion
7366
} // namespace core
74-
} // trtorch
67+
} // namespace trtorch

Diff for: core/conversion/converters/impl/matrix_multiply.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "core/util/prelude.h"
2+
#include "core/conversion/converters/converters.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace conversion {
7+
namespace converters {
8+
namespace impl {
9+
namespace {
10+
11+
auto mm_registrations = RegisterNodeConversionPatterns()
12+
.pattern({
13+
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
14+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
nvinfer1::ITensor* self;
16+
if (args[0].isIValue()) {
17+
auto t = args[0].unwrapToTensor();
18+
auto t_weights = Weights(ctx, t);
19+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
20+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n);
21+
const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str());
22+
self = const_layer->getOutput(0);
23+
} else {
24+
self = args[0].ITensor();
25+
}
26+
LOG_DEBUG("self tensor shape: " << self->getDimensions());
27+
28+
nvinfer1::ITensor* other;
29+
if (args[1].isIValue()) {
30+
auto t = args[1].unwrapToTensor();
31+
auto t_weights = Weights(ctx, t);
32+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
33+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n);
34+
const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str());
35+
other = const_layer->getOutput(0);
36+
} else {
37+
other = args[1].ITensor();
38+
}
39+
LOG_DEBUG("other tensor shape: " << other->getDimensions());
40+
41+
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
42+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
43+
mm_layer->setName(util::node_info(n).c_str());
44+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
45+
46+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
47+
return true;
48+
}
49+
});
50+
} // namespace
51+
} // namespace impl
52+
} // namespace converters
53+
} // namespace conversion
54+
} // namespace core
55+
} // namespace trtorch

Diff for: core/lowering/BUILD

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ cc_library(
88
srcs = [
99
"lowering.cpp",
1010
"drop_unused_nodes.cpp",
11+
"register_const_op.cpp"
1112
],
1213
deps = [
1314
"@libtorch//:libtorch",
1415
"//core/lowering/passes",
1516
"//core/util:prelude"
16-
]
17+
],
18+
alwayslink = True
1719
)
1820

1921
load("@rules_pkg//:pkg.bzl", "pkg_tar")

Diff for: core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2525
torch::jit::FuseLinear(g);
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
28+
passes::UnpackAddMM(g);
2829
passes::ExpandLogSoftmax(g);
2930
//passes::RemoveDimExeception(g);
3031
//irfusers::UnpackBatchNorm(g);

Diff for: core/lowering/passes/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ cc_library(
1010
"expand_log_softmax.cpp",
1111
"remove_dropout.cpp",
1212
"unpack_batch_norm.cpp",
13-
"exception_elimination.cpp"
13+
"exception_elimination.cpp",
14+
"unpack_addmm.cpp"
1415
],
1516
deps = [
1617
"//core/util:prelude",

Diff for: core/lowering/passes/fuse_flatten_linear.cpp

-33
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,6 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
4040
flatten_linear_bias_none_to_linear.runOnGraph(graph);
4141
}
4242

43-
void FuseFlattenAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
44-
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
45-
std::string flatten_linear_pattern = R"IR(
46-
graph(%input, %6, %7, %weight, %bias):
47-
%flat = aten::flatten(%input, %6, %7)
48-
%res = aten::linear(%flat, %weight, %bias)
49-
return (%res))IR";
50-
std::string flatten_linear_bias_none_pattern = R"IR(
51-
graph(%input, %6, %7, %weight):
52-
%flat = aten::flatten(%input, %6, %7)
53-
%bias: Tensor? = prim::Constant()
54-
%res = aten::linear(%flat, %weight, %bias)
55-
return (%res))IR";
56-
std::string fused_linear = R"IR(
57-
graph(%input, %6, %7, %weight, %bias):
58-
%res = aten::linear(%input, %weight, %bias)
59-
return (%res))IR";
60-
61-
std::string fused_linear_bias_none = R"IR(
62-
graph(%input, %6, %7, %weight):
63-
%bias: Tensor? = prim::Constant()
64-
%res = aten::linear(%input, %weight, %bias)
65-
return (%res))IR";
66-
67-
torch::jit::SubgraphRewriter flatten_linear_to_linear;
68-
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
69-
flatten_linear_to_linear.runOnGraph(graph);
70-
71-
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
72-
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
73-
flatten_linear_bias_none_pattern, fused_linear_bias_none);
74-
flatten_linear_bias_none_to_linear.runOnGraph(graph);
75-
}
7643
} // namespace passes
7744
} // namespace lowering
7845
} // namespace core

Diff for: core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1111
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1212
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1313
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
14+
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1415
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1516

1617
} // namespace irfusers

Diff for: core/lowering/passes/unpack_addmm.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "torch/csrc/jit/passes/fuse_linear.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
namespace trtorch {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
10+
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
11+
std::string addmm_pattern = R"IR(
12+
graph(%b, %x, %w, %1):
13+
%out: Tensor = aten::addmm(%b, %x, %w, %1, %1)
14+
return (%out))IR";
15+
std::string mm_add_pattern = R"IR(
16+
graph(%b, %x, %w, %1):
17+
%mm: Tensor = aten::matmul(%x, %w)
18+
%bias: Tensor = trt::const(%b)
19+
%out: Tensor = aten::add_(%bias, %mm, %1)
20+
return (%out))IR";
21+
22+
23+
torch::jit::SubgraphRewriter unpack_addmm;
24+
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
25+
unpack_addmm.runOnGraph(graph);
26+
}
27+
28+
29+
} // namespace passes
30+
} // namespace lowering
31+
} // namespace core
32+
} // namespace trtorch

Diff for: core/lowering/passes/unpack_batch_norm.cpp

-20
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,5 @@
1-
#include "torch/csrc/jit/runtime/custom_operator.h"
2-
#include "torch/csrc/jit/passes/fuse_linear.h"
31
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
42

5-
namespace torch {
6-
namespace jit {
7-
8-
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
9-
return c10::AliasAnalysisKind::FROM_SCHEMA;
10-
}
11-
12-
RegisterOperators trt_const_op_reg({
13-
Operator(
14-
"trt::const(Tensor val) -> Tensor",
15-
[](Stack& stack) {
16-
return 0; //nop
17-
},
18-
aliasAnalysisFromSchema())});
19-
20-
} // namespace jit
21-
} // namespace torch
22-
233
namespace trtorch {
244
namespace core {
255
namespace lowering {

Diff for: core/lowering/register_const_op.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "torch/csrc/jit/runtime/custom_operator.h"
2+
3+
namespace torch {
4+
namespace jit {
5+
6+
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
7+
return c10::AliasAnalysisKind::FROM_SCHEMA;
8+
}
9+
10+
/// Op marks a Tensor to be conveted from an Torch Tensor
11+
/// to a TRT constant Tensor
12+
RegisterOperators trt_const_op_reg({
13+
Operator(
14+
"trt::const(Tensor val) -> Tensor",
15+
[](Stack& stack) {
16+
return 0; //noop
17+
},
18+
aliasAnalysisFromSchema())});
19+
20+
} // namespace jit
21+
} // namespace torch

Diff for: tests/core/converters/BUILD

+26-21
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,59 @@
11
load("//tests/core/converters:converter_test.bzl", "converter_test")
22

33
converter_test(
4-
name = "test_softmax"
4+
name = "test_activation"
55
)
66

77
converter_test(
8-
name = "test_shuffle"
8+
name = "test_conv"
99
)
1010

1111
converter_test(
12-
name = "test_activation"
12+
name = "test_element_wise"
1313
)
1414

1515
converter_test(
16-
name = "test_pooling"
16+
name = "test_linear"
1717
)
1818

1919
converter_test(
20-
name = "test_unary"
20+
name = "test_matrix_multiply"
2121
)
2222

2323
converter_test(
24-
name = "test_linear"
24+
name = "test_pooling"
2525
)
2626

2727
converter_test(
28-
name = "test_element_wise"
28+
name = "test_reduce"
2929
)
3030

3131
converter_test(
32-
name = "test_conv"
32+
name = "test_shuffle"
3333
)
3434

3535
converter_test(
36-
name = "test_reduce"
36+
name = "test_softmax"
37+
)
38+
39+
converter_test(
40+
name = "test_unary"
3741
)
3842

3943
test_suite(
40-
name = "test_converters",
41-
tests = [
42-
":test_softmax",
43-
":test_shuffle",
44-
":test_activation",
45-
":test_pooling",
46-
":test_unary",
47-
":test_linear",
48-
":test_element_wise",
49-
":test_conv",
50-
":test_reduce"
51-
]
44+
name = "test_converters",
45+
tests = [
46+
":test_activation",
47+
":test_conv",
48+
":test_element_wise",
49+
":test_linear",
50+
":test_matrix_multiply",
51+
":test_pooling",
52+
":test_reduce",
53+
":test_shuffle",
54+
":test_softmax",
55+
":test_unary",
56+
]
5257
)
5358

5459

0 commit comments

Comments
 (0)