Skip to content

Commit 15f1239

Browse files
authored
Revert "Revert "Support torch.mean with dim=None (#3752)" (#3757)" (#3782)
* Revert "Revert "Support torch.mean with dim=None (#3752)" (#3757)" This reverts commit ce1bd4e. * Add torch pin * Delete .torch_pin
1 parent 787b846 commit 15f1239

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torch_xla/csrc/aten_xla_type.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -1975,13 +1975,16 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
19751975
/*keep_reduced_dimensions=*/false, dtype));
19761976
}
19771977

1978-
at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::IntArrayRef dim,
1979-
bool keepdim,
1978+
at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
1979+
at::OptionalIntArrayRef dim, bool keepdim,
19801980
c10::optional<at::ScalarType> dtype) {
19811981
XLA_FN_COUNTER("xla::");
1982+
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
19821983
return bridge::AtenFromXlaTensor(XLATensor::mean(
1983-
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(dim),
1984-
/*keep_reduced_dimensions=*/keepdim, dtype));
1984+
self_tensor,
1985+
dim ? torch::lazy::ToVector<int64_t>(*dim)
1986+
: torch::lazy::Iota<int64_t>(self_tensor->shape().get().rank()),
1987+
keepdim, dtype));
19851988
}
19861989

19871990
at::Tensor XLANativeFunctions::min(const at::Tensor& self) {

0 commit comments

Comments
 (0)