Skip to content

Avoid unnecessary fallback in _bincount when deterministic mode is enabled on CUDA (PyTorch ≥ 2.1) #3086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hyukkyukang opened this issue May 8, 2025 · 1 comment · Fixed by hyukkyukang/torchmetrics#1 · May be fixed by #3087
Labels
enhancement New feature or request

Comments

@hyukkyukang
Copy link

🚀 Feature

Improve _bincount utility to avoid unnecessary fallback on CUDA when deterministic mode is enabled and conditions are safe for native torch.bincount use (PyTorch ≥ 2.1).

Motivation

The current _bincount implementation in TorchMetrics falls back to a slower and more memory-intensive workaround when torch.are_deterministic_algorithms_enabled() is set to True, regardless of the PyTorch version or backend.

However, since PyTorch v2.1, torch.bincount is allowed in deterministic mode on CUDA as long as:

  • No weights are passed
  • Gradients are not required

Avoiding the fallback in this case would improve performance and reduce memory usage.
This is particularly relevant when running large-scale evaluations on modern GPU systems.

Pitch

Update the _bincount utility logic to:

  • Use native torch.bincount if:
    • x.is_cuda is True
    • torch.__version__ >= 2.1
    • No weights are involved
    • Gradients are not required
  • Only fall back when:
    • x.is_mps or
    • XLA backend is detected or
    • PyTorch version is < 2.1 and deterministic algorithms are enabled

Alternatives

Continue using the current fallback unconditionally under deterministic mode, but this leads to unnecessary compute and memory overhead on newer CUDA-enabled systems.

Additional context

This proposed change aligns with the improvements introduced in PyTorch PR #105244, which enabled deterministic torch.bincount on CUDA under safe conditions starting from v2.1.

A PR will follow shortly to implement this enhancement.

@hyukkyukang hyukkyukang added the enhancement New feature or request label May 8, 2025
Copy link

github-actions bot commented May 8, 2025

Hi! Thanks for your contribution! Great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
1 participant