Skip to content

Commit 7b6747d

Browse files
authored
Respect default dtype (#3798)
* Respect default dtype Prior to the linked PR, the python API never actually passed nullopt and so the `OrFloat` path was never tested. * Trigger CI * Remove .torch_pin
1 parent 774df50 commit 7b6747d

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torch_xla/csrc/aten_xla_type.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ torch::lazy::BackendDevice GetXlaDeviceOrCurrent(
8484
return xla_device_opt ? *xla_device_opt : GetCurrentDevice();
8585
}
8686

87-
at::ScalarType GetScalarTypeOrFloat(c10::optional<at::ScalarType> scalar_type) {
88-
return scalar_type ? *scalar_type : at::ScalarType::Float;
89-
}
90-
9187
bool IsOperationOnType(const c10::optional<at::ScalarType>& opt_dtype,
9288
at::ScalarType tensor_type, at::ScalarType type) {
9389
if (opt_dtype && *opt_dtype == type) {
@@ -1286,7 +1282,7 @@ at::Tensor XLANativeFunctions::empty(
12861282
// s_copy_().
12871283
return bridge::AtenFromXlaTensor(XLATensor::full(
12881284
XlaHelpers::I64List(size), 0, GetXlaDeviceOrCurrent(device),
1289-
GetScalarTypeOrFloat(dtype)));
1285+
at::dtype_or_default(dtype)));
12901286
}
12911287

12921288
at::Tensor XLANativeFunctions::empty_symint(
@@ -1689,7 +1685,7 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start,
16891685
}
16901686

16911687
return bridge::AtenFromXlaTensor(
1692-
XLATensor::linspace(start, end, steps, GetScalarTypeOrFloat(dtype),
1688+
XLATensor::linspace(start, end, steps, at::dtype_or_default(dtype),
16931689
GetXlaDeviceOrCurrent(device)));
16941690
}
16951691

0 commit comments

Comments
 (0)