Skip to content

Commit e50db7b

Browse files
authored
[CIR][CIRGen][Builtin][Neon] Lower neon_vshl_n_v and neon_vshlq_n_v (#965)
As title, but important step in this PR is to allow CIR ShiftOp to take vector of int type as input type. As result, I added a verifier to ShiftOp with 2 constraints 1. Input type either all vector or int type. This is consistent with LLVM::ShlOp, vector shift amount is expected. 2. In the spirit of C99 6.5.7.3, shift amount type must be the same as result type, the if vector type is used. (This is enforced in LLVM lowering for scalar int type).
1 parent 03154f8 commit e50db7b

File tree

10 files changed

+343
-118
lines changed

10 files changed

+343
-118
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,15 +1181,20 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
11811181
let summary = "Shift";
11821182
let description = [{
11831183
Shift `left` or `right`, according to the first operand. Second operand is
1184-
the shift target and the third the amount.
1184+
the shift target and the third the amount. Second and the thrid operand can
1185+
be either integer type or vector of integer type. However, they must be
1186+
either all vector of integer type, or all integer type. If they are vectors,
1187+
each vector element of the shift target is shifted by the corresponding
1188+
shift amount in the shift amount vector.
11851189

11861190
```mlir
11871191
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
1192+
%8 = cir.shift(left, %2 : !cir.vector<!s32i x 2>, %3 : !cir.vector<!s32i x 2>) -> !cir.vector<!s32i x 2>
11881193
```
11891194
}];
11901195

1191-
let results = (outs CIR_IntType:$result);
1192-
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
1196+
let results = (outs CIR_AnyIntOrVecOfInt:$result);
1197+
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
11931198
UnitAttr:$isShiftleft);
11941199

11951200
let assemblyFormat = [{
@@ -1200,8 +1205,7 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
12001205
`)` `->` type($result) attr-dict
12011206
}];
12021207

1203-
// Already covered by the traits
1204-
let hasVerifier = 0;
1208+
let hasVerifier = 1;
12051209
}
12061210

12071211
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ def IntegerVector : Type<
537537
]>, "!cir.vector of !cir.int"> {
538538
}
539539

540+
// Constraints
541+
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
542+
540543
// Pointer to Arrays
541544
def ArrayPtr : Type<
542545
And<[

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,10 +2207,10 @@ static int64_t getIntValueFromConstOp(mlir::Value val) {
22072207
}
22082208

22092209
/// This function `buildCommonNeonCallPattern0` implements a common way
2210-
// to generate neon intrinsic call that has following pattern:
2211-
// 1. There is a need to cast result of the intrinsic call back to
2212-
// expression type.
2213-
// 2. Function arg types are given, not deduced from actual arg types.
2210+
/// to generate neon intrinsic call that has following pattern:
2211+
/// 1. There is a need to cast result of the intrinsic call back to
2212+
/// expression type.
2213+
/// 2. Function arg types are given, not deduced from actual arg types.
22142214
static mlir::Value
22152215
buildCommonNeonCallPattern0(CIRGenFunction &cgf, std::string &intrincsName,
22162216
llvm::SmallVector<mlir::Type> argTypes,
@@ -2224,6 +2224,23 @@ buildCommonNeonCallPattern0(CIRGenFunction &cgf, std::string &intrincsName,
22242224
return builder.createBitcast(res, resultType);
22252225
}
22262226

2227+
/// Build a constant shift amount vector of `vecTy` to shift a vector
2228+
/// Here `shitfVal` is a constant integer that will be splated into a
2229+
/// a const vector of `vecTy` which is the return of this function
2230+
static mlir::Value buildNeonShiftVector(CIRGenBuilderTy &builder,
2231+
mlir::Value shiftVal,
2232+
mlir::cir::VectorType vecTy,
2233+
mlir::Location loc, bool neg) {
2234+
int shiftAmt = getIntValueFromConstOp(shiftVal);
2235+
llvm::SmallVector<mlir::Attribute> vecAttr{
2236+
vecTy.getSize(),
2237+
// ConstVectorAttr requires cir::IntAttr
2238+
mlir::cir::IntAttr::get(vecTy.getEltType(), shiftAmt)};
2239+
mlir::cir::ConstVectorAttr constVecAttr = mlir::cir::ConstVectorAttr::get(
2240+
vecTy, mlir::ArrayAttr::get(builder.getContext(), vecAttr));
2241+
return builder.create<mlir::cir::ConstantOp>(loc, vecTy, constVecAttr);
2242+
}
2243+
22272244
mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
22282245
unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic,
22292246
const char *nameHint, unsigned modifier, const CallExpr *e,
@@ -2300,6 +2317,13 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
23002317
: "llvm.aarch64.neon.sqrdmulh.lane",
23012318
resTy, getLoc(e->getExprLoc()));
23022319
}
2320+
case NEON::BI__builtin_neon_vshl_n_v:
2321+
case NEON::BI__builtin_neon_vshlq_n_v: {
2322+
mlir::Location loc = getLoc(e->getExprLoc());
2323+
ops[1] = buildNeonShiftVector(builder, ops[1], vTy, loc, false);
2324+
return builder.create<mlir::cir::ShiftOp>(
2325+
loc, vTy, builder.createBitcast(ops[0], vTy), ops[1], true);
2326+
}
23032327
}
23042328

23052329
// This second switch is for the intrinsics that might have a more generic

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3939,6 +3939,23 @@ LogicalResult BinOp::verify() {
39393939
return mlir::success();
39403940
}
39413941

3942+
//===----------------------------------------------------------------------===//
3943+
// ShiftOp Definitions
3944+
//===----------------------------------------------------------------------===//
3945+
LogicalResult ShiftOp::verify() {
3946+
mlir::Operation *op = getOperation();
3947+
mlir::Type resType = getResult().getType();
3948+
bool isOp0Vec = mlir::isa<mlir::cir::VectorType>(op->getOperand(0).getType());
3949+
bool isOp1Vec = mlir::isa<mlir::cir::VectorType>(op->getOperand(1).getType());
3950+
if (isOp0Vec != isOp1Vec)
3951+
return emitOpError() << "input types cannot be one vector and one scalar";
3952+
if (isOp1Vec && op->getOperand(1).getType() != resType) {
3953+
return emitOpError() << "shift amount must have the type of the result "
3954+
<< "if it is vector shift";
3955+
}
3956+
return mlir::success();
3957+
}
3958+
39423959
//===----------------------------------------------------------------------===//
39433960
// LabelOp Definitions
39443961
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2773,24 +2773,40 @@ class CIRShiftOpLowering
27732773
auto cirAmtTy =
27742774
mlir::dyn_cast<mlir::cir::IntType>(op.getAmount().getType());
27752775
auto cirValTy = mlir::dyn_cast<mlir::cir::IntType>(op.getValue().getType());
2776+
2777+
// Operands could also be vector type
2778+
auto cirAmtVTy =
2779+
mlir::dyn_cast<mlir::cir::VectorType>(op.getAmount().getType());
2780+
auto cirValVTy =
2781+
mlir::dyn_cast<mlir::cir::VectorType>(op.getValue().getType());
27762782
auto llvmTy = getTypeConverter()->convertType(op.getType());
27772783
mlir::Value amt = adaptor.getAmount();
27782784
mlir::Value val = adaptor.getValue();
27792785

2780-
assert(cirValTy && cirAmtTy && "non-integer shift is NYI");
2781-
assert(cirValTy == op.getType() && "inconsistent operands' types NYI");
2786+
assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
2787+
"shift input type must be integer or vector type, otherwise NYI");
2788+
2789+
assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
2790+
"inconsistent operands' types NYI");
27822791

27832792
// Ensure shift amount is the same type as the value. Some undefined
27842793
// behavior might occur in the casts below as per [C99 6.5.7.3].
2785-
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
2786-
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
2787-
cirValTy.getWidth());
2794+
// Vector type shift amount needs no cast as type consistency is expected to
2795+
// be already be enforced at CIRGen.
2796+
if (cirAmtTy)
2797+
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
2798+
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
2799+
cirValTy.getWidth());
27882800

27892801
// Lower to the proper LLVM shift operation.
27902802
if (op.getIsShiftleft())
27912803
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
27922804
else {
2793-
if (cirValTy.isUnsigned())
2805+
bool isUnSigned =
2806+
cirValTy ? !cirValTy.isSigned()
2807+
: !mlir::cast<mlir::cir::IntType>(cirValVTy.getEltType())
2808+
.isSigned();
2809+
if (isUnSigned)
27942810
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
27952811
else
27962812
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);

0 commit comments

Comments
 (0)