Skip to content

Commit 137e849

Browse files
feat: Add support for aten::where with scalar other (#1855)
1 parent 1c9b2a1 commit 137e849

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

core/conversion/converters/impl/select.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,39 @@ auto select_registrations TORCHTRT_UNUSED =
800800

801801
layer->setName(util::node_info(n).c_str());
802802

803+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
804+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
805+
return true;
806+
}})
807+
.pattern(
808+
{"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> (Tensor)",
809+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
810+
auto condition = args[0].ITensorOrFreeze(ctx);
811+
auto condition_nbDims = condition->getDimensions().nbDims;
812+
auto self = args[1].ITensorOrFreeze(ctx);
813+
auto x_nbDims = self->getDimensions().nbDims;
814+
815+
// Get maximum rank of all input tensors
816+
auto max_nbDims = std::max(condition_nbDims, x_nbDims);
817+
818+
// TensorRT requires all inputs to Select layers to have the same rank, so for each
819+
// tensor input, ensure that its rank is equal to the maximum number of dimensions
820+
// If not, left-pad the tensor dimension with 1s until the max rank is achieved
821+
condition =
822+
addPadding(ctx, n, condition, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
823+
self = addPadding(ctx, n, self, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
824+
825+
// Create a scalar tensor of rank max_nbDims from scalar other
826+
auto scalar_value = args[2].unwrapToScalar();
827+
std::vector<int64_t> dims_vec(max_nbDims, 1);
828+
auto self_dtype = util::TRTDataTypeToScalarType(self->getType());
829+
auto constant_tensor = torch::full(dims_vec, scalar_value, {torch::dtype(self_dtype)});
830+
auto constant_itensor = converters::tensor_to_const(ctx, constant_tensor);
831+
832+
auto layer = ctx->net->addSelect(*condition, *self, *constant_itensor);
833+
TORCHTRT_CHECK(layer, "Unable to create select layer for aten::where.ScalarOther");
834+
layer->setName(util::node_info(n).c_str());
835+
803836
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
804837
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
805838
return true;

tests/core/conversion/converters/test_select.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -1389,3 +1389,33 @@ TEST(Converters, WhereConvertsMismatchedShapesCorrectly) {
13891389

13901390
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
13911391
}
1392+
1393+
TEST(Converters, WhereScalarConvertsCorrectly) {
1394+
const auto graph = R"IR(
1395+
graph(%input : Tensor,
1396+
%condition : Tensor):
1397+
%scalar : int = prim::Constant[value=-1]()
1398+
%10 : Tensor = aten::where(%condition, %input, %scalar)
1399+
return (%10))IR";
1400+
1401+
auto g = std::make_shared<torch::jit::Graph>();
1402+
1403+
torch::jit::parseIR(graph, g.get());
1404+
1405+
auto condition = at::randint(0, 2, {1, 435, 1}, {at::kCUDA}).to(at::kBool);
1406+
auto input = at::randn({1, 435, 11}, {at::kCUDA});
1407+
1408+
auto jit_condition = at::clone(condition);
1409+
auto jit_input = at::clone(input);
1410+
1411+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1412+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_input, jit_condition});
1413+
1414+
auto trt_condition = at::clone(condition);
1415+
auto trt_input = at::clone(input);
1416+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_input, trt_condition});
1417+
1418+
for (size_t i = 0; i < jit_results.size(); i++) {
1419+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6));
1420+
}
1421+
}

0 commit comments

Comments
 (0)