|
1 |
| -#include <torch/torch.h> |
| 1 | +#include "c10/util/MathConstants.h" |
2 | 2 | #include "core/conversion/converters/converter_util.h"
|
3 | 3 | #include "core/conversion/converters/converters.h"
|
4 | 4 | #include "core/util/prelude.h"
|
| 5 | +#include "torch/torch.h" |
5 | 6 |
|
6 | 7 | namespace torch_tensorrt {
|
7 | 8 | namespace core {
|
@@ -804,6 +805,94 @@ auto element_wise_registrations TORCHTRT_UNUSED =
|
804 | 805 |
|
805 | 806 | LOG_DEBUG("Output tensor shape: " << out->getDimensions());
|
806 | 807 | return true;
|
| 808 | + }}) |
| 809 | + .pattern( |
| 810 | + {"aten::atan2(Tensor self, Tensor other) -> (Tensor)", |
| 811 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 812 | + // Element-wise divide input Tensors, apply atan unary, apply quadrant correction |
| 813 | + auto self = args[0].ITensorOrFreeze(ctx); |
| 814 | + auto other = args[1].ITensorOrFreeze(ctx); |
| 815 | + |
| 816 | + // atan(self / other) |
| 817 | + auto intermediate_div = add_elementwise( |
| 818 | + ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_intermediate_div"); |
| 819 | + auto atan2_intermediate = |
| 820 | + ctx->net->addUnary(*intermediate_div->getOutput(0), nvinfer1::UnaryOperation::kATAN); |
| 821 | + |
| 822 | + // Constant tensors used for quadrant correction |
| 823 | + auto ZERO = tensor_to_const(ctx, torch::tensor({0.})); |
| 824 | + auto ONE = tensor_to_const(ctx, torch::tensor({1.})); |
| 825 | + auto TWO = tensor_to_const(ctx, torch::tensor({2.})); |
| 826 | + // Using PI float for TRT compatibility, however double is preferred for PyTorch |
| 827 | + auto PI = tensor_to_const(ctx, torch::tensor({c10::pi<float>})); |
| 828 | + |
| 829 | + // Quadrant correction is only needed when (other < 0) (elementwise) |
| 830 | + // In this scenario, the correction is +/- pi, depending on the sign of self (elementwise) |
| 831 | + |
| 832 | + // Full atan2 Formula is given by: |
| 833 | + // atan2(self, other) = atan(self / other) - (other < 0) * (2 * (self < 0) - 1) * pi |
| 834 | + |
| 835 | + // Mask of (other < 0) |
| 836 | + auto other_mask = add_elementwise( |
| 837 | + ctx, |
| 838 | + nvinfer1::ElementWiseOperation::kLESS, |
| 839 | + other, |
| 840 | + ZERO, |
| 841 | + util::node_info(n) + "_less_than_zero_other_mask"); |
| 842 | + |
| 843 | + // Mask of (self < 0) |
| 844 | + auto self_mask = add_elementwise( |
| 845 | + ctx, |
| 846 | + nvinfer1::ElementWiseOperation::kLESS, |
| 847 | + self, |
| 848 | + ZERO, |
| 849 | + util::node_info(n) + "_greater_than_zero_self_mask"); |
| 850 | + |
| 851 | + // Apply 2 * x - 1 to translate mask from {0, 1} to {-1, 1} |
| 852 | + self_mask = add_elementwise( |
| 853 | + ctx, |
| 854 | + nvinfer1::ElementWiseOperation::kPROD, |
| 855 | + self_mask->getOutput(0), |
| 856 | + TWO, |
| 857 | + util::node_info(n) + "_greater_than_zero_times_two_self_mask"); |
| 858 | + self_mask = add_elementwise( |
| 859 | + ctx, |
| 860 | + nvinfer1::ElementWiseOperation::kSUB, |
| 861 | + self_mask->getOutput(0), |
| 862 | + ONE, |
| 863 | + util::node_info(n) + "_greater_than_zero_normalized_self_mask"); |
| 864 | + |
| 865 | + // Multiply mask by pi |
| 866 | + self_mask = add_elementwise( |
| 867 | + ctx, |
| 868 | + nvinfer1::ElementWiseOperation::kPROD, |
| 869 | + self_mask->getOutput(0), |
| 870 | + PI, |
| 871 | + util::node_info(n) + "_greater_than_zero_times_pi_self_mask"); |
| 872 | + |
| 873 | + // Take product of masks to generate correction term |
| 874 | + auto correction_term = add_elementwise( |
| 875 | + ctx, |
| 876 | + nvinfer1::ElementWiseOperation::kPROD, |
| 877 | + other_mask->getOutput(0), |
| 878 | + self_mask->getOutput(0), |
| 879 | + util::node_info(n) + "_correction_term"); |
| 880 | + |
| 881 | + // Add correction term to atan(self/other) to obtain atan2(self, other) |
| 882 | + auto atan2 = add_elementwise( |
| 883 | + ctx, |
| 884 | + nvinfer1::ElementWiseOperation::kSUB, |
| 885 | + atan2_intermediate->getOutput(0), |
| 886 | + correction_term->getOutput(0), |
| 887 | + util::node_info(n) + "_corrected_atan2"); |
| 888 | + |
| 889 | + TORCHTRT_CHECK(atan2, "Unable to create atan2 layer from node: " << *n); |
| 890 | + |
| 891 | + atan2->setName(util::node_info(n).c_str()); |
| 892 | + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], atan2->getOutput(0)); |
| 893 | + |
| 894 | + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); |
| 895 | + return true; |
807 | 896 | }});
|
808 | 897 |
|
809 | 898 | } // namespace
|
|
0 commit comments