Skip to content

[CXX11ABI] torch 2.6.0-cu126 and cu124 have different exported symbols #152790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vadimkantorov opened this issue May 4, 2025 · 19 comments
Open
Labels
module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented May 4, 2025

🐛 Describe the bug

The symbol _ZN3c105ErrorC2ENS_14SourceLocationESs is exported in cu124's version, but missing in cu126: some nm outputs in Dao-AILab/flash-attention#1644

I understand that because of missing symbols, flash_attention has stopped working with torch 2.7. But it was a bit surprising that the exported symbols differ between cu124 and cu126 version of the same release...

Also, a question is why torch exported _ZN3c105ErrorC2ENS_14SourceLocationESs and why flash_attention depends on it...

@malfet

Versions

torch 2.6.0-cu126 and cu124

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @seemethere @malfet @osalpekar @atalman @ptrblck @eqy @jerryzh168

@malfet malfet added high priority module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general module: regression It used to work, and now it doesn't labels May 4, 2025
@malfet
Copy link
Contributor

malfet commented May 4, 2025

I agree this shouldn't be the case...

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 4, 2025

Maybe worth adding some smoke tests on what symbols actually get exported / compare between versions. Especially if now symbols should not get auto-exported to reduce symbol pollution...

I also wonder how/why flash_attention depends on _ZN3c105ErrorC2ENS_14SourceLocationESs, as it broke its operation with torch 2.7.0.

Also, xformers hid all 2.6.0-compatible installation recipes, despite the fact that vllm hasn't released 2.7.0 version yet and flash_attention seems broken with 2.7.0...

@malfet
Copy link
Contributor

malfet commented May 5, 2025

Maybe worth adding some smoke tests on what symbols actually get exported / compare between versions. Especially if now symbols should not get auto-exported to reduce symbol pollution...

We do have smoke tests for those in https://github.com/pytorch/pytorch/blob/main/.ci/pytorch/smoke_test/check_binary_symbols.py but they were mostly focused on non-using CXX11 ABI before it were enabled by default

I also wonder how/why flash_attention depends on _ZN3c105ErrorC2ENS_14SourceLocationESs, as it broke its operation with torch 2.7.0.

c++filt _ZN3c105ErrorC2ENS_14SourceLocationESs-> c10::Error::Error(c10::SourceLocation, std::basic_string<char, std::char_traits<char>, std::allocator<char> >) , i.e. I suspect it's the one that is raised by any TORCH_CHECK call....

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 5, 2025

I also wonder if this symbol can be brought back in 2.7.1 - just to make sure that flash_attention's pip-distributed pre-compiled binaries work for torch 2.7.1 out-of-the-box?..

Maybe this symbols is needed for building PyTorch C++ extensions using TORCH_CHECK?

@malfet
Copy link
Contributor

malfet commented May 5, 2025

Ok, I sometimes fail to read issues: this is an expected behavior for 2.6.0 and indeed associated with CXX11 ABI migration: 12.4 were build with pre-CXX11ABI and 12.6 with it, which should have been reflected in release notes

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 5, 2025

12.4 were build with pre-CXX11ABI and 12.6 with it

I see - just a bit strange that this happened without version bump within the same 2.6.0 release.

And if it happened during 2.6.0 release - flash_attention's own releases seem falling behind for quite some months now...

@malfet
Copy link
Contributor

malfet commented May 5, 2025

I.e. the symbol name in 2.6 (and all 2.7.0) are _ZN3c105ErrorC1ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE -> c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)

@vadimkantorov
Copy link
Contributor Author

So it is expected that binaries of the most C++ extensions need to upgrade and push new binaries, right?

@malfet
Copy link
Contributor

malfet commented May 5, 2025

@vadimkantorov just checking: are you saying you can't build flash attention from source, or their binaries are incompatible?
If former, I'll try to submit a PR (and ask @drisspg help for review) to make them respect the flag set by torch

@malfet
Copy link
Contributor

malfet commented May 5, 2025

So it is expected that binaries of the most C++ extensions need to upgrade and push new binaries, right?

Until 2.7.0 PyTorch had no stable ABI, so it was always the recommendation. But I think some packages might have hardcoded the ABI flags in their script instead of querying torch build system (I forgot the exact API name, but it should be something like torch.utils.cpp_extension.get_abi_flags()

@vadimkantorov
Copy link
Contributor Author

I haven't tried compiling flash_attention from scratch - since I discovered that I need to downgrade to cu124 to get it working for now

But I expect more people would send issues to flash_attention about this soon, since 2.7.0 is a new release (and maybe much fewer people did upgrade 2.6.0 from cu124 to cu126):

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 5, 2025

Maybe for extensions like flash_attention (which do not use other functions from libtorch to do computations with tensors or benefit from autograd), simply them providing functions via C FFI and accepting tensors via raw pointers / DLPack would be sufficient - and not require upgrading when PyTorch upgrades its ABI - (I hope the load_inline was used more often for building the C++ extension on the user's machine)?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 5, 2025

@malfet maybe good for "smoketests" would be to have a test suite in PyTorch which tries building from source the most popular C++ pytorch extensions:

to check if there are any errors and to raise the issue with the authors before the PyTorch release

or even to have in Known Issues section in the release notes saying about other tools: flash_attention's binaries are broken and needs a new release, vllm same, etc

I guess the thing with flash_attention is that they don't have a huge team for fast support, but the package is depended on by a lot of way more popular packages...

@malfet malfet changed the title torch 2.6.0-cu126 and cu124 have different exported symbols [CXX11ABI] torch 2.6.0-cu126 and cu124 have different exported symbols May 5, 2025
@atalman
Copy link
Contributor

atalman commented May 5, 2025

Hi @vadimkantorov Please go ahead and create an issue in https://github.com/pytorch/test-infra/issues with list of extensions you think could be useful for us to test against.

We do have nightly smoke tests running on nightly basis: https://github.com/pytorch/test-infra/blob/main/.github/workflows/validate-nightly-binaries.yml hence can look into extending the validation framework and adding extra test here.

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed module: regression It used to work, and now it doesn't triage review labels May 5, 2025
@malfet
Copy link
Contributor

malfet commented May 6, 2025

Some followups:

  • One can easily restore the symbol by adding say PreCxx11ExceptionsWrapper.cpp to c10/utils that looks like something as follows
#define _GLIBCXX_USE_CXX11_ABI 0
#include <c10/util/Exception.h>
#include <iostream>

c10::Error::Error(SourceLocation source_location, std::string msg) {
	std::cout << "Ha" << std::endl;
}

But it is much harder to redispatch strings from one ABI to the next, and even if one achieves something like that C ships, it results in segfauls when inlined destructor is called even for as simple extensions as

#define _GLIBCXX_USE_CXX11_ABI 0
#include <torch/extension.h>

// A simple function that multiplies each element by 2
torch::Tensor multiply_by_two(torch::Tensor input) {
    TORCH_CHECK_VALUE(input.numel() < 10, "Too many elements");
    return input * 2;
}

PYBIND11_MODULE(simple_extension, m) {
    m.def("multiply_by_two", &multiply_by_two, "Multiply each element by 2");

which works with abovementioned example, but crashes if wrappers redispatches to C-like constructor

* thread #1, name = 'python', stop reason = signal SIGSEGV
  * frame #0: 0x000077d8602f1efa libc.so.6`__libc_free + 26
    frame #1: 0x000077d82e0eb7cd simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() [inlined] std::__new_allocator<std::string>::deallocate(__n=<unavailable>, __p=<unavailable>, this=<unavailable>) at new_allocator.h:158:26
    frame #2: 0x000077d82e0eb7c8 simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() [inlined] std::allocator_traits<std::allocator<std::string> >::deallocate(__n=<unavailable>, __p=<unavailable>, __a=<unavailable>) at alloc_traits.h:496:23
    frame #3: 0x000077d82e0eb7c8 simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() at stl_vector.h:387:19
    frame #4: 0x000077d82e0eb7c8 simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() [inlined] std::_Vector_base<std::string, std::allocator<std::string> >::~_Vector_base(this=<unavailable>, __in_chrg=<unavailable>) at stl_vector.h:366:15
    frame #5: 0x000077d82e0eb7b5 simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() [inlined] std::vector<std::string, std::allocator<std::string> >::~vector(this=<unavailable>, __in_chrg=<unavailable>) at stl_vector.h:733:7
    frame #6: 0x000077d82e0eb78e simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError() [inlined] c10::Error::~Error(this=0x0000563de48bd3b0, __in_chrg=<unavailable>) at Exception.h:30:15
    frame #7: 0x000077d82e0eb74d simple_extension.cpython-311-x86_64-linux-gnu.so`c10::ValueError::~ValueError(this=0x0000563de48bd3b0) at Exception.h:254:15
    frame #8: 0x000077d8511c1023 libstdc++.so.6`___lldb_unnamed_symbol7648 + 35

@malfet
Copy link
Contributor

malfet commented May 9, 2025

@vadimkantorov can you explain in a bit more detail how you end up in that situation?

I run

curl -OL https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
python3.10 -mpip install torch
python3.10 -mpip install flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl 

and it seems to pass smoke tests like

python3 -mpytest test_rotary.py  -v

@malfet malfet added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label May 9, 2025
@vadimkantorov
Copy link
Contributor Author

I probably first installed unsloth which brought in PyTorch 2.6.0cu124 and flash_attn with old ABI, and then I somehow updated cuda on the machine and tried to upgrade PyTorch manually. Pip then maybe always fetches the old abi binary of flash_attn from the cache? Or maybe the new abi binary of flash_attn is not discovered by pip instal and must be installed manually via the whl link?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 12, 2025

So the full installation line for pytorch 2.7.0 and flash_attn for cu126 seems to be:

pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu126
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl

# pip index versions vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
# pip install --upgrade vllm==0.8.5.dev600+g7ea6cb28b --pre --extra-index-url https://wheels.vllm.ai/nightly

although, xformers now announced they dropped support for 2.6.0 for precompiled binaries :( seems a bit too fast :( - e.g. vllm's support for 2.7.0 is not released yet (expected in 0.9.0, and latest vllm depending on 2.6.0 still requires xformers which itself depends on 2.7.0)

@vadimkantorov
Copy link
Contributor Author

@malfet I think pip sometimes would also ignore --index-url if it finds a "suitable" package in local cache. E.g. this can lead to ingnoring index-url with nightlies...

Found it with vllm:

But likely it affects pytorch itself as well :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: binaries Anything related to official binaries that we release to users module: cuda Related to torch.cuda, and CUDA support in general needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants