From 403d112f3dba87ae6d82dc6edb4fa0b9cd3a821e Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 4 Aug 2022 16:46:47 -0700 Subject: [PATCH] Revert "Revert "Revert of #3822 "Revert "implementing sym_numel_custom"" (#3826)" (#3835)" This reverts commit 3af087374c2691ec753fcacd0de4d6c6013d773a. --- torch_xla/csrc/tensor_impl.cpp | 9 +++++++++ torch_xla/csrc/tensor_impl.h | 1 + 2 files changed, 10 insertions(+) diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index d79b0610c3a6..b968a0569a32 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -120,6 +120,15 @@ c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const { sizes.size()); } +c10::SymInt XLATensorImpl::sym_numel_custom() const { + auto sym_sizes = sym_sizes_custom(); + c10::SymInt prod{1}; + for (auto s : sym_sizes) { + prod *= s; + } + return prod; +} + c10::SymIntArrayRef XLATensorImpl::sym_sizes() const { // it isn't strictly necessary to delegate to `sym_sizes_custom` // however, it's consistent with pytorch core diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index d3665431a4d3..ce8be9389594 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -35,6 +35,7 @@ class XLATensorImpl : public c10::TensorImpl { at::IntArrayRef sizes_custom() const override; c10::SymIntArrayRef sym_sizes() const override; c10::SymIntArrayRef sym_sizes_custom() const override; + c10::SymInt sym_numel_custom() const override; at::IntArrayRef strides_custom() const override; int64_t dim_custom() const override;