Skip to content

Commit 774df50

Browse files
authored
Revert "Revert "Revert of #3822 "Revert "implementing sym_numel_custom"" (#3826)" (#3835)" (#3837)
This reverts commit 3af0873.
1 parent 3af0873 commit 774df50

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

torch_xla/csrc/tensor_impl.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
120120
sizes.size());
121121
}
122122

123+
c10::SymInt XLATensorImpl::sym_numel_custom() const {
124+
auto sym_sizes = sym_sizes_custom();
125+
c10::SymInt prod{1};
126+
for (auto s : sym_sizes) {
127+
prod *= s;
128+
}
129+
return prod;
130+
}
131+
123132
c10::SymIntArrayRef XLATensorImpl::sym_sizes() const {
124133
// it isn't strictly necessary to delegate to `sym_sizes_custom`
125134
// however, it's consistent with pytorch core

torch_xla/csrc/tensor_impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class XLATensorImpl : public c10::TensorImpl {
3535
at::IntArrayRef sizes_custom() const override;
3636
c10::SymIntArrayRef sym_sizes() const override;
3737
c10::SymIntArrayRef sym_sizes_custom() const override;
38+
c10::SymInt sym_numel_custom() const override;
3839
at::IntArrayRef strides_custom() const override;
3940

4041
int64_t dim_custom() const override;

0 commit comments

Comments
 (0)