Skip to content

Commit 0a600c3

Browse files
authored
[mlir][nvgpu] Make phaseParity of mbarrier.try_wait i1 (#81460)
Currently, `phaseParity` argument of `nvgpu.mbarrier.try_wait.parity` is index. This can cause a problem if it's passed any value different than 0 or 1. Because the PTX instruction only accepts even or odd phase. This PR makes phaseParity argument i1 to avoid misuse. Here is the information from PTX doc: ``` The .parity variant of the instructions test for the completion of the phase indicated by the operand phaseParity, which is the integer parity of either the current phase or the immediately preceding phase of the mbarrier object. An even phase has integer parity 0 and an odd phase has integer parity of 1. So the valid values of phaseParity operand are 0 and 1. ``` See for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
1 parent 05ad0d4 commit 0a600c3

File tree

10 files changed

+25
-16
lines changed

10 files changed

+25
-16
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,14 +609,16 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
609609
phase. Suspended thread resumes execution when the specified phase completes
610610
OR before the phase completes following a system-dependent time limit.
611611

612+
The `$phaseParity` specifies either even phase (0) or odd phase (1) to
613+
wait.
614+
612615
Example:
613616
```mlir
614-
nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
617+
nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
615618
```
616-
617619
}];
618-
let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$phase, Index:$ticks, Index:$mbarId);
619-
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";
620+
let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
621+
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";
620622
}
621623

622624
def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,8 @@ struct NVGPUMBarrierTryWaitParityLowering
956956
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
957957
adaptor.getMbarId(), rewriter);
958958
Value ticks = truncToI32(b, adaptor.getTicks());
959-
Value phase = truncToI32(b, adaptor.getPhase());
959+
Value phase =
960+
b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
960961

961962
if (isMbarrierShared(op.getBarriers().getType())) {
962963
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,8 @@ void HopperBuilder::buildBarrierArriveTx(
10101010

10111011
void HopperBuilder::buildTryWaitParity(
10121012
TypedValue<nvgpu::MBarrierGroupType> barrier) {
1013-
Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1013+
Type i1 = rewriter.getI1Type();
1014+
Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
10141015
// 10M is an arbitrary, not too small or too big number to specify the number
10151016
// of ticks before retry.
10161017
// TODO: hoist this in a default dialect constant.

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,12 @@ func.func @mbarrier_txcount() {
590590
}
591591

592592

593-
%phase = arith.constant 0 : index
593+
%phase_c0 = arith.constant 0 : i1
594594
%ticks = arith.constant 10000000 : index
595595
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
596596
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
597597
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
598-
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
598+
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
599599

600600
func.return
601601
}
@@ -626,12 +626,12 @@ func.func @mbarrier_txcount_pred() {
626626
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
627627
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
628628

629-
%phase = arith.constant 0 : index
629+
%phase_c0 = arith.constant 0 : i1
630630
%ticks = arith.constant 10000000 : index
631631
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
632632
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
633633
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
634-
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
634+
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
635635

636636
func.return
637637
}

mlir/test/Dialect/NVGPU/tmaload-transform.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func.func @main() {
6262
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
6363
// CHECK: }
6464
//
65-
// CHECK: %[[c0_6:.*]] = arith.constant 0 : index
65+
// CHECK: %[[c0_6:.*]] = llvm.mlir.constant(false) : i1
6666
// CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
6767
// CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
6868

mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ func.func @main() {
197197
{
198198
%ticks = arith.constant 10000000 : index
199199
// TMA wait
200-
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
200+
%phase_c0 = arith.constant 0 : i1
201+
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
201202
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
202203
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
203204
// Descriptor WGMMA

mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ func.func @main() {
206206
{
207207
%ticks = arith.constant 10000000 : index
208208
// TMA wait
209-
nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
209+
%phase_c0 = arith.constant 0 : i1
210+
nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
210211
%lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
211212
%rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
212213
// Descriptor WGMMA

mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ module @mymod {
9393
}
9494

9595
// Step 8. Wait until TMA is done
96-
nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
96+
%phase_c0 = arith.constant 0 : i1
97+
nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : !barrierType
9798

9899
// Step 9. Print loaded data in 128b swizzled
99100
scf.if %10 {

mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ module @mymod {
119119
}
120120

121121
// Step 7. Wait until TMA is done
122-
nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
122+
%phase_c0 = arith.constant 0 : i1
123+
nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : !barrierType
123124

124125
// Step 8. Print loaded data in 128b swizzled
125126
scf.if %10 {

mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ module @mymod {
9696
} else {
9797
nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : <memorySpace = #gpu.address_space<workgroup>>
9898
}
99-
nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
99+
%phase_c0 = arith.constant 0 : i1
100+
nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
100101
scf.if %10 {
101102
%11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
102103
%12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>

0 commit comments

Comments
 (0)