Skip to content

Commit c5cc6e3

Browse files
feat: Add converter support for logical_and (#1856)
1 parent 137e849 commit c5cc6e3

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

core/conversion/converters/impl/element_wise.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,29 @@ auto element_wise_registrations TORCHTRT_UNUSED =
810810
return true;
811811
}})
812812
.pattern(
813+
{"aten::logical_and(Tensor self, Tensor other) -> (Tensor)",
814+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
815+
// torch.logical_and autocasts inputs to bool
816+
auto input_as_bool = [&](int idx) {
817+
auto x = args[idx].ITensorOrFreeze(ctx);
818+
if (x->getType() != nvinfer1::DataType::kBOOL) {
819+
x = castITensor(
820+
ctx, x, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_bool_" + str(idx)).c_str());
821+
}
822+
return x;
823+
};
824+
auto self = input_as_bool(0);
825+
auto other = input_as_bool(1);
826+
827+
auto and_layer =
828+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kAND, self, other, util::node_info(n) + "_and");
829+
TORCHTRT_CHECK(and_layer, "Unable to create and layer from node: " << *n);
830+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], and_layer->getOutput(0));
831+
832+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
833+
return true;
834+
}})
835+
.pattern(
813836
{"aten::atan2(Tensor self, Tensor other) -> (Tensor)",
814837
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
815838
// Element-wise divide input Tensors, apply atan unary, apply quadrant correction

tests/core/conversion/converters/test_comparators.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,14 @@ TEST(Converters, ATenMinConvertsCorrectly) {
134134
pointwise_test_helper(graph, false, false, {4}, {3, 4});
135135
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
136136
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
137-
}
137+
}
138+
139+
TEST(Converters, ATenLogicalAndConvertsCorrectly) {
140+
const auto graph = R"IR(
141+
graph(%0 : Tensor, %1 : Tensor):
142+
%2 : Tensor = aten::logical_and(%0, %1)
143+
return (%2))IR";
144+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kBool, at::kBool);
145+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kBool);
146+
pointwise_test_helper(graph, false, false, {5, 5}, {5, 5}, false, at::kInt, at::kInt);
147+
}

0 commit comments

Comments
 (0)