Skip to content

Commit 1b50484

Browse files
committed
feat(//core/conversion/converters/impl): all function schemas for upsample_nearest
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 4f1a9df commit 1b50484

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

Diff for: core/conversion/converters/impl/interpolate.cpp

+59-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "core/util/prelude.h"
33
#include "core/conversion/converters/converters.h"
44

5+
#include <csignal>
6+
57
namespace trtorch {
68
namespace core {
79
namespace conversion {
@@ -11,15 +13,69 @@ namespace {
1113

1214
auto interpolate_registrations = RegisterNodeConversionPatterns()
1315
.pattern({
16+
"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)",
17+
[](ConversionCtx* ctx, const torch::jit::Node*n, args& args) -> bool {
18+
TRTORCH_ASSERT(args[0].IValue()->isTensor(), "Input expected to be of type Tensor");
19+
20+
auto in = args[0].ITensor();
21+
auto in_shape = util::toVec(in->getDimensions());
22+
23+
// Case 1: user uses output size and not scales
24+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
25+
auto output_size = util::toDims(args[1].unwrapToIntList());
26+
27+
TRTORCH_ASSERT(output_size.nbDims == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch");
28+
} else {
29+
LOG_DEBUG("scale factor parameters not supported yet.");
30+
}
31+
32+
return true;
33+
}
34+
}).pattern({
1435
"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)",
1536
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
37+
// std::raise(SIGINT);
38+
TRTORCH_ASSERT(args[0].IValue()->isTensor(), "Input expected to be of type Tensor");
39+
40+
auto in = args[0].ITensor();
41+
auto in_shape = util::toVec(in->getDimensions());
42+
43+
// Case 1: user uses output_size and not scales_h, scales_w
44+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone()){
45+
auto output_size = util::toDims(args[1].unwrapToIntList());
46+
47+
TRTORCH_ASSERT( (output_size.nbDims == 1 || output_size.nbDims == 2), "aten::upsample_nearest2d input Tensor and output size dimension mismatch");
48+
49+
nvinfer1::ILayer* new_layer;
50+
51+
52+
53+
//util::toDims(args[1].unwrapToIntList());
54+
55+
} else {
56+
LOG_DEBUG("scale factor parameters not supported yet.");
57+
}
58+
59+
return true;
60+
}
61+
}).pattern({
62+
"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
63+
[](ConversionCtx* ctx, const torch::jit::Node*n, args& args) -> bool {
64+
TRTORCH_ASSERT(args[0].IValue()->isTensor(), "Input expected to be of type Tensor");
65+
1666
auto in = args[0].ITensor();
67+
auto in_shape = util::toVec(in->getDimensions());
1768

18-
auto shape = util::toVec(in->getDimensions());
69+
// Case 1: user uses output size and not scales_d, scales_h, scales_w
70+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
71+
auto output_size = util::toDims(args[1].unwrapToIntList());
1972

20-
LOG_DEBUG("Shape of input is" << in);
73+
TRTORCH_ASSERT( (output_size.nbDims == 1 || output_size.nbDims == 3), "aten::upsample_nearest3d input Tensor and output size dimension mismatch");
74+
2175

22-
std::cout << "TEST!" << std::endl;
76+
} else {
77+
LOG_DEBUG("scale factor parameters not supported yet.");
78+
}
2379

2480
return true;
2581
}

0 commit comments

Comments
 (0)