Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: bugfix to #949 #951

Merged
merged 11 commits into from
Mar 17, 2025
Merged

bugfix: bugfix to #949 #951

merged 11 commits into from
Mar 17, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Mar 17, 2025

The sm86/sm89 version of mla kernel was not tests after change #942, this PR fixes the issue.

This PR also make the following changes:

  1. adding the mla unittest to CI (on a10g node).
  2. shrinking the unittest of mla so that CI can finish in reasonable time.
  3. change is_sm90a_supported(torch.device("cuda")) to backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): for non-hopper GPUs, as pointed by @Atream .

@yzh119
Copy link
Collaborator Author

yzh119 commented Mar 17, 2025

ci has some issues similar to https://discuss.pytorch.org/t/torch-pytest-leads-to-memory-fragmentation-how-to-do-proper-integration-testing-of-a-lot-of-torch-models/201231

The oom hook (

@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# skip OOM error
try:
item.runtest()
except (torch.OutOfMemoryError, RuntimeError) as e:
if isinstance(e, torch.OutOfMemoryError) or "CUDA error: out of memory" in str(
e
):
pytest.skip("Skipping due to OOM")
else:
raise
) didn't totally avoid the issue.

@yzh119
Copy link
Collaborator Author

yzh119 commented Mar 17, 2025

Temporary solution is to set environment variable PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to reduce fragmentation.

@yzh119 yzh119 merged commit 30b2838 into flashinfer-ai:main Mar 17, 2025
2 checks passed
yyihuang pushed a commit to yyihuang/flashinfer that referenced this pull request Mar 17, 2025
The sm86/sm89 version of mla kernel was not tests after change flashinfer-ai#942,
this PR fixes the issue.

This PR also make the following changes:
1. adding the mla unittest to CI (on a10g node). 
2. shrinking the unittest of mla so that CI can finish in reasonable
time.
3. change `is_sm90a_supported(torch.device("cuda"))` to `backend ==
"fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper
GPUs, as pointed by @Atream .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant