diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 97c07c7741aa..24b7fc5dcb81 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -84,10 +84,6 @@ torch::lazy::BackendDevice GetXlaDeviceOrCurrent( return xla_device_opt ? *xla_device_opt : GetCurrentDevice(); } -at::ScalarType GetScalarTypeOrFloat(c10::optional scalar_type) { - return scalar_type ? *scalar_type : at::ScalarType::Float; -} - bool IsOperationOnType(const c10::optional& opt_dtype, at::ScalarType tensor_type, at::ScalarType type) { if (opt_dtype && *opt_dtype == type) { @@ -1310,7 +1306,7 @@ at::Tensor XLANativeFunctions::empty( // s_copy_(). return bridge::AtenFromXlaTensor(XLATensor::full( XlaHelpers::I64List(size), 0, GetXlaDeviceOrCurrent(device), - GetScalarTypeOrFloat(dtype))); + at::dtype_or_default(dtype))); } at::Tensor XLANativeFunctions::empty_symint( @@ -1719,7 +1715,7 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, } return bridge::AtenFromXlaTensor( - XLATensor::linspace(start, end, steps, GetScalarTypeOrFloat(dtype), + XLATensor::linspace(start, end, steps, at::dtype_or_default(dtype), GetXlaDeviceOrCurrent(device))); }