diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 8b08a5505a..274cbe6a70 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -1,7 +1,8 @@ -#include +#include "c10/util/MathConstants.h" #include "core/conversion/converters/converter_util.h" #include "core/conversion/converters/converters.h" #include "core/util/prelude.h" +#include "torch/torch.h" namespace torch_tensorrt { namespace core { @@ -804,6 +805,94 @@ auto element_wise_registrations TORCHTRT_UNUSED = LOG_DEBUG("Output tensor shape: " << out->getDimensions()); return true; + }}) + .pattern( + {"aten::atan2(Tensor self, Tensor other) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Element-wise divide input Tensors, apply atan unary, apply quadrant correction + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); + + // atan(self / other) + auto intermediate_div = add_elementwise( + ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_intermediate_div"); + auto atan2_intermediate = + ctx->net->addUnary(*intermediate_div->getOutput(0), nvinfer1::UnaryOperation::kATAN); + + // Constant tensors used for quadrant correction + auto ZERO = tensor_to_const(ctx, torch::tensor({0.})); + auto ONE = tensor_to_const(ctx, torch::tensor({1.})); + auto TWO = tensor_to_const(ctx, torch::tensor({2.})); + // Using PI float for TRT compatibility, however double is preferred for PyTorch + auto PI = tensor_to_const(ctx, torch::tensor({c10::pi})); + + // Quadrant correction is only needed when (other < 0) (elementwise) + // In this scenario, the correction is +/- pi, depending on the sign of self (elementwise) + + // Full atan2 Formula is given by: + // atan2(self, other) = atan(self / other) - (other < 0) * (2 * (self < 0) - 1) * pi + + // Mask of (other < 0) + auto other_mask = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kLESS, + other, + ZERO, + util::node_info(n) + "_less_than_zero_other_mask"); + + // Mask of (self < 0) + auto self_mask = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kLESS, + self, + ZERO, + util::node_info(n) + "_greater_than_zero_self_mask"); + + // Apply 2 * x - 1 to translate mask from {0, 1} to {-1, 1} + self_mask = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + self_mask->getOutput(0), + TWO, + util::node_info(n) + "_greater_than_zero_times_two_self_mask"); + self_mask = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUB, + self_mask->getOutput(0), + ONE, + util::node_info(n) + "_greater_than_zero_normalized_self_mask"); + + // Multiply mask by pi + self_mask = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + self_mask->getOutput(0), + PI, + util::node_info(n) + "_greater_than_zero_times_pi_self_mask"); + + // Take product of masks to generate correction term + auto correction_term = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + other_mask->getOutput(0), + self_mask->getOutput(0), + util::node_info(n) + "_correction_term"); + + // Add correction term to atan(self/other) to obtain atan2(self, other) + auto atan2 = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUB, + atan2_intermediate->getOutput(0), + correction_term->getOutput(0), + util::node_info(n) + "_corrected_atan2"); + + TORCHTRT_CHECK(atan2, "Unable to create atan2 layer from node: " << *n); + + atan2->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], atan2->getOutput(0)); + + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; }}); } // namespace diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 6b1c26bbab..712b9ade11 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -538,3 +538,72 @@ TEST(Converters, ATenRemainderWithScalarConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } + +TEST(Converters, ATenAtan2ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.0 : Tensor, %x.1 : Tensor): + %2 : Tensor = aten::atan2(%x.0, %x.1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Resize range to [-1, 1] to span multiple quadrants + auto in_0 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1; + auto in_1 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1; + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenAtan2ManagesPosInfCorrectly) { + const auto graph = R"IR( + graph(%x.0 : Tensor, %x.1 : Tensor): + %2 : Tensor = aten::atan2(%x.0, %x.1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Expecting PI/2 + auto in_0 = at::ones({4, 1, 7, 8}, {at::kCUDA}); + auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenAtan2ManagesNegInfCorrectly) { + const auto graph = R"IR( + graph(%x.0 : Tensor, %x.1 : Tensor): + %2 : Tensor = aten::atan2(%x.0, %x.1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Expecting -PI/2 + auto in_0 = -1 * at::ones({4, 1, 7, 8}, {at::kCUDA}); + auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}