Skip to content

Commit c5057f8

Browse files
committed
fix(aten::linear): Fixes new issues in 1.8 that cause script based
models to fail while trace models work. Seems to be down to the fact that the two create different graphs and script was having issues with aten::linear Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 71c4dcb commit c5057f8

File tree

9 files changed

+76
-22
lines changed

9 files changed

+76
-22
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,20 @@ auto aten_registrations TRTORCH_UNUSED =
396396
EvalOptions().validSchemas({
397397
"aten::numel(Tensor self) -> int",
398398
})})
399+
.evaluator({c10::Symbol::fromQualString("aten::t"),
400+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
401+
auto tensor_var = args.at(n->input(0));
402+
if (tensor_var.IValue()->isTensor()) {
403+
auto tensor = tensor_var.unwrapToTensor();
404+
return tensor.t();
405+
} else {
406+
TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
407+
return {};
408+
}
409+
},
410+
EvalOptions().validSchemas({
411+
"aten::t(Tensor self) -> Tensor",
412+
})})
399413
.evaluator({c10::Symbol::fromQualString("aten::dim"),
400414
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
401415
auto tensor_var = args.at(n->input(0));

Diff for: core/lowering/lowering.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3636
torch::jit::LowerAllTuples(g);
3737
passes::RemoveContiguous(g);
3838
passes::RemoveDropout(g);
39-
passes::FuseFlattenLinear(g);
39+
passes::LinearToAddMM(g);
4040
passes::Conv2DToConvolution(g);
4141
passes::Conv3DToConvolution(g);
4242
passes::FuseAddMMBranches(g);

Diff for: core/lowering/passes/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ cc_library(
1717
"conv3d_to_convolution.cpp",
1818
"exception_elimination.cpp",
1919
"fuse_addmm_branches.cpp",
20-
"fuse_flatten_linear.cpp",
20+
"linear_to_addmm.cpp",
2121
"remove_bn_dim_check.cpp",
2222
"remove_contiguous.cpp",
2323
"remove_dropout.cpp",

Diff for: core/lowering/passes/fuse_flatten_linear.cpp renamed to core/lowering/passes/linear_to_addmm.cpp

+17-15
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,31 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10-
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
10+
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
1111
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
1212
std::string flatten_linear_pattern = R"IR(
13-
graph(%input, %6, %7, %weight, %bias):
14-
%flat = aten::flatten(%input, %6, %7)
15-
%res = aten::linear(%flat, %weight, %bias)
13+
graph(%input, %weight, %bias):
14+
%res = aten::linear(%input, %weight, %bias)
1615
return (%res))IR";
1716
std::string flatten_linear_bias_none_pattern = R"IR(
18-
graph(%input, %6, %7, %weight):
19-
%flat = aten::flatten(%input, %6, %7)
17+
graph(%input, %weight):
2018
%bias: Tensor? = prim::Constant()
21-
%res = aten::linear(%flat, %weight, %bias)
22-
return (%res))IR";
23-
std::string fused_linear = R"IR(
24-
graph(%input, %6, %7, %weight, %bias):
2519
%res = aten::linear(%input, %weight, %bias)
2620
return (%res))IR";
2721

22+
std::string fused_linear = R"IR(
23+
graph(%input, %weight_t, %bias):
24+
%1: int = prim::Constant[value=1]()
25+
%weight = aten::t(%weight_t)
26+
%mm: Tensor = aten::matmul(%input, %weight)
27+
%b_f: Tensor = trt::const(%bias)
28+
%out: Tensor = aten::add_(%b_f, %mm, %1)
29+
return (%out))IR";
2830
std::string fused_linear_bias_none = R"IR(
29-
graph(%input, %6, %7, %weight):
30-
%bias: Tensor? = prim::Constant()
31-
%res = aten::linear(%input, %weight, %bias)
32-
return (%res))IR";
31+
graph(%input, %weight_t):
32+
%weight = aten::t(%weight_t)
33+
%mm: Tensor = aten::matmul(%input, %weight)
34+
return (%mm))IR";
3335

3436
torch::jit::SubgraphRewriter flatten_linear_to_linear;
3537
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
@@ -38,7 +40,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
3840
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
3941
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
4042
flatten_linear_bias_none_to_linear.runOnGraph(graph);
41-
LOG_GRAPH("Post flatten linear: " << *graph);
43+
LOG_GRAPH("Post linear to addmm: " << *graph);
4244
}
4345

4446
} // namespace passes

Diff for: core/lowering/passes/passes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace passes {
1010
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1111
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1212
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
13-
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
13+
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1414
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1515
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1616
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);

Diff for: tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ config_setting(
77
}
88
)
99

10+
lowering_test(
11+
name = "test_linear_to_addmm",
12+
)
13+
1014
lowering_test(
1115
name = "test_remove_contiguous_pass",
1216
)
@@ -30,6 +34,7 @@ lowering_test(
3034
test_suite(
3135
name = "lowering_tests",
3236
tests = [
37+
":test_linear_to_addmm",
3338
":test_remove_contiguous_pass",
3439
":test_remove_to",
3540
":test_remove_detach_pass",

Diff for: tests/core/lowering/test_linear_to_addmm.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, LinearToAddMM) {
10+
std::string source_graph = R"IR(
11+
graph(%input, %6, %7, %weight, %bias):
12+
%flat = aten::flatten(%input, %6, %7)
13+
%res = aten::linear(%flat, %weight, %bias)
14+
return (%res))IR";
15+
std::string target_graph = R"IR(
16+
graph(%input, %6, %7, %weight_t, %bias):
17+
%1: int = prim::Constant[value=1]()
18+
%flat = aten::flatten(%input, %6, %7)
19+
%weight = aten::t(%weight_t)
20+
%mm: Tensor = aten::matmul(%flat, %weight)
21+
%b_f: Tensor = trt::const(%bias)
22+
%out: Tensor = aten::add_(%b_f, %mm, %1)
23+
return (%out))IR";
24+
25+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
26+
auto sg = std::make_shared<torch::jit::Graph>();
27+
torch::jit::parseIR(source_graph, &*sg);
28+
trtorch::core::lowering::passes::LinearToAddMM(sg);
29+
30+
auto tg = std::make_shared<torch::jit::Graph>();
31+
torch::jit::parseIR(target_graph, &*tg);
32+
33+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
34+
}

Diff for: tests/modules/hub.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@
4646
"path": "both"
4747
},
4848
"resnet18": {
49-
"model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet18', pretrained=True),
49+
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True),
5050
"path": "both"
5151
},
5252
"resnet50": {
53-
"model": torch.hub.load('pytorch/vision:v0.8.2', 'resnet50', pretrained=True),
53+
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
5454
"path": "both"
5555
},
5656
"fcn_resnet101": {
57-
"model": torch.hub.load('pytorch/vision:v0.8.2', 'fcn_resnet101', pretrained=True),
57+
"model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True),
5858
"path": "script"
5959
},
6060
"ssd": {

Diff for: tests/py/test_api.py

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def test_is_colored_output_on(self):
7979
def test_suite():
8080
suite = unittest.TestSuite()
8181
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
82-
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet50(pretrained=True)))
8382
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
8483
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
8584
suite.addTest(unittest.makeSuite(TestLoggingAPIs))

0 commit comments

Comments
 (0)