Skip to content

torchmetrics.Accuracy doesn't support inference mode with distributed backend. #9431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
WeichenXu123 opened this issue Sep 10, 2021 · 6 comments · Fixed by #9443
Closed

torchmetrics.Accuracy doesn't support inference mode with distributed backend. #9431

WeichenXu123 opened this issue Sep 10, 2021 · 6 comments · Fixed by #9443
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@WeichenXu123
Copy link

WeichenXu123 commented Sep 10, 2021

🐛 Bug

torchmetrics.Accuracy doesn't support inference mode (introduced in #8813) with distributed backend.

To Reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing
from torchmetrics import Accuracy
import pytorch_lightning as pl
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, random_split, TensorDataset


class IrisClassificationBase(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
        self.args = kwargs
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 3)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 0.01)


class IrisClassification(IrisClassificationBase):
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.train_acc(torch.argmax(logits, dim=1), y)
        self.log("train_acc", self.train_acc.compute(), on_step=False, on_epoch=True)
        self.log("loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.val_acc(torch.argmax(logits, dim=1), y)
        self.log("val_acc", self.val_acc.compute())
        self.log("val_loss", loss, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        self.test_acc(torch.argmax(logits, dim=1), y)
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc.compute())


class IrisDataModuleBase(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.columns = None

    def _get_iris_as_tensor_dataset(self):
        iris = load_iris()
        df = iris.data
        self.columns = iris.feature_names
        target = iris["target"]
        data = torch.Tensor(df).float()
        labels = torch.Tensor(target).long()
        data_set = TensorDataset(data, labels)
        return data_set

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            iris_full = self._get_iris_as_tensor_dataset()
            self.train_set, self.val_set = random_split(iris_full, [130, 20])
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.train_set, self.test_set = random_split(self.train_set, [110, 20])


class IrisDataModule(IrisDataModuleBase):
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=4)


if __name__ == '__main__':
    torch.multiprocessing.freeze_support()
    model = IrisClassification()
    dm = IrisDataModule()
    dm.setup(stage="fit")
    trainer = pl.Trainer(max_epochs=20, accelerator="ddp_cpu", num_processes=4)
    trainer.fit(model, dm)
    dm.setup(stage='test')
    trainer.test(model, dm) # raise error 

This line trainer.test(model, dm) raise error of:

distributed_backend=gloo
All DDP processes registered. Starting ddp with 4 processes
----------------------------------------------------------------------------------------------------

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:112: UserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Testing: 0it [00:00, ?it/s]Traceback (most recent call last):
  File "/Users/weichen.xu/work/projects/test1/v1.py", line 104, in <module>
    trainer.test(model, dm)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 512, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 765, in _test_impl
    results = self._run(model)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1014, in _run
    self._dispatch()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1081, in _dispatch
    self.accelerator.start_evaluating(self)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 90, in start_evaluating
    self.training_type_plugin.start_evaluating(trainer)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 169, in start_evaluating
    mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 210, in new_process
    results = trainer.run_stage()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1092, in run_stage
    return self._run_evaluate()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1151, in _run_evaluate
    eval_loop_results = self._evaluation_loop.run()
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 113, in run
    self.advance(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 107, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 113, in run
    self.advance(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 114, in advance
    output = self.evaluation_step(batch, batch_idx, dataloader_idx)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 158, in evaluation_step
    output = self.trainer.accelerator.test_step(step_kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 198, in test_step
    return self.training_type_plugin.test_step(*step_kwargs.values())
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 356, in test_step
    return self.lightning_module.test_step(*args, **kwargs)
  File "/Users/weichen.xu/work/projects/test1/v1.py", line 57, in test_step
    self.log("test_acc", self.test_acc.compute())
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/metric.py", line 364, in wrapped_func
    dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, should_unsync=self._should_unsync
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/contextlib.py", line 112, in __enter__
    return next(self.gen)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/metric.py", line 338, in sync_context
    distributed_available=distributed_available,
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/metric.py", line 286, in sync
    self._sync_dist(dist_sync_fn, process_group=process_group)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/metric.py", line 229, in _sync_dist
    group=process_group or self.process_group,
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/utilities/data.py", line 191, in apply_to_collection
    return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/utilities/data.py", line 191, in <dictcomp>
    return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/utilities/data.py", line 187, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/utilities/distributed.py", line 120, in gather_all_tensors
    return _simple_gather_all_tensors(result, group, world_size)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchmetrics/utilities/distributed.py", line 92, in _simple_gather_all_tensors
    torch.distributed.all_gather(gathered_result, result, group)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1914, in all_gather
    work.wait()
RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.

Expected behavior

Environment

Additional context

@WeichenXu123 WeichenXu123 added bug Something isn't working help wanted Open to be worked on labels Sep 10, 2021
@WeichenXu123 WeichenXu123 changed the title torchmetrics.Accuracy doesn't support inference model with distributed backend. torchmetrics.Accuracy doesn't support inference mode with distributed backend. Sep 10, 2021
@tchaton
Copy link
Contributor

tchaton commented Sep 10, 2021

@tangbinh @ananthsub @SkafteNicki Mind looking into this ?

@SkafteNicki
Copy link
Member

I assume it does not like these four lines (as they are the only inplace operations):
https://github.com/PyTorchLightning/metrics/blob/d9bae8bf08718080586557185bcc6f394933f15a/torchmetrics/classification/accuracy.py#L258-L261
@WeichenXu123 could you try changing them from inplace operations self.value+=value to self.value = self.value + value and see if that fixes the problem. Then we can make the changes in torchmetrics.

@tangbinh
Copy link
Contributor

@WeichenXu123 Thank you for providing the code; it's very helpful. @SkafteNicki I try your suggestions but it looks like the problem remains. The problem also doesn't go away when I replace all the logics in torchmetrics/classification/accuracy.py and simply return torch.tensor(0.) in compute. Do you think it has something to do with how torch.distributed.all_gather works?

@ananthsub
Copy link
Contributor

@tchaton if it's safer, we can revert the change for inference mode to unbreak these use cases in the meantime

ananthsub added a commit to ananthsub/pytorch-lightning that referenced this issue Sep 10, 2021
@WeichenXu123
Copy link
Author

@WeichenXu123 Thank you for providing the code; it's very helpful. @SkafteNicki I try your suggestions but it looks like the problem remains. The problem also doesn't go away when I replace all the logics in torchmetrics/classification/accuracy.py and simply return torch.tensor(0.) in compute. Do you think it has something to do with how torch.distributed.all_gather works?

In my test, this issue only happen in distributed backend, so I guess it related to the metric "aggregation" stage

@ananthsub
Copy link
Contributor

ananthsub commented Sep 24, 2021

@WeichenXu123 - were you running ddp_spawn on CPU? On cuda/NCCL, this simple example works for me. So it looks like an issue with gloo backends.

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def init_process(rank, size, backend="nccl"):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    with torch.inference_mode():
        zero = torch.zeros(1).cuda(rank)
        outputs = [torch.ones(1).cuda(rank)] * 2
        dist.all_gather(outputs, zero)
    print(outputs)


if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants