File tree 2 files changed +10
-0
lines changed
2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -120,6 +120,15 @@ c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
120
120
sizes.size ());
121
121
}
122
122
123
+ c10::SymInt XLATensorImpl::sym_numel_custom () const {
124
+ auto sym_sizes = sym_sizes_custom ();
125
+ c10::SymInt prod{1 };
126
+ for (auto s : sym_sizes) {
127
+ prod *= s;
128
+ }
129
+ return prod;
130
+ }
131
+
123
132
c10::SymIntArrayRef XLATensorImpl::sym_sizes () const {
124
133
// it isn't strictly necessary to delegate to `sym_sizes_custom`
125
134
// however, it's consistent with pytorch core
Original file line number Diff line number Diff line change @@ -35,6 +35,7 @@ class XLATensorImpl : public c10::TensorImpl {
35
35
at::IntArrayRef sizes_custom () const override ;
36
36
c10::SymIntArrayRef sym_sizes () const override ;
37
37
c10::SymIntArrayRef sym_sizes_custom () const override ;
38
+ c10::SymInt sym_numel_custom () const override ;
38
39
at::IntArrayRef strides_custom () const override ;
39
40
40
41
int64_t dim_custom () const override ;
You can’t perform that action at this time.
0 commit comments