Skip to content

Commit 41c9913

Browse files
authored
9080 expose mat mul precision (#9081)
1 parent 24ab7c2 commit 41c9913

7 files changed

+259
-0
lines changed

test/test_mat_mul_precision.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Numeric tests for default precision of mat mul."""
2+
3+
import unittest
4+
5+
import torch
6+
import torch_xla
7+
import torch_xla.backends
8+
9+
import test_utils
10+
11+
12+
class TestMatMulPrecision(unittest.TestCase):
13+
14+
def _make_input(self):
15+
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
16+
rand_ = torch.testing.make_tensor((1024, 1024),
17+
dtype=torch.float64,
18+
device="cpu",
19+
low=0.99,
20+
high=1.01)
21+
return eye * rand_
22+
23+
# TODO: Figure out why either PT/XLA or unittest
24+
# is unable to successfully run this test in a parameterized way.
25+
# https://github.com/pytorch/xla/issues/9129
26+
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
27+
@unittest.expectedFailure
28+
def test_all(self):
29+
# The number of bit of precise mantissa expected in the result.
30+
parameters = [
31+
('highest', 22),
32+
('high', 14),
33+
('default', 8),
34+
]
35+
# Although pytest has a slightly more elegant parameterized testing function,
36+
# all TPU tests user unittest.
37+
for i, (precision, bits) in enumerate(parameters):
38+
with self.subTest(precision=precision, bits=bits):
39+
self._test_parameterized(precision, bits)
40+
41+
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
42+
def test_highest(self):
43+
self._test_parameterized('highest', 22)
44+
45+
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
46+
def test_high(self):
47+
self._test_parameterized('high', 14)
48+
49+
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.')
50+
def test_default(self):
51+
self._test_parameterized('default', 8)
52+
53+
# DO NOT add epsilons to this test. These tests must be numerically exact.
54+
def _test_parameterized(self, precision, bits):
55+
# Arrange
56+
torch_xla.backends.set_mat_mul_precision(precision)
57+
58+
# Diagonal matrices force mat mul through MXU
59+
# but require only one non-zero accumulation.
60+
x = self._make_input()
61+
y = self._make_input()
62+
reference_float64 = torch.matmul(x, y)
63+
64+
# TODO: Justify this logic. Why isn't it Why is it not
65+
# 1 - ((2**8 - 1) / 2**8)**2 (equation stated by per TPU expert)?
66+
widest_atol = torch.tensor(
67+
-1 + ((2**(bits) + 1) / 2**bits)**2, dtype=torch.float64)
68+
69+
narrowest_atol = widest_atol / 4.0
70+
71+
x = x.to(torch.float32).to('xla')
72+
y = y.to(torch.float32).to('xla')
73+
74+
# Act
75+
actual = torch.matmul(x, y).to('cpu').to(torch.float64)
76+
77+
# Disable rtol, we know exactly the atol for default, high, and highest.
78+
torch.testing.assert_close(
79+
actual,
80+
reference_float64,
81+
rtol=0.0,
82+
atol=widest_atol,
83+
)
84+
85+
with self.assertRaises(AssertionError):
86+
torch.testing.assert_close(
87+
actual,
88+
reference_float64,
89+
rtol=0.0,
90+
atol=narrowest_atol,
91+
)
92+
93+
assert not torch.equal(actual, reference_float64), (
94+
"Actual product and reference product should not be closer than equal, "
95+
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}"
96+
)
97+
98+
99+
# There is no main function. This is designed to be run from
100+
# python -m unittest ...
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp"""
2+
3+
import sys
4+
import unittest
5+
6+
import torch
7+
import torch_xla
8+
import torch_xla.backends
9+
10+
11+
class TestMatMulPrecisionGetAndSet(unittest.TestCase):
12+
13+
def setUp(self):
14+
self._original = torch_xla.backends.get_mat_mul_precision()
15+
torch.set_printoptions(precision=20)
16+
torch_xla.sync()
17+
18+
def tearDown(self):
19+
torch_xla.backends.set_mat_mul_precision(self._original)
20+
torch.set_printoptions(profile="default")
21+
torch_xla.sync()
22+
23+
def test_set_mat_mul_precision_error(self):
24+
# Assert
25+
with self.assertRaises(ValueError):
26+
# Act
27+
torch_xla.backends.set_mat_mul_precision('BAD VALUE')
28+
29+
def test_get_and_set_mat_mul_precision_default(self):
30+
# Arrange
31+
torch_xla.backends.set_mat_mul_precision('default')
32+
33+
# Act
34+
status = torch_xla.backends.get_mat_mul_precision()
35+
36+
# Assert
37+
self.assertEqual(status, 'default')
38+
39+
def test_get_and_set_mat_mul_precision_high(self):
40+
# Arrange
41+
torch_xla.backends.set_mat_mul_precision('high')
42+
43+
# Act
44+
status = torch_xla.backends.get_mat_mul_precision()
45+
46+
# Assert
47+
self.assertEqual(status, 'high')
48+
49+
def test_get_and_set_mat_mul_precision_highest(self):
50+
# Arrange
51+
torch_xla.backends.set_mat_mul_precision('highest')
52+
53+
# Act
54+
status = torch_xla.backends.get_mat_mul_precision()
55+
56+
# Assert
57+
self.assertEqual(status, 'highest')
58+
59+
60+
if __name__ == '__main__':
61+
test = unittest.main()
62+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch_xla
1414
import torch_xla.core.xla_model as xm
1515
import torch_xla.utils.utils as xu
16+
import torch_xla.runtime as xr
1617

1718

1819
def _set_rng_seed(seed):
@@ -420,3 +421,8 @@ def temporary_env(**kwargs):
420421
else:
421422
# Restore the original value
422423
os.environ[key] = old_value
424+
425+
426+
# Taken from test_operations.py
427+
def is_on_tpu():
428+
return 'XRT_TPU_CONFIG' in os.environ or xr.device_type() == 'TPU'

test/tpu/run_tests.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ TEST_CDIR="$(dirname "$CDIR")"
66
source "${TEST_CDIR}/utils/run_tests_utils.sh"
77

88
# TODO: merge with other run_tests
9+
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_high)
10+
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_default)
11+
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_highest)
12+
(cd $TEST_CDIR && python3 -m unittest test_mat_mul_precision.TestMatMulPrecision.test_all)
13+
python3 "$TEST_CDIR/test_mat_mul_precision_get_and_set.py"
914
python3 "$TEST_CDIR/test_operations.py" -v
1015
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"
1116
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py"

torch_xla/backends/__init__.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""torch_xla.backends controls the behavior of the XLA backend.
2+
3+
This subpackage parallels the torch.backends.{cuda, cpu, mps, etc}
4+
subpackages in PyTorch.
5+
"""
6+
7+
# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py
8+
# for an example of how backends are implemented in PyTorch
9+
# in the __init__.py file, despite general style guidelines against this.
10+
11+
# Literal is available from Python 3.8,
12+
# matching the Python versions for PyTorch and PyTorch/XLA.
13+
from typing import Final, Literal, TypeAlias
14+
15+
import torch_xla
16+
17+
__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"]
18+
19+
# Valid values for get_mat_mul_precision/set_mat_mul_precision
20+
# Note: it is idiomatic to PyTorch to use strings rather than enums.
21+
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9
22+
23+
_DEFAULT: Final = "default"
24+
_HIGH: Final = "high"
25+
_HIGHEST: Final = "highest"
26+
27+
# Use of variables with Final typehint instead of literals is valid.
28+
_PrecisionType: TypeAlias = Literal[
29+
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm]
30+
31+
32+
# Some of this description adapted from Jax documentation.
33+
# TODO: Once the numerics tutorial is released, link from this docstring.
34+
def set_mat_mul_precision(precision: _PrecisionType) -> None:
35+
"""Control the default matmul and conv precision for 32bit inputs.
36+
37+
Some platforms, like TPU, offer configurable precision levels for
38+
matrix multiplication and convolution computations,
39+
trading off accuracy for speed.
40+
41+
This option controls the default precision level for
42+
computations involved in matrix multiplication and convolution on
43+
32bit inputs. The levels describe the precision at
44+
which scalar products are computed.
45+
46+
On a TPU:
47+
* `default` is the fastest and least precise,
48+
downcasting an FP32 to BF16 before multiplying.
49+
50+
* `high` takes three passes and generates approximately 14 bits of
51+
precision.
52+
53+
* `highest` is the most precise, and the slowest. It takes six
54+
passes and generates approximately 22 bits of precision.
55+
56+
Args:
57+
precision (str): The precision to set for matrix multiplication.
58+
Must be one of 'default', 'high', or 'highest'.
59+
"""
60+
if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
61+
raise ValueError(f"Invalid precision: {precision}. "
62+
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")
63+
64+
torch_xla._XLAC._xla_set_mat_mul_precision(precision)
65+
66+
67+
def get_mat_mul_precision() -> _PrecisionType:
68+
"""Get the current mat mul precision for 32bit inputs.
69+
70+
Returns:
71+
str: The current precision setting for matrix multiplication,
72+
one of 'default', 'high', or 'highest'.
73+
"""
74+
precision = torch_xla._XLAC._xla_get_mat_mul_precision()
75+
assert precision in [_DEFAULT, _HIGH, _HIGHEST
76+
], (f"Invalid precision: {precision}. "
77+
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")
78+
return precision

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,10 @@ void InitXlaModuleBindings(py::module m) {
21242124
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
21252125
XlaHelpers::set_mat_mul_precision(precision);
21262126
});
2127+
m.def("_xla_get_mat_mul_precision", []() {
2128+
xla::PrecisionConfig::Precision precision = XlaHelpers::mat_mul_precision();
2129+
return xla::PrecisionToString(precision);
2130+
});
21272131

21282132
py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
21292133
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");

torch_xla/csrc/xla_op_builder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ xla::PrecisionConfig DotPrecisonConfig(py::dict args) {
208208
precision = xla::PrecisionConfig::HIGH;
209209
} else if (*arg_precision_config == "highest") {
210210
precision = xla::PrecisionConfig::HIGHEST;
211+
} else {
212+
XLA_ERROR() << "Invalid precision config in args: "
213+
<< *arg_precision_config
214+
<< " (valid values: default, high, highest)";
211215
}
212216
}
213217
return XlaHelpers::BuildPrecisionConfig(precision);

0 commit comments

Comments
 (0)