Skip to content

SpearmanCorrCoef is very slow on large tensors with many duplicate elements #3102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gratus907 opened this issue May 23, 2025 · 1 comment · May be fixed by #3103
Open

SpearmanCorrCoef is very slow on large tensors with many duplicate elements #3102

gratus907 opened this issue May 23, 2025 · 1 comment · May be fixed by #3103
Labels
enhancement New feature or request

Comments

@gratus907
Copy link

🚀 Feature

More efficient implementation of SpearmanCorrCoef

Motivation

Current implementation of SpearmanCorrCoef is very slow on large tensors with many duplicate elements, due to inefficient implementation in _rank_data iterating through each elements.

Pitch

Improve implementation of _rank_data function as following

def _rank_data(data: Tensor) -> Tensor:
    n = data.numel()
    rank = torch.empty_like(data, dtype=torch.int32)
    idx = data.argsort()
    rank[idx[:n]] = torch.arange(1, n + 1, dtype=torch.int32, device=data.device)

    uniq, inv, counts = torch.unique(
        data, sorted=True, return_inverse=True, return_counts=True
    )
    sum_ranks = torch.zeros_like(uniq, dtype=torch.int32)
    sum_ranks.scatter_add_(0, inv, rank.to(torch.int32))
    mean_ranks = sum_ranks / counts
    return mean_ranks[inv]

which uses torch.unique and scatter_add_ to avoid python loops.

@gratus907 gratus907 added the enhancement New feature or request label May 23, 2025
Copy link

Hi! Thanks for your contribution! Great first issue!

@gratus907 gratus907 linked a pull request May 23, 2025 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant