Skip to content

Commit 0548540

Browse files
committed
feat(aten::size [static]): Implement a aten::size converter for static input size
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c5b6202 commit 0548540

File tree

8 files changed

+45
-7
lines changed

8 files changed

+45
-7
lines changed

Diff for: core/conversion/conversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
7878
} else {
7979
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
8080
}
81-
ctx->evaluated_value_map[input] = std::move(eval.value());
81+
ctx->AssociateValueAndIValue(input, eval.value());
8282
node_args.push_back(&(ctx->evaluated_value_map[input]));
8383
} else {
8484
LOG_DEBUG(ctx->logger, "Found the value is None");;

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,15 @@ nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Valu
103103
return tensor;
104104
}
105105

106+
torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue ivalue) {
107+
this->evaluated_value_map[value] = std::move(ivalue);
108+
return &this->evaluated_value_map[value];
109+
}
110+
106111
std::string ConversionCtx::SerializeEngine() {
107112
auto engine = builder->buildEngineWithConfig(*net, *cfg);
108113
auto serialized_engine = engine->serialize();
114+
engine->destroy();
109115
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
110116
}
111117

Diff for: core/conversion/conversionctx/ConversionCtx.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "core/util/prelude.h"
1212

13-
1413
namespace trtorch {
1514
namespace core {
1615
namespace conversion {
@@ -39,6 +38,7 @@ struct ConversionCtx {
3938
ConversionCtx(BuilderSettings settings);
4039
std::string SerializeEngine();
4140
nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
41+
torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor);
4242
bool CheckLayerAddition(const torch::jit::Node* n);
4343

4444
~ConversionCtx();

Diff for: core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"impl/matrix_multiply.cpp",
1919
"impl/pooling.cpp",
2020
"impl/reduce.cpp",
21+
"impl/shape.cpp",
2122
"impl/shuffle.cpp",
2223
"impl/softmax.cpp",
2324
"impl/unary.cpp",

Diff for: core/conversion/converters/impl/shape.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "core/conversion/converters/converters.h"
2+
3+
#include "torch/torch.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
static auto shape_registrations = RegisterNodeConversionPatterns()
13+
.pattern({
14+
// To use in static input size cases (explicit batch)
15+
"aten::size.int(Tensor self, int dim) -> (Tensor)",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
auto in = args[0].ITensor();
18+
auto in_shape = util::toVec(in->getDimensions());
19+
20+
auto size = in_shape[args[1].unwrapToInt()];
21+
22+
ctx->AssociateValueAndIValue(n->outputs()[0], size);
23+
LOG_DEBUG("Output Value: " << size);
24+
return true;
25+
}
26+
});
27+
} // namespace
28+
} // namespace impl
29+
} // namespace converters
30+
} // namespace conversion
31+
} // namespace core
32+
} // namespace trtorch

Diff for: core/conversion/converters/impl/shuffle.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
3232
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
3333
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3434
auto in = args[0].ITensor();
35-
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);
35+
auto in_shape = util::toVec(in->getDimensions());
36+
auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes();
3637

3738
auto shuffle = ctx->net->addShuffle(*in);
3839
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
39-
shuffle->setReshapeDimensions(new_shape);
40+
shuffle->setReshapeDimensions(util::toDims(new_shape));
4041
shuffle->setName(util::node_info(n).c_str());
4142

4243
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));

Diff for: core/execution/TRTEngine.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ namespace trtorch {
1010
namespace core {
1111
namespace execution {
1212

13-
TRTEngine::TRTEngine() {}
14-
1513
TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine) {
1614
rt = nvinfer1::createInferRuntime(logger);
1715

Diff for: core/execution/execution.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct TRTEngine {
1717
std::pair<uint64_t, uint64_t> num_io;
1818
EngineID id;
1919

20-
TRTEngine();
20+
TRTEngine() = default;
2121
TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine);
2222
TRTEngine& operator=(const TRTEngine& other);
2323
};

0 commit comments

Comments
 (0)