Skip to content

Fix TPU testing and collect all tests #11098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ jobs:
timeoutInMinutes: "35"
condition: eq(variables['continue'], '1')

- bash: bash run_standalone_tasks.sh
workingDirectory: tests/tests_pytorch
env:
PL_USE_MOCKED_MNIST: "1"
PL_RUN_CUDA_TESTS: "1"
displayName: 'Testing: PyTorch standalone tasks'
timeoutInMinutes: "10"
condition: eq(variables['continue'], '1')

- bash: |
python -m coverage report
python -m coverage xml
Expand Down
10 changes: 3 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ references:
job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet | kubectl create -f -) && \
job_name=${job_name#job.batch/}
job_name=${job_name% created}
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}')
echo "GKE pod name: $pod_name"
echo "Waiting on kubernetes job: $job_name"
i=0 && \
# N checks spaced 30s apart = 900s total.
Expand All @@ -92,8 +94,6 @@ references:
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
echo "GKE pod name: $pod_name" && \
kubectl logs -f $pod_name --container=train > /tmp/full_output.txt
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
# First portion is the test logs. Print these to Github Action stdout.
Expand All @@ -106,10 +106,6 @@ references:
name: Statistics
command: |
mv ./xx01 coverage.xml
# TODO: add human readable report
cat coverage.xml
sudo pip install pycobertura
pycobertura show coverage.xml

jobs:

Expand All @@ -119,7 +115,7 @@ jobs:
environment:
- XLA_VER: 1.9
- PYTHON_VER: 3.7
- MAX_CHECKS: 240
- MAX_CHECKS: 1000
- CHECK_SPEEP: 5
steps:
- checkout
Expand Down
15 changes: 5 additions & 10 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ local tputests = base.BaseTest {
mode: 'postsubmit',
configMaps: [],

timeout: 1200, # 20 minutes, in seconds.
timeout: 6000, # 100 minutes, in seconds.

image: 'pytorchlightning/pytorch_lightning',
imageTag: 'base-xla-py{PYTHON_VERSION}-torch{PYTORCH_VERSION}',
Expand All @@ -34,16 +34,11 @@ local tputests = base.BaseTest {
pip install -e .[test]
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
export PL_RUN_TPU_TESTS=1
cd tests/tests_pytorch
echo $PWD
# TODO (@kaushikb11): Add device stats tests here
coverage run --source pytorch_lightning -m pytest -v --capture=no \
strategies/test_tpu_spawn.py \
profilers/test_xla_profiler.py \
accelerators/test_tpu.py \
models/test_tpu.py \
plugins/environments/test_xla_environment.py \
utilities/test_xla_device_utils.py
coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
echo "\n||| Running standalone tests |||\n"
bash run_standalone_tests.sh -b 1
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
coverage xml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class SingleTPUPlugin(SingleTPUStrategy):
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
rank_zero_deprecation(
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in."
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in"
" v1.8. Use `pl.strategies.single_tpu.SingleTPUStrategy` instead."
)
super().__init__(*args, **kwargs)
44 changes: 33 additions & 11 deletions src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import os
import time
from functools import wraps
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING

import torch.multiprocessing as mp
from torch.multiprocessing import ProcessContext

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
Expand All @@ -26,9 +28,10 @@
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
xm, xmp = None, None

if TYPE_CHECKING:
from pytorch_lightning.strategies import Strategy
Expand Down Expand Up @@ -72,7 +75,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
"""
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
xmp.spawn(
_save_spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=len(self._strategy.parallel_devices),
Expand Down Expand Up @@ -103,14 +106,6 @@ def _wrapping_function(
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self._strategy.barrier("end-process")

# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self._strategy.local_rank == 0:
time.sleep(2)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback
Expand Down Expand Up @@ -138,3 +133,30 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
self.add_to_queue(trainer, extra)

return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)


def _save_spawn(
fn: Callable,
args: Tuple = (),
nprocs: Optional[int] = None,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
) -> Optional[ProcessContext]:
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
processes."""

@wraps(fn)
def wrapped(rank: int, *_args: Any) -> None:
fn(rank, *_args)

# Make all processes wait for each other before joining
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
xm.rendezvous("end-process")

# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if rank == 0:
time.sleep(1)

return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)
15 changes: 10 additions & 5 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
start_method="fork",
)
self.debug = debug
self._launched = False

@property
def checkpoint_io(self) -> CheckpointIO:
Expand All @@ -90,6 +91,8 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:

@property
def root_device(self) -> torch.device:
if not self._launched:
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
return xm.xla_device()

@staticmethod
Expand Down Expand Up @@ -130,7 +133,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
os.environ["PT_XLA_DEBUG"] = "1"

shared_params = find_shared_parameters(self.model)
self.model_to_device()
Expand All @@ -150,8 +153,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:

@property
def is_distributed(self) -> bool:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnStrategy._validate_dataloader(dataloader)
Expand Down Expand Up @@ -189,8 +192,9 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
raise ValueError(
"Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
f" {reduce_op}"
)

output = xm.mesh_reduce("reduce", output, sum)
Expand All @@ -201,6 +205,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
return output

def _worker_setup(self, process_idx: int):
self._launched = True
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def test_devices_auto_choice_mps():

@pytest.mark.parametrize(
["parallel_devices", "accelerator"],
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))],
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], "tpu")],
)
def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator):
with pytest.raises(MisconfigurationException, match=r"parallel_devices set through"):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def test_strategy_choice_ipu_plugin(tmpdir):


@RunIf(ipu=True)
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
def test_device_type_when_ipu_strategy_passed(tmpdir):
trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8)
assert isinstance(trainer.strategy, IPUStrategy)
assert isinstance(trainer.accelerator, IPUAccelerator)
Expand Down
21 changes: 9 additions & 12 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
from pytorch_lightning.utilities import find_shared_parameters
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import pl_multi_process_test


class WeightSharingModule(BoringModel):
Expand All @@ -46,8 +45,7 @@ def forward(self, x):
return x


@RunIf(tpu=True)
@pl_multi_process_test
@RunIf(tpu=True, standalone=True)
def test_resume_training_on_cpu(tmpdir):
"""Checks if training can be resumed from a saved checkpoint on CPU."""
# Train a model on TPU
Expand All @@ -65,11 +63,9 @@ def test_resume_training_on_cpu(tmpdir):
# Verify that training is resumed on CPU
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model, ckpt_path=model_path)
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_if_test_works_after_train(tmpdir):
"""Ensure that .test() works after .fit()"""

Expand Down Expand Up @@ -293,12 +289,14 @@ def test_xla_checkpoint_plugin_being_default():
assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO)


@RunIf(tpu=True)
@patch("pytorch_lightning.strategies.tpu_spawn.xm")
def test_mp_device_dataloader_attribute(_):
@patch("pytorch_lightning.strategies.tpu_spawn.MpDeviceLoader")
@patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock):
dataset = RandomDataset(32, 64)
dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
assert dataloader.dataset == dataset
dataloader = DataLoader(dataset)
processed_dataloader = TPUSpawnStrategy().process_dataloader(dataloader)
mp_loader_mock.assert_called_with(dataloader, root_device_mock)
assert processed_dataloader.dataset == processed_dataloader._loader.dataset


@RunIf(tpu=True)
Expand All @@ -307,8 +305,7 @@ def test_warning_if_tpus_not_used():
Trainer()


@RunIf(tpu=True)
@pl_multi_process_test
@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize(
["devices", "expected_device_ids"],
[
Expand Down
12 changes: 5 additions & 7 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert cpu_stats_mock.call_count == expected


@pytest.mark.skipif(True, reason="TODO (@kaushikb11): fix this test, timeout")
@RunIf(tpu=True)
def test_device_stats_monitor_tpu(tmpdir):
"""Test TPU stats are logged using a logger."""
Expand All @@ -106,24 +105,23 @@ def test_device_stats_monitor_tpu(tmpdir):

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
def log_metrics(self, metrics, step=None) -> None:
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
for f in fields:
assert any(f in h for h in metrics)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
max_epochs=2,
limit_train_batches=5,
accelerator="tpu",
devices=1,
devices=8,
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
enable_checkpointing=False,
enable_progress_bar=False,
)

trainer.fit(model)


Expand All @@ -146,7 +144,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
trainer.fit(model)


def test_prefix_metric_keys(tmpdir):
def test_prefix_metric_keys():
"""Test that metric key names are converted correctly."""
metrics = {"1": 1.0, "2": 2.0, "3": 3.0}
prefix = "foo"
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.C
min_cuda_gpus="PL_RUN_CUDA_TESTS",
slow="PL_RUN_SLOW_TESTS",
ipu="PL_RUN_IPU_TESTS",
tpu="PL_RUN_TPU_TESTS",
)
if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1":
# special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,10 @@ def test_v1_8_0_deprecated_single_device_plugin_class():
SingleDevicePlugin("cpu")


@RunIf(tpu=True)
@RunIf(tpu=True, standalone=True)
def test_v1_8_0_deprecated_single_tpu_plugin_class():
with pytest.deprecated_call(
match=(
"SingleTPUPlugin` is deprecated in v1.6 and will be removed in v1.8." " Use `.*SingleTPUStrategy` instead."
)
match="SingleTPUPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `.*SingleTPUStrategy` instead."
):
SingleTPUPlugin(0)

Expand Down
2 changes: 2 additions & 0 deletions tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def __new__(
if tpu:
conditions.append(not _TPU_AVAILABLE)
reasons.append("TPU")
# used in conftest.py::pytest_collection_modifyitems
kwargs["tpu"] = True

if ipu:
conditions.append(not _IPU_AVAILABLE)
Expand Down
Loading