Skip to content

Commit 24dfcc0

Browse files
authored
[flang][cuda] Use the nvvm.vote.sync op for all and any (#134433)
NVVM operations are now available for all and any as well. Use the op and clean up the generation function to handle all the 3 vote sync kinds.
1 parent 428fc2c commit 24dfcc0

File tree

3 files changed

+19
-46
lines changed

3 files changed

+19
-46
lines changed

Diff for: flang/include/flang/Optimizer/Builder/IntrinsicCall.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Runtime/iostat-consts.h"
2020
#include "mlir/Dialect/Complex/IR/Complex.h"
2121
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2223
#include "mlir/Dialect/Math/IR/Math.h"
2324
#include <optional>
2425

@@ -450,9 +451,8 @@ struct IntrinsicLibrary {
450451
llvm::ArrayRef<fir::ExtendedValue> args);
451452
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
452453
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
453-
mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
454-
mlir::Value genVoteAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
455-
mlir::Value genVoteBallotSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
454+
template <mlir::NVVM::VoteSyncKind kind>
455+
mlir::Value genVoteSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
456456

457457
/// Implement all conversion functions like DBLE, the first argument is
458458
/// the value to convert. There may be an additional KIND arguments that

Diff for: flang/lib/Optimizer/Builder/IntrinsicCall.cpp

+14-41
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
#include "mlir/Dialect/Complex/IR/Complex.h"
4949
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5050
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
51-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
5251
#include "mlir/Dialect/Math/IR/Math.h"
5352
#include "mlir/Dialect/Vector/IR/VectorOps.h"
5453
#include "llvm/Support/CommandLine.h"
@@ -262,7 +261,7 @@ static constexpr IntrinsicHandler handlers[]{
262261
{{{"mask", asAddr}, {"dim", asValue}}},
263262
/*isElemental=*/false},
264263
{"all_sync",
265-
&I::genVoteAllSync,
264+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::all>,
266265
{{{"mask", asValue}, {"pred", asValue}}},
267266
/*isElemental=*/false},
268267
{"allocated",
@@ -275,7 +274,7 @@ static constexpr IntrinsicHandler handlers[]{
275274
{{{"mask", asAddr}, {"dim", asValue}}},
276275
/*isElemental=*/false},
277276
{"any_sync",
278-
&I::genVoteAnySync,
277+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::any>,
279278
{{{"mask", asValue}, {"pred", asValue}}},
280279
/*isElemental=*/false},
281280
{"asind", &I::genAsind},
@@ -341,7 +340,7 @@ static constexpr IntrinsicHandler handlers[]{
341340
{"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
342341
{"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
343342
{"ballot_sync",
344-
&I::genVoteBallotSync,
343+
&I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
345344
{{{"mask", asValue}, {"pred", asValue}}},
346345
/*isElemental=*/false},
347346
{"bessel_jn",
@@ -6583,46 +6582,20 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
65836582
return value;
65846583
}
65856584

6586-
static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
6587-
llvm::StringRef funcName, mlir::Type resTy,
6588-
llvm::ArrayRef<mlir::Value> args) {
6589-
mlir::MLIRContext *context = builder.getContext();
6590-
mlir::Type i32Ty = builder.getI32Type();
6591-
mlir::Type i1Ty = builder.getI1Type();
6592-
mlir::FunctionType ftype =
6593-
mlir::FunctionType::get(context, {i32Ty, i1Ty}, {resTy});
6594-
auto funcOp = builder.createFunction(loc, funcName, ftype);
6595-
llvm::SmallVector<mlir::Value> filteredArgs;
6596-
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
6597-
}
6598-
6599-
// ALL_SYNC
6600-
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
6601-
llvm::ArrayRef<mlir::Value> args) {
6602-
assert(args.size() == 2);
6603-
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync",
6604-
builder.getI1Type(), args);
6605-
}
6606-
6607-
// ANY_SYNC
6608-
mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
6609-
llvm::ArrayRef<mlir::Value> args) {
6610-
assert(args.size() == 2);
6611-
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync",
6612-
builder.getI1Type(), args);
6613-
}
6614-
6615-
// BALLOT_SYNC
6616-
mlir::Value
6617-
IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
6618-
llvm::ArrayRef<mlir::Value> args) {
6585+
// ALL_SYNC, ANY_SYNC, BALLOT_SYNC
6586+
template <mlir::NVVM::VoteSyncKind kind>
6587+
mlir::Value IntrinsicLibrary::genVoteSync(mlir::Type resultType,
6588+
llvm::ArrayRef<mlir::Value> args) {
66196589
assert(args.size() == 2);
66206590
mlir::Value arg1 =
66216591
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
6622-
return builder
6623-
.create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
6624-
mlir::NVVM::VoteSyncKind::ballot)
6625-
.getResult();
6592+
mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot
6593+
? builder.getI32Type()
6594+
: builder.getI1Type();
6595+
auto voteRes =
6596+
builder.create<mlir::NVVM::VoteSyncOp>(loc, resTy, args[0], arg1, kind)
6597+
.getResult();
6598+
return builder.create<fir::ConvertOp>(loc, resultType, voteRes);
66266599
}
66276600

66286601
// MATCH_ANY_SYNC

Diff for: flang/test/Lower/CUDA/cuda-device-proc.cuf

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ attributes(device) subroutine testVote()
301301
end subroutine
302302

303303
! CHECK-LABEL: func.func @_QPtestvote()
304-
! CHECK: fir.call @llvm.nvvm.vote.all.sync
305-
! CHECK: fir.call @llvm.nvvm.vote.any.sync
304+
! CHECK: %{{.*}} = nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
305+
! CHECK: %{{.*}} = nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
306306
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
307307

308308
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

0 commit comments

Comments
 (0)