Skip to content

Commit 205ae57

Browse files
authored
Add regression test for torch.expand (#3785)
1 parent d1d693b commit 205ae57

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

.circleci/common.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ function install_deps_pytorch_xla() {
5858
pip install hypothesis
5959
pip install cloud-tpu-client
6060
pip install absl-py
61-
pip install --upgrade numpy>=1.18.5
61+
pip install --upgrade "numpy>=1.18.5"
6262
pip install --upgrade numba
6363

6464
# Using the Ninja generator requires CMake version 3.13 or greater
65-
pip install cmake>=3.13 --upgrade
65+
pip install "cmake>=3.13" --upgrade
6666

6767
sudo apt-get -qq update
6868

.circleci/docker/install_conda.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function install_and_setup_conda() {
3939
/usr/bin/yes | pip install cloud-tpu-client
4040
/usr/bin/yes | pip install expecttest==0.1.3
4141
/usr/bin/yes | pip install ninja # Install ninja to speedup the build
42-
/usr/bin/yes | pip install cmake>=3.13 --upgrade # Using Ninja requires CMake>=3.13
42+
/usr/bin/yes | pip install "cmake>=3.13" --upgrade # Using Ninja requires CMake>=3.13
4343
/usr/bin/yes | pip install absl-py
4444
# Additional PyTorch requirements
4545
/usr/bin/yes | pip install scikit-image scipy==1.1.0 # >1.1.0 breaks PyTorch tests

test/run_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ function run_op_tests {
111111
run_xla_ir_debug python3 "$CDIR/test_env_var_mapper.py"
112112
run_pjrt python3 "$CDIR/pjrt/test_experimental_pjrt.py"
113113
run_pjrt python3 "$CDIR/pjrt/test_experimental_tpu.py"
114+
run_test python3 "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
114115
}
115116

116117
function run_mp_op_tests {

test/test_operations_hlo.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Parse local options first, and rewrite the sys.argv[].
2+
# We need to do that before import "common", as otherwise we get an error for
3+
# unrecognized arguments.
4+
import argparse
5+
import sys
6+
7+
parser = argparse.ArgumentParser(add_help=False)
8+
parser.add_argument('--replicated', action='store_true')
9+
parser.add_argument('--long_test', action='store_true')
10+
parser.add_argument('--max_diff_count', type=int, default=25)
11+
parser.add_argument('--verbosity', type=int, default=0)
12+
FLAGS, leftovers = parser.parse_known_args()
13+
sys.argv = [sys.argv[0]] + leftovers
14+
15+
# Normal imports section starts here.
16+
import torch
17+
import torch_xla
18+
import torch_xla.utils.utils as xu
19+
import torch_xla.core.xla_model as xm
20+
import torch_xla.debug.metrics as met
21+
import unittest
22+
23+
24+
class TestOperationsHlo(unittest.TestCase):
25+
26+
def setUp(self):
27+
super(TestOperationsHlo, self).setUp()
28+
29+
def tearDown(self):
30+
super(TestOperationsHlo, self).tearDown()
31+
32+
def test_expand(self):
33+
a = torch.rand(1, 5, device=xm.xla_device())
34+
b = a.expand(5, 5)
35+
hlo_text = torch_xla._XLAC._get_xla_tensors_text([b])
36+
assert 'aten::expand' in hlo_text
37+
38+
39+
if __name__ == '__main__':
40+
torch.set_default_tensor_type('torch.FloatTensor')
41+
torch.manual_seed(42)
42+
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
43+
use_full_mat_mul_precision=True)
44+
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
45+
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
46+
print(met.metrics_report())
47+
sys.exit(0 if test.result.wasSuccessful() else 1)

0 commit comments

Comments
 (0)