You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Improve _bincount utility to avoid unnecessary fallback on CUDA when deterministic mode is enabled and conditions are safe for native torch.bincount use (PyTorch ≥ 2.1).
Motivation
The current _bincount implementation in TorchMetrics falls back to a slower and more memory-intensive workaround when torch.are_deterministic_algorithms_enabled() is set to True, regardless of the PyTorch version or backend.
However, since PyTorch v2.1, torch.bincount is allowed in deterministic mode on CUDA as long as:
No weights are passed
Gradients are not required
Avoiding the fallback in this case would improve performance and reduce memory usage.
This is particularly relevant when running large-scale evaluations on modern GPU systems.
Pitch
Update the _bincount utility logic to:
Use native torch.bincount if:
x.is_cuda is True
torch.__version__ >= 2.1
No weights are involved
Gradients are not required
Only fall back when:
x.is_mps or
XLA backend is detected or
PyTorch version is < 2.1 and deterministic algorithms are enabled
Alternatives
Continue using the current fallback unconditionally under deterministic mode, but this leads to unnecessary compute and memory overhead on newer CUDA-enabled systems.
Additional context
This proposed change aligns with the improvements introduced in PyTorch PR #105244, which enabled deterministic torch.bincount on CUDA under safe conditions starting from v2.1.
A PR will follow shortly to implement this enhancement.
The text was updated successfully, but these errors were encountered:
🚀 Feature
Improve _bincount utility to avoid unnecessary fallback on CUDA when deterministic mode is enabled and conditions are safe for native
torch.bincount
use (PyTorch ≥ 2.1).Motivation
The current
_bincount
implementation in TorchMetrics falls back to a slower and more memory-intensive workaround whentorch.are_deterministic_algorithms_enabled()
is set toTrue
, regardless of the PyTorch version or backend.However, since PyTorch v2.1,
torch.bincount
is allowed in deterministic mode on CUDA as long as:Avoiding the fallback in this case would improve performance and reduce memory usage.
This is particularly relevant when running large-scale evaluations on modern GPU systems.
Pitch
Update the
_bincount
utility logic to:torch.bincount
if:x.is_cuda
isTrue
torch.__version__ >= 2.1
x.is_mps
orAlternatives
Continue using the current fallback unconditionally under deterministic mode, but this leads to unnecessary compute and memory overhead on newer CUDA-enabled systems.
Additional context
This proposed change aligns with the improvements introduced in PyTorch PR #105244, which enabled deterministic
torch.bincount
on CUDA under safe conditions starting from v2.1.A PR will follow shortly to implement this enhancement.
The text was updated successfully, but these errors were encountered: