Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for enabling sparse gradients in EmbeddingBag #8719

Open
chandrasekhard2 opened this issue Feb 18, 2025 · 1 comment · May be fixed by #8905
Open

Support for enabling sparse gradients in EmbeddingBag #8719

chandrasekhard2 opened this issue Feb 18, 2025 · 1 comment · May be fixed by #8905
Assignees
Labels
enhancement New feature or request lowering ATen Operation lowering

Comments

@chandrasekhard2
Copy link
Collaborator

🚀 Feature

Support for enabling sparse gradients in EmbeddingBag.

Motivation

Adding support for sparse gradients will allow to fit larger embedding tables on the TPU.

Pitch

I encountered the following error when turning on the sparse=True in EmbeddingBag API.

NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXLA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XLA, Meta, SparseCPU, SparseCUDA, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastXLA, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Without this flag, memory consumption doubles during training (Embedding table memory size + gradients for the same). If we implement the support for the sparse gradients we can almost double the embedding_dim of any model on the same hardware (provided it doesn't exceed the HBM).

Alternatives

Alternative is to just keep using EmbeddingBag with sparse=False.

Additional context

@miladm
Copy link
Collaborator

miladm commented Feb 19, 2025

@ysiraichi do we have bandwidth to get started on this work this week?

cc @qihqi

@ysiraichi ysiraichi added the enhancement New feature or request label Mar 19, 2025
@amjames amjames linked a pull request Mar 28, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request lowering ATen Operation lowering
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants