diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index d79b0610c3a6..b968a0569a32 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -120,6 +120,15 @@ c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const { sizes.size()); } +c10::SymInt XLATensorImpl::sym_numel_custom() const { + auto sym_sizes = sym_sizes_custom(); + c10::SymInt prod{1}; + for (auto s : sym_sizes) { + prod *= s; + } + return prod; +} + c10::SymIntArrayRef XLATensorImpl::sym_sizes() const { // it isn't strictly necessary to delegate to `sym_sizes_custom` // however, it's consistent with pytorch core diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index d3665431a4d3..ce8be9389594 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -35,6 +35,7 @@ class XLATensorImpl : public c10::TensorImpl { at::IntArrayRef sizes_custom() const override; c10::SymIntArrayRef sym_sizes() const override; c10::SymIntArrayRef sym_sizes_custom() const override; + c10::SymInt sym_numel_custom() const override; at::IntArrayRef strides_custom() const override; int64_t dim_custom() const override;