Introduce apply_xla_patch_to_nn_linear and test that in a scan #8739
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers.
For this to work we need three pieces:
apply_xla_patch_to_nn_linear
function to replace the implementation ofnn.Linear
with einsum (calling XLAPatchedLinear).@custom_op
to mark a function as opaque to AOTAutograd.@custom_op
, the einsum is still decomposed into transposes/reshapes due totorch.einsum is incorrectly decomposed when wrapped inside a custom op #8713. That's a bug/PyTorch limitation. To workaround this, I added a
_xla_einsum
C++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity.Added a test that demonstrates how
nn.Linear
layers by default flattens any non-contracting dims, and how we could avoid that withapply_xla_patch_to_nn_linear
.