Skip to content

Commit a35fbf1

Browse files
inocsinnarendasan
authored andcommitted
feat: support true_divide, floor_divide, max, min, rsub
Signed-off-by: inocsin <[email protected]>
1 parent 4d3ac4f commit a35fbf1

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

Diff for: core/conversion/converters/impl/element_wise.cpp

+112
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,61 @@ auto element_wise_registrations TRTORCH_UNUSED =
200200
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
201201
return true;
202202
}})
203+
.pattern({"aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
204+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205+
// Should implement other - alpha * self
206+
auto self = args[0].ITensorOrFreeze(ctx);
207+
auto otherScalar = args[1].unwrapToScalar().to<float>();
208+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
209+
auto scalar = args[2].unwrapToScalar().to<float>();
210+
211+
if (1 != scalar) {
212+
auto scaleW = Weights(ctx, scalar);
213+
auto unuse = Weights();
214+
// IScaleLayer assert shift, scale and power to have
215+
// the same dtype
216+
auto scaleLayer = ctx->net->addScale(
217+
*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
218+
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
219+
self = scaleLayer->getOutput(0);
220+
}
221+
222+
auto rsub =
223+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
224+
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);
225+
226+
rsub->setName(util::node_info(n).c_str());
227+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
228+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
229+
return true;
230+
}})
231+
.pattern({"aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)",
232+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233+
// Should implement other - alpha * self
234+
auto self = args[0].ITensorOrFreeze(ctx);
235+
auto other = args[1].ITensorOrFreeze(ctx);
236+
auto scalar = args[2].unwrapToScalar().to<float>();
237+
238+
if (1 != scalar) {
239+
auto scaleW = Weights(ctx, scalar);
240+
auto unuse = Weights();
241+
// IScaleLayer assert shift, scale and power to have
242+
// the same dtype
243+
auto scaleLayer = ctx->net->addScale(
244+
*self, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
245+
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
246+
self = scaleLayer->getOutput(0);
247+
}
248+
249+
auto rsub =
250+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
251+
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);
252+
253+
rsub->setName(util::node_info(n).c_str());
254+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
255+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
256+
return true;
257+
}})
203258
.pattern({"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
204259
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205260
// Should implement self / other
@@ -352,6 +407,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
352407
pow->setName(util::node_info(n).c_str());
353408
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
354409

410+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
411+
return true;
412+
}})
413+
.pattern({"aten::floor_divide(Tensor self, Tensor other) -> (Tensor)",
414+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
415+
// TODO: Remove with functionalization
416+
auto self = args[0].ITensorOrFreeze(ctx);
417+
auto other = args[1].ITensorOrFreeze(ctx);
418+
auto floor_divide =
419+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
420+
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);
421+
422+
floor_divide->setName(util::node_info(n).c_str());
423+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
424+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
425+
return true;
426+
}})
427+
.pattern({"aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)",
428+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
429+
// TODO: Remove with functionalization
430+
auto self = args[0].ITensorOrFreeze(ctx);
431+
auto otherScalar = args[1].unwrapToScalar().to<float>();
432+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
433+
auto floor_divide =
434+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
435+
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);
436+
437+
floor_divide->setName(util::node_info(n).c_str());
438+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
439+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
440+
return true;
441+
}})
442+
.pattern({"aten::max.other(Tensor self, Tensor other) -> (Tensor)",
443+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
444+
// TODO: Remove with functionalization
445+
auto self = args[0].ITensorOrFreeze(ctx);
446+
auto other = args[1].ITensorOrFreeze(ctx);
447+
auto max =
448+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, other, util::node_info(n));
449+
TRTORCH_CHECK(max, "Unable to create max layer from node: " << *n);
450+
451+
max->setName(util::node_info(n).c_str());
452+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], max->getOutput(0));
453+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
454+
return true;
455+
}})
456+
.pattern({"aten::min.other(Tensor self, Tensor other) -> (Tensor)",
457+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
458+
// TODO: Remove with functionalization
459+
auto self = args[0].ITensorOrFreeze(ctx);
460+
auto other = args[1].ITensorOrFreeze(ctx);
461+
auto min =
462+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, self, other, util::node_info(n));
463+
TRTORCH_CHECK(min, "Unable to create min layer from node: " << *n);
464+
465+
min->setName(util::node_info(n).c_str());
466+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], min->getOutput(0));
355467
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
356468
return true;
357469
}});

Diff for: core/lowering/passes/BUILD

+1-2
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,4 @@ pkg_tar(
4040
name = "include",
4141
package_dir = "core/lowering/passes/",
4242
srcs = ["passes.h"],
43-
)
44-
43+
)

Diff for: tests/core/conversion/converters/test_element_wise.cpp

+67-1
Original file line numberDiff line numberDiff line change
@@ -161,5 +161,71 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
161161
%3 : Tensor = aten::ne(%x.1, %2)
162162
return (%3))IR";
163163
pointwise_test_helper(graph, true, false, {3, 4, 2});
164-
;
164+
pointwise_test_helper(graph, true);
165165
}
166+
167+
168+
TEST(Converters, ATenFloorDivideConvertsCorrectly) {
169+
const auto graph = R"IR(
170+
graph(%0 : Tensor, %1 : Tensor):
171+
%2 : Tensor = aten::floor_divide(%0, %1)
172+
return (%2))IR";
173+
pointwise_test_helper(graph, false);
174+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
175+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
176+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
177+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
178+
}
179+
180+
181+
TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
182+
const auto graph = R"IR(
183+
graph(%0 : Tensor):
184+
%scalar : float = prim::Constant[value=2.4]()
185+
%1 : Tensor = aten::floor_divide(%0, %scalar)
186+
return (%1))IR";
187+
pointwise_test_helper(graph, true);
188+
}
189+
190+
TEST(Converters, ATenMaxConvertsCorrectly) {
191+
const auto graph = R"IR(
192+
graph(%0 : Tensor, %1 : Tensor):
193+
%2 : Tensor = aten::max(%0, %1)
194+
return (%2))IR";
195+
pointwise_test_helper(graph, false);
196+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
197+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
198+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
199+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
200+
}
201+
202+
TEST(Converters, ATenMinConvertsCorrectly) {
203+
const auto graph = R"IR(
204+
graph(%0 : Tensor, %1 : Tensor):
205+
%2 : Tensor = aten::min(%0, %1)
206+
return (%2))IR";
207+
pointwise_test_helper(graph, false);
208+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
209+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
210+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
211+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
212+
}
213+
214+
TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
215+
const auto graph = R"IR(
216+
graph(%0 : Tensor, %1 : Tensor):
217+
%2 : int = prim::Constant[value=2]()
218+
%3 : Tensor = aten::rsub(%0, %1, %2)
219+
return (%3))IR";
220+
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
221+
}
222+
223+
TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
224+
const auto graph = R"IR(
225+
graph(%0 : Tensor):
226+
%2 : int = prim::Constant[value=2]()
227+
%scalar : float = prim::Constant[value=2.4]()
228+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
229+
return (%3))IR";
230+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
231+
}

0 commit comments

Comments
 (0)