Skip to content

Commit 4f1a9df

Browse files
committed
fix(): first commit
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 8171f79 commit 4f1a9df

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

Diff for: core/conversion/converters/BUILD

100644100755
+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ cc_library(
2929
"impl/shuffle.cpp",
3030
"impl/softmax.cpp",
3131
"impl/unary.cpp",
32+
"impl/interpolate.cpp"
3233
],
3334
deps = [
3435
"@tensorrt//:nvinfer",

Diff for: core/conversion/converters/impl/interpolate.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
auto interpolate_registrations = RegisterNodeConversionPatterns()
13+
.pattern({
14+
"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)",
15+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16+
auto in = args[0].ITensor();
17+
18+
auto shape = util::toVec(in->getDimensions());
19+
20+
LOG_DEBUG("Shape of input is" << in);
21+
22+
std::cout << "TEST!" << std::endl;
23+
24+
return true;
25+
}
26+
});
27+
28+
29+
} // namespace
30+
} // namespace impl
31+
} // namespace converters
32+
} // namespace conversion
33+
} // namespace core
34+
} // namespace trtorch

Diff for: test.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# import torch.nn as nn
2+
# import torch
3+
4+
# class FeatureExtractor(nn.Module):
5+
# def __init__(self):
6+
# super(FeatureExtractor, self).__init__()
7+
# self.conv1 = nn.Conv2d(1, 6, 3)
8+
# self.conv2 = nn.Conv2d(6, 16, 3)
9+
10+
# def forward(self, x):
11+
# x = torch.max_pool2d(torch.relu(self.conv1(x)), (2, 2))
12+
# x = torch.max_pool2d(torch.relu(self.conv2(x)), 2)
13+
14+
# return x
15+
16+
# class Classifier(nn.Module):
17+
# def __init__(self):
18+
# super(Classifier, self).__init__()
19+
20+
# self.fc1 = nn.Linear(16*6*6, 120)
21+
# self.fc2 = nn.Linear(120, 84)
22+
# self.fc3 = nn.Linear(84, 10)
23+
24+
# def forward(self, x):
25+
# x = torch.flatten(x, 1)
26+
# x = torch.relu(self.fc1(x))
27+
# x = torch.relu(self.fc2(x))
28+
# x = self.fc3(x)
29+
30+
# return x
31+
32+
# class LeNet(nn.Module):
33+
# def __init__(self):
34+
# super(LeNet, self).__init__()
35+
# self.feat = FeatureExtractor()
36+
# self.classifier = Classifier()
37+
38+
# def forward(self, x):
39+
# x = self.feat(x)
40+
# x = self.classifier(x)
41+
42+
# return x
43+
44+
# model = LeNet()
45+
# model.eval()
46+
# traced_model = torch.jit.trace(model, torch.empty([1, 1, 32, 32]))
47+
# torch.jit.save(traced_model, 'traced_model.ts')
48+
# torch.jit.save(torch.jit.script(model), 'script_model.ts')
49+
50+
51+
import torch.nn as nn
52+
import torch
53+
import torch.nn.functional as F
54+
#import trtorch
55+
56+
class Interp(nn.Module):
57+
def __init__(self):
58+
super(Interp, self).__init__()
59+
60+
def forward(self, x):
61+
return F.interpolate(x, scale_factor=(5,5), mode='nearest')
62+
63+
model = Interp()
64+
model.eval()
65+
trace = torch.jit.trace(model, torch.empty([1, 1, 2, 2]))
66+
torch.jit.save(trace, 'trace.ts')
67+
68+
#trtorch.check_method_op_support(trace, "forward")
69+
70+

Diff for: trace.ts

2.26 KB
Binary file not shown.

0 commit comments

Comments
 (0)