Skip to content

Commit dfd8c36

Browse files
authored
Atan2 converter (#1381)
1 parent 73a13c5 commit dfd8c36

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

core/conversion/converters/impl/element_wise.cpp

+90-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
#include <torch/torch.h>
1+
#include "c10/util/MathConstants.h"
22
#include "core/conversion/converters/converter_util.h"
33
#include "core/conversion/converters/converters.h"
44
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
56

67
namespace torch_tensorrt {
78
namespace core {
@@ -804,6 +805,94 @@ auto element_wise_registrations TORCHTRT_UNUSED =
804805

805806
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
806807
return true;
808+
}})
809+
.pattern(
810+
{"aten::atan2(Tensor self, Tensor other) -> (Tensor)",
811+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
812+
// Element-wise divide input Tensors, apply atan unary, apply quadrant correction
813+
auto self = args[0].ITensorOrFreeze(ctx);
814+
auto other = args[1].ITensorOrFreeze(ctx);
815+
816+
// atan(self / other)
817+
auto intermediate_div = add_elementwise(
818+
ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_intermediate_div");
819+
auto atan2_intermediate =
820+
ctx->net->addUnary(*intermediate_div->getOutput(0), nvinfer1::UnaryOperation::kATAN);
821+
822+
// Constant tensors used for quadrant correction
823+
auto ZERO = tensor_to_const(ctx, torch::tensor({0.}));
824+
auto ONE = tensor_to_const(ctx, torch::tensor({1.}));
825+
auto TWO = tensor_to_const(ctx, torch::tensor({2.}));
826+
// Using PI float for TRT compatibility, however double is preferred for PyTorch
827+
auto PI = tensor_to_const(ctx, torch::tensor({c10::pi<float>}));
828+
829+
// Quadrant correction is only needed when (other < 0) (elementwise)
830+
// In this scenario, the correction is +/- pi, depending on the sign of self (elementwise)
831+
832+
// Full atan2 Formula is given by:
833+
// atan2(self, other) = atan(self / other) - (other < 0) * (2 * (self < 0) - 1) * pi
834+
835+
// Mask of (other < 0)
836+
auto other_mask = add_elementwise(
837+
ctx,
838+
nvinfer1::ElementWiseOperation::kLESS,
839+
other,
840+
ZERO,
841+
util::node_info(n) + "_less_than_zero_other_mask");
842+
843+
// Mask of (self < 0)
844+
auto self_mask = add_elementwise(
845+
ctx,
846+
nvinfer1::ElementWiseOperation::kLESS,
847+
self,
848+
ZERO,
849+
util::node_info(n) + "_greater_than_zero_self_mask");
850+
851+
// Apply 2 * x - 1 to translate mask from {0, 1} to {-1, 1}
852+
self_mask = add_elementwise(
853+
ctx,
854+
nvinfer1::ElementWiseOperation::kPROD,
855+
self_mask->getOutput(0),
856+
TWO,
857+
util::node_info(n) + "_greater_than_zero_times_two_self_mask");
858+
self_mask = add_elementwise(
859+
ctx,
860+
nvinfer1::ElementWiseOperation::kSUB,
861+
self_mask->getOutput(0),
862+
ONE,
863+
util::node_info(n) + "_greater_than_zero_normalized_self_mask");
864+
865+
// Multiply mask by pi
866+
self_mask = add_elementwise(
867+
ctx,
868+
nvinfer1::ElementWiseOperation::kPROD,
869+
self_mask->getOutput(0),
870+
PI,
871+
util::node_info(n) + "_greater_than_zero_times_pi_self_mask");
872+
873+
// Take product of masks to generate correction term
874+
auto correction_term = add_elementwise(
875+
ctx,
876+
nvinfer1::ElementWiseOperation::kPROD,
877+
other_mask->getOutput(0),
878+
self_mask->getOutput(0),
879+
util::node_info(n) + "_correction_term");
880+
881+
// Add correction term to atan(self/other) to obtain atan2(self, other)
882+
auto atan2 = add_elementwise(
883+
ctx,
884+
nvinfer1::ElementWiseOperation::kSUB,
885+
atan2_intermediate->getOutput(0),
886+
correction_term->getOutput(0),
887+
util::node_info(n) + "_corrected_atan2");
888+
889+
TORCHTRT_CHECK(atan2, "Unable to create atan2 layer from node: " << *n);
890+
891+
atan2->setName(util::node_info(n).c_str());
892+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], atan2->getOutput(0));
893+
894+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
895+
return true;
807896
}});
808897

809898
} // namespace

tests/core/conversion/converters/test_element_wise.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,72 @@ TEST(Converters, ATenRemainderWithScalarConvertsCorrectly) {
538538

539539
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
540540
}
541+
542+
TEST(Converters, ATenAtan2ConvertsCorrectly) {
543+
const auto graph = R"IR(
544+
graph(%x.0 : Tensor, %x.1 : Tensor):
545+
%2 : Tensor = aten::atan2(%x.0, %x.1)
546+
return (%2))IR";
547+
548+
auto g = std::make_shared<torch::jit::Graph>();
549+
torch::jit::parseIR(graph, g.get());
550+
551+
// Resize range to [-1, 1] to span multiple quadrants
552+
auto in_0 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1;
553+
auto in_1 = -2 * at::rand({2, 3, 5, 5}, {at::kCUDA}) + 1;
554+
555+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
556+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
557+
558+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
559+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
560+
561+
ASSERT_TRUE(
562+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
563+
}
564+
565+
TEST(Converters, ATenAtan2ManagesPosInfCorrectly) {
566+
const auto graph = R"IR(
567+
graph(%x.0 : Tensor, %x.1 : Tensor):
568+
%2 : Tensor = aten::atan2(%x.0, %x.1)
569+
return (%2))IR";
570+
571+
auto g = std::make_shared<torch::jit::Graph>();
572+
torch::jit::parseIR(graph, g.get());
573+
574+
// Expecting PI/2
575+
auto in_0 = at::ones({4, 1, 7, 8}, {at::kCUDA});
576+
auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA});
577+
578+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
579+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
580+
581+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
582+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
583+
584+
ASSERT_TRUE(
585+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
586+
}
587+
588+
TEST(Converters, ATenAtan2ManagesNegInfCorrectly) {
589+
const auto graph = R"IR(
590+
graph(%x.0 : Tensor, %x.1 : Tensor):
591+
%2 : Tensor = aten::atan2(%x.0, %x.1)
592+
return (%2))IR";
593+
594+
auto g = std::make_shared<torch::jit::Graph>();
595+
torch::jit::parseIR(graph, g.get());
596+
597+
// Expecting -PI/2
598+
auto in_0 = -1 * at::ones({4, 1, 7, 8}, {at::kCUDA});
599+
auto in_1 = at::zeros({4, 1, 7, 8}, {at::kCUDA});
600+
601+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
602+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
603+
604+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
605+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
606+
607+
ASSERT_TRUE(
608+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
609+
}

0 commit comments

Comments
 (0)