Skip to content

Commit 34a3328

Browse files
authored
Merge pull request #1710 from mfeliz-cruise/michael.feliz/dynamic_aten_mul
[feat] Add dynamic conversion path to aten::mul evaluator
2 parents 78b571c + 02d502c commit 34a3328

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

core/conversion/evaluators/aten.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,16 @@ auto aten_registrations TORCHTRT_UNUSED =
455455
.evaluator(
456456
{c10::Symbol::fromQualString("aten::mul"),
457457
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
458+
if (!constTypesOnly(args)) {
459+
auto a = args.at(n->input(0)).ITensorOrFreeze(ctx);
460+
auto b = args.at(n->input(1)).ITensorOrFreeze(ctx);
461+
auto mul =
462+
converters::add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, a, b, util::node_info(n));
463+
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
464+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
465+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
466+
return {};
467+
}
458468
if (args.at(n->input(0)).IValue()->isInt()) {
459469
auto a = args.at(n->input(0)).unwrapToInt();
460470
auto b = args.at(n->input(1)).unwrapToInt();

core/conversion/var/Var.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,18 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9292
}
9393

9494
TORCHTRT_CHECK(
95-
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
96-
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
95+
isITensor() ||
96+
(isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isScalar() || ptr_.ivalue->isCustomClass())),
97+
"Requested either IValue containing a Tensor, Scalar or ITensor, however Var type is " << type_name());
9798

9899
nvinfer1::ITensor* out;
99100

100101
if (isIValue()) {
101102
if (ptr_.ivalue->isTensor()) {
102103
auto tensor = ptr_.ivalue->toTensor();
103104
out = converters::tensor_to_const(ctx, tensor);
105+
} else if (ptr_.ivalue->isScalar()) {
106+
out = converters::scalar_to_tensor(ctx, ptr_.ivalue->toScalar());
104107
} else {
105108
// Split converter generates c10::IValue which hold TensorContainer.
106109
auto output_container = ptr_.ivalue->toCustomClass<TensorContainer>();

tests/cpp/test_dynamic_size.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,37 @@ TEST(Converters, ATenResizeGetItemDynShapeCorrectly) {
8787

8888
auto trt = trt_results[0].reshape(jit_results[0].sizes());
8989

90+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
91+
}
92+
93+
TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
94+
const auto graph = R"IR(
95+
graph(%x.1 : Tensor):
96+
%2 : int = prim::Constant[value=0]()
97+
%3 : int = prim::Constant[value=-1]()
98+
%4 : int = prim::Constant[value=2]()
99+
%size.1 : int[] = aten::size(%x.1)
100+
%37 : int = aten::__getitem__(%size.1, %2)
101+
%38 : int = aten::mul(%37, %4)
102+
%39 : int[] = prim::ListConstruct(%38, %3)
103+
%7 : Tensor = aten::reshape(%x.1, %39)
104+
return (%7))IR";
105+
106+
auto g = std::make_shared<torch::jit::Graph>();
107+
108+
torch::jit::parseIR(graph, g.get());
109+
110+
auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA});
111+
112+
auto jit_in = at::clone(in);
113+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
114+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
115+
116+
auto trt_in = at::clone(in);
117+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
118+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
119+
120+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
121+
90122
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
91123
}

0 commit comments

Comments
 (0)