Skip to content

[mlir][nvgpu] Allow TMA's last dim to be non-128B without swizzling #81499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2024

Conversation

grypp
Copy link
Member

@grypp grypp commented Feb 12, 2024

Allow TMA's last dimension to be non-128B when swizzling mode is not set.

Test tma_load_64x8_8x128_noswizzle.mlir is failing due to the verifier. This PR will fix that

…zzling

Allow TMA's last dimenision to be non-128B when swizzling mode is not set.

Test tma_load_64x8_8x128_noswizzle.mlir is failing due to the verifier. This PR will fix that
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

Allow TMA's last dimension to be non-128B when swizzling mode is not set.

Test tma_load_64x8_8x128_noswizzle.mlir is failing due to the verifier. This PR will fix that


Full diff: https://github.com/llvm/llvm-project/pull/81499.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+2-1)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 4b6327479a219c..26f831f10a4e40 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -362,7 +362,8 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
                              << kMaxTMADimension << " but it is " << dim;
     }
   }
-  if (descMemref.getRank() > 1) {
+  if (descMemref.getRank() > 1 &&
+      descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
     unsigned lastDimensionByte =
         descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
     if (lastDimensionByte != kMaxTMALastdimByte)

@grypp grypp merged commit e892f32 into llvm:main Feb 13, 2024
@grypp grypp deleted the no-swizzling-tma branch February 13, 2024 08:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants