Skip to content

Commit 24b422e

Browse files
committed
feat(aten::view): Adds support for ATen view also fixes some tests
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c4b62a6 commit 24b422e

File tree

6 files changed

+114
-26
lines changed

6 files changed

+114
-26
lines changed

Diff for: core/conversion/converters/impl/shuffle.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
4343
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
4444
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
4545

46+
return true;
47+
}
48+
}).pattern({
49+
"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
50+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
51+
auto in = args[0].ITensor();
52+
auto in_shape = util::toVec(in->getDimensions());
53+
auto ex_tensor = torch::rand(in_shape);
54+
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();
55+
56+
auto shuffle = ctx->net->addShuffle(*in);
57+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
58+
shuffle->setReshapeDimensions(util::toDims(new_shape));
59+
shuffle->setName(util::node_info(n).c_str());
60+
61+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
62+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
63+
4664
return true;
4765
}
4866
});

Diff for: tests/core/converters/BUILD

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
load("//tests/core/converters:converter_test.bzl", "converter_test")
22

3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
310
converter_test(
411
name = "test_activation"
512
)

Diff for: tests/core/converters/converter_test.bzl

-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
config_setting(
2-
name = "use_pre_cxx11_abi",
3-
values = {
4-
"define": "abi=pre_cxx11_abi",
5-
}
6-
)
71

82
def converter_test(name, visibility=None):
93
native.cc_test(

Diff for: tests/core/converters/test_shuffle.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,29 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
7373
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
7474
auto trt = trt_results[0].reshape_as(jit_results[0]);
7575

76+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
77+
}
78+
79+
TEST(Converters, ATenViewConvertsCorrectly) {
80+
const auto graph = R"IR(
81+
graph(%0 : Tensor):
82+
%1 : int = prim::Constant[value=3]()
83+
%2 : int = prim::Constant[value=2]()
84+
%3 : int[] = prim::ListConstruct(%1, %2)
85+
%4 : Tensor = aten::view(%0, %3)
86+
return (%4))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
92+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
93+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
94+
95+
in = at::clone(in);
96+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
97+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
98+
auto trt = trt_results[0].reshape_as(jit_results[0]);
99+
76100
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
77101
}

Diff for: tests/modules/hub.py

+65-18
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,72 @@
22
import torchvision.models as models
33

44
models = {
5-
"alexnet": models.alexnet(pretrained=True),
6-
"vgg16": models.vgg16(pretrained=True),
7-
"squeezenet": models.squeezenet1_0(pretrained=True),
8-
"densenet": models.densenet161(pretrained=True),
9-
"inception_v3": models.inception_v3(pretrained=True),
5+
"alexnet": {
6+
"model": models.alexnet(pretrained=True),
7+
"path": "both"
8+
},
9+
"vgg16": {
10+
"model": models.vgg16(pretrained=True),
11+
"path": "both"
12+
},
13+
"squeezenet": {
14+
"model": models.squeezenet1_0(pretrained=True),
15+
"path": "both"
16+
},
17+
"densenet": {
18+
"model": models.densenet161(pretrained=True),
19+
"path": "both"
20+
},
21+
"inception_v3": {
22+
"model": models.inception_v3(pretrained=True),
23+
"path": "both"
24+
},
1025
#"googlenet": models.googlenet(pretrained=True),
11-
"shufflenet": models.shufflenet_v2_x1_0(pretrained=True),
12-
"mobilenet_v2": models.mobilenet_v2(pretrained=True),
13-
"resnext50_32x4d": models.resnext50_32x4d(pretrained=True),
14-
"wideresnet50_2": models.wide_resnet50_2(pretrained=True),
15-
"mnasnet": models.mnasnet1_0(pretrained=True),
16-
"resnet18": torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True),
17-
"resnet50": torch.hub.load('pytorch/vision:v0.5.0', 'resnet50', pretrained=True)}
26+
"shufflenet": {
27+
"model": models.shufflenet_v2_x1_0(pretrained=True),
28+
"path": "both"
29+
},
30+
"mobilenet_v2": {
31+
"model": models.mobilenet_v2(pretrained=True),
32+
"path": "both"
33+
},
34+
"resnext50_32x4d": {
35+
"model": models.resnext50_32x4d(pretrained=True),
36+
"path": "both"
37+
},
38+
"wideresnet50_2": {
39+
"model": models.wide_resnet50_2(pretrained=True),
40+
"path": "both"
41+
},
42+
"mnasnet": {
43+
"model": models.mnasnet1_0(pretrained=True),
44+
"path": "both"
45+
},
46+
"resnet18": {
47+
"model": torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True),
48+
"path": "both"
49+
},
50+
"resnet50": {
51+
"model":torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True),
52+
"path": "both"
53+
},
54+
"fcn_resnet101": {
55+
"model": torch.hub.load('pytorch/vision:v0.6.0', 'fcn_resnet101', pretrained=True),
56+
"path": "script"
57+
},
58+
"ssd": {
59+
"model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"),
60+
"path": "trace"
61+
}
62+
}
1863

1964
for n, m in models.items():
2065
print("Downloading {}".format(n))
21-
m = m.eval().cuda()
22-
x = torch.ones((1, 3, 224, 224)).cuda()
23-
trace_model = torch.jit.trace(m, x)
24-
torch.jit.save(trace_model, n + '_traced.jit.pt')
25-
script_model = torch.jit.script(m)
26-
torch.jit.save(script_model, n + '_scripted.jit.pt')
66+
m["model"] = m["model"].eval().cuda()
67+
x = torch.ones((1, 3, 300, 300)).cuda()
68+
if m["path"] == "both" or m["path"] == "trace":
69+
trace_model = torch.jit.trace(m["model"], [x])
70+
torch.jit.save(trace_model, n + '_traced.jit.pt')
71+
if m["path"] == "both" or m["path"] == "script":
72+
script_model = torch.jit.script(m["model"])
73+
torch.jit.save(script_model, n + '_scripted.jit.pt')

Diff for: tests/util/BUILD

-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ cc_library(
2020
],
2121
deps = [
2222
"@tensorrt//:nvinfer",
23-
"@libtorch//:libtorch",
24-
"@libtorch//:caffe2",
2523
"//core/conversion",
2624
"//core/util:prelude",
2725
"//cpp/api:trtorch",

0 commit comments

Comments
 (0)