Skip to content

Commit 0bae1fd

Browse files
authored
[CIR] Infer MLIR context in type builders when possible (#1570)
Add `TypeBuilderWithInferredContext` to each CIR type that supports MLIR context inference from its parameters.
1 parent 4991421 commit 0bae1fd

File tree

13 files changed

+96
-108
lines changed

13 files changed

+96
-108
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
302302

303303
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
304304
mlir::Value imag) {
305-
auto resultComplexTy = cir::ComplexType::get(getContext(), real.getType());
305+
auto resultComplexTy = cir::ComplexType::get(real.getType());
306306
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
307307
}
308308

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+38
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def CIR_ComplexType : CIR_Type<"Complex", "complex",
228228

229229
let parameters = (ins "mlir::Type":$elementTy);
230230

231+
let builders = [
232+
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementTy), [{
233+
return $_get(elementTy.getContext(), elementTy);
234+
}]>,
235+
];
236+
231237
let assemblyFormat = [{
232238
`<` $elementTy `>`
233239
}];
@@ -301,6 +307,14 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
301307
let parameters = (ins "mlir::Type":$memberTy,
302308
"cir::RecordType":$clsTy);
303309

310+
let builders = [
311+
TypeBuilderWithInferredContext<(ins
312+
"mlir::Type":$memberTy, "cir::RecordType":$clsTy
313+
), [{
314+
return $_get(memberTy.getContext(), memberTy, clsTy);
315+
}]>,
316+
];
317+
304318
let assemblyFormat = [{
305319
`<` $memberTy `in` $clsTy `>`
306320
}];
@@ -338,6 +352,14 @@ def CIR_ArrayType : CIR_Type<"Array", "array",
338352

339353
let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size);
340354

355+
let builders = [
356+
TypeBuilderWithInferredContext<(ins
357+
"mlir::Type":$eltType, "uint64_t":$size
358+
), [{
359+
return $_get(eltType.getContext(), eltType, size);
360+
}]>,
361+
];
362+
341363
let assemblyFormat = [{
342364
`<` $eltType `x` $size `>`
343365
}];
@@ -358,6 +380,14 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
358380

359381
let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size);
360382

383+
let builders = [
384+
TypeBuilderWithInferredContext<(ins
385+
"mlir::Type":$eltType, "uint64_t":$size
386+
), [{
387+
return $_get(eltType.getContext(), eltType, size);
388+
}]>,
389+
];
390+
361391
let assemblyFormat = [{
362392
`<` $eltType `x` $size `>`
363393
}];
@@ -452,6 +482,14 @@ def CIR_MethodType : CIR_Type<"Method", "method",
452482
let parameters = (ins "cir::FuncType":$memberFuncTy,
453483
"cir::RecordType":$clsTy);
454484

485+
let builders = [
486+
TypeBuilderWithInferredContext<(ins
487+
"cir::FuncType":$memberFuncTy, "cir::RecordType":$clsTy
488+
), [{
489+
return $_get(memberFuncTy.getContext(), memberFuncTy, clsTy);
490+
}]>,
491+
];
492+
455493
let assemblyFormat = [{
456494
`<` qualified($memberFuncTy) `in` $clsTy `>`
457495
}];

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+4-9
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,16 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
168168
// If the string is full of null bytes, emit a #cir.zero rather than
169169
// a #cir.const_array.
170170
if (lastNonZeroPos == llvm::StringRef::npos) {
171-
auto arrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize);
171+
auto arrayTy = cir::ArrayType::get(eltTy, finalSize);
172172
return getZeroAttr(arrayTy);
173173
}
174174
// We will use trailing zeros only if there are more than one zero
175175
// at the end
176176
int trailingZerosNum =
177177
finalSize > lastNonZeroPos + 2 ? finalSize - lastNonZeroPos - 1 : 0;
178178
auto truncatedArrayTy =
179-
cir::ArrayType::get(getContext(), eltTy, finalSize - trailingZerosNum);
180-
auto fullArrayTy = cir::ArrayType::get(getContext(), eltTy, finalSize);
179+
cir::ArrayType::get(eltTy, finalSize - trailingZerosNum);
180+
auto fullArrayTy = cir::ArrayType::get(eltTy, finalSize);
181181
return cir::ConstArrayAttr::get(
182182
getContext(), fullArrayTy,
183183
mlir::StringAttr::get(str.drop_back(trailingZerosNum),
@@ -407,8 +407,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
407407
bool isSigned = false) {
408408
auto elementTy = mlir::dyn_cast_or_null<cir::IntType>(vt.getEltType());
409409
assert(elementTy && "expected int vector");
410-
return cir::VectorType::get(getContext(),
411-
isExtended
410+
return cir::VectorType::get(isExtended
412411
? getExtendedIntTy(elementTy, isSigned)
413412
: getTruncatedIntTy(elementTy, isSigned),
414413
vt.getSize());
@@ -530,10 +529,6 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
530529
return getCompleteRecordTy(members, name, packed, padded, ast);
531530
}
532531

533-
cir::ArrayType getArrayType(mlir::Type eltType, unsigned size) {
534-
return cir::ArrayType::get(getContext(), eltType, size);
535-
}
536-
537532
bool isSized(mlir::Type ty) {
538533
if (mlir::isa<cir::PointerType, cir::RecordType, cir::ArrayType,
539534
cir::BoolType, cir::IntType, cir::CIRFPTypeInterface>(ty))

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

+31-60
Original file line numberDiff line numberDiff line change
@@ -1863,14 +1863,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
18631863
switch (TypeFlags.getEltType()) {
18641864
case NeonTypeFlags::Int8:
18651865
case NeonTypeFlags::Poly8:
1866-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1867-
TypeFlags.isUnsigned() ? CGF->UInt8Ty
1866+
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt8Ty
18681867
: CGF->SInt8Ty,
18691868
V1Ty ? 1 : (8 << IsQuad));
18701869
case NeonTypeFlags::Int16:
18711870
case NeonTypeFlags::Poly16:
1872-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1873-
TypeFlags.isUnsigned() ? CGF->UInt16Ty
1871+
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt16Ty
18741872
: CGF->SInt16Ty,
18751873
V1Ty ? 1 : (4 << IsQuad));
18761874
case NeonTypeFlags::BFloat16:
@@ -1884,14 +1882,12 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
18841882
else
18851883
llvm_unreachable("NeonTypeFlags::Float16 NYI");
18861884
case NeonTypeFlags::Int32:
1887-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1888-
TypeFlags.isUnsigned() ? CGF->UInt32Ty
1885+
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt32Ty
18891886
: CGF->SInt32Ty,
18901887
V1Ty ? 1 : (2 << IsQuad));
18911888
case NeonTypeFlags::Int64:
18921889
case NeonTypeFlags::Poly64:
1893-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1894-
TypeFlags.isUnsigned() ? CGF->UInt64Ty
1890+
return cir::VectorType::get(TypeFlags.isUnsigned() ? CGF->UInt64Ty
18951891
: CGF->SInt64Ty,
18961892
V1Ty ? 1 : (1 << IsQuad));
18971893
case NeonTypeFlags::Poly128:
@@ -1900,12 +1896,10 @@ static cir::VectorType GetNeonType(CIRGenFunction *CGF, NeonTypeFlags TypeFlags,
19001896
// so we use v16i8 to represent poly128 and get pattern matched.
19011897
llvm_unreachable("NeonTypeFlags::Poly128 NYI");
19021898
case NeonTypeFlags::Float32:
1903-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1904-
CGF->getCIRGenModule().FloatTy,
1899+
return cir::VectorType::get(CGF->getCIRGenModule().FloatTy,
19051900
V1Ty ? 1 : (2 << IsQuad));
19061901
case NeonTypeFlags::Float64:
1907-
return cir::VectorType::get(CGF->getBuilder().getContext(),
1908-
CGF->getCIRGenModule().DoubleTy,
1902+
return cir::VectorType::get(CGF->getCIRGenModule().DoubleTy,
19091903
V1Ty ? 1 : (1 << IsQuad));
19101904
}
19111905
llvm_unreachable("Unknown vector element type!");
@@ -2102,7 +2096,7 @@ static cir::VectorType getSignChangedVectorType(CIRGenBuilderTy &builder,
21022096
auto elemTy = mlir::cast<cir::IntType>(vecTy.getEltType());
21032097
elemTy = elemTy.isSigned() ? builder.getUIntNTy(elemTy.getWidth())
21042098
: builder.getSIntNTy(elemTy.getWidth());
2105-
return cir::VectorType::get(builder.getContext(), elemTy, vecTy.getSize());
2099+
return cir::VectorType::get(elemTy, vecTy.getSize());
21062100
}
21072101

21082102
static cir::VectorType
@@ -2111,19 +2105,16 @@ getHalfEltSizeTwiceNumElemsVecType(CIRGenBuilderTy &builder,
21112105
auto elemTy = mlir::cast<cir::IntType>(vecTy.getEltType());
21122106
elemTy = elemTy.isSigned() ? builder.getSIntNTy(elemTy.getWidth() / 2)
21132107
: builder.getUIntNTy(elemTy.getWidth() / 2);
2114-
return cir::VectorType::get(builder.getContext(), elemTy,
2115-
vecTy.getSize() * 2);
2108+
return cir::VectorType::get(elemTy, vecTy.getSize() * 2);
21162109
}
21172110

21182111
static cir::VectorType
21192112
castVecOfFPTypeToVecOfIntWithSameWidth(CIRGenBuilderTy &builder,
21202113
cir::VectorType vecTy) {
21212114
if (mlir::isa<cir::SingleType>(vecTy.getEltType()))
2122-
return cir::VectorType::get(builder.getContext(), builder.getSInt32Ty(),
2123-
vecTy.getSize());
2115+
return cir::VectorType::get(builder.getSInt32Ty(), vecTy.getSize());
21242116
if (mlir::isa<cir::DoubleType>(vecTy.getEltType()))
2125-
return cir::VectorType::get(builder.getContext(), builder.getSInt64Ty(),
2126-
vecTy.getSize());
2117+
return cir::VectorType::get(builder.getSInt64Ty(), vecTy.getSize());
21272118
llvm_unreachable(
21282119
"Unsupported element type in getVecOfIntTypeWithSameEltWidth");
21292120
}
@@ -2315,8 +2306,7 @@ static mlir::Value emitCommonNeonVecAcrossCall(CIRGenFunction &cgf,
23152306
const clang::CallExpr *e) {
23162307
CIRGenBuilderTy &builder = cgf.getBuilder();
23172308
mlir::Value op = cgf.emitScalarExpr(e->getArg(0));
2318-
cir::VectorType vTy =
2319-
cir::VectorType::get(&cgf.getMLIRContext(), eltTy, vecLen);
2309+
cir::VectorType vTy = cir::VectorType::get(eltTy, vecLen);
23202310
llvm::SmallVector<mlir::Value, 1> args{op};
23212311
return emitNeonCall(builder, {vTy}, args, intrincsName, eltTy,
23222312
cgf.getLoc(e->getExprLoc()));
@@ -2447,8 +2437,7 @@ mlir::Value CIRGenFunction::emitCommonNeonBuiltinExpr(
24472437
cir::VectorType resTy =
24482438
(builtinID == NEON::BI__builtin_neon_vqdmulhq_lane_v ||
24492439
builtinID == NEON::BI__builtin_neon_vqrdmulhq_lane_v)
2450-
? cir::VectorType::get(&getMLIRContext(), vTy.getEltType(),
2451-
vTy.getSize() * 2)
2440+
? cir::VectorType::get(vTy.getEltType(), vTy.getSize() * 2)
24522441
: vTy;
24532442
cir::VectorType mulVecT =
24542443
GetNeonType(this, NeonTypeFlags(neonType.getEltType(), false,
@@ -2888,10 +2877,8 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
28882877
llvm_unreachable(" neon_vqmovnh_u16 NYI ");
28892878
case NEON::BI__builtin_neon_vqmovns_s32: {
28902879
mlir::Location loc = cgf.getLoc(expr->getExprLoc());
2891-
cir::VectorType argVecTy =
2892-
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt32Ty, 4);
2893-
cir::VectorType resVecTy =
2894-
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt16Ty, 4);
2880+
cir::VectorType argVecTy = cir::VectorType::get(cgf.SInt32Ty, 4);
2881+
cir::VectorType resVecTy = cir::VectorType::get(cgf.SInt16Ty, 4);
28952882
vecExtendIntValue(cgf, argVecTy, ops[0], loc);
28962883
mlir::Value result = emitNeonCall(builder, {argVecTy}, ops,
28972884
"aarch64.neon.sqxtn", resVecTy, loc);
@@ -3706,88 +3693,74 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
37063693

37073694
case NEON::BI__builtin_neon_vset_lane_f64: {
37083695
Ops.push_back(emitScalarExpr(E->getArg(2)));
3709-
Ops[1] = builder.createBitcast(
3710-
Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1));
3696+
Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 1));
37113697
return builder.create<cir::VecInsertOp>(getLoc(E->getExprLoc()), Ops[1],
37123698
Ops[0], Ops[2]);
37133699
}
37143700
case NEON::BI__builtin_neon_vsetq_lane_f64: {
37153701
Ops.push_back(emitScalarExpr(E->getArg(2)));
3716-
Ops[1] = builder.createBitcast(
3717-
Ops[1], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2));
3702+
Ops[1] = builder.createBitcast(Ops[1], cir::VectorType::get(DoubleTy, 2));
37183703
return builder.create<cir::VecInsertOp>(getLoc(E->getExprLoc()), Ops[1],
37193704
Ops[0], Ops[2]);
37203705
}
37213706
case NEON::BI__builtin_neon_vget_lane_i8:
37223707
case NEON::BI__builtin_neon_vdupb_lane_i8:
3723-
Ops[0] = builder.createBitcast(
3724-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 8));
3708+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 8));
37253709
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37263710
emitScalarExpr(E->getArg(1)));
37273711
case NEON::BI__builtin_neon_vgetq_lane_i8:
37283712
case NEON::BI__builtin_neon_vdupb_laneq_i8:
3729-
Ops[0] = builder.createBitcast(
3730-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt8Ty, 16));
3713+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt8Ty, 16));
37313714
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37323715
emitScalarExpr(E->getArg(1)));
37333716
case NEON::BI__builtin_neon_vget_lane_i16:
37343717
case NEON::BI__builtin_neon_vduph_lane_i16:
3735-
Ops[0] = builder.createBitcast(
3736-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 4));
3718+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 4));
37373719
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37383720
emitScalarExpr(E->getArg(1)));
37393721
case NEON::BI__builtin_neon_vgetq_lane_i16:
37403722
case NEON::BI__builtin_neon_vduph_laneq_i16:
3741-
Ops[0] = builder.createBitcast(
3742-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt16Ty, 8));
3723+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt16Ty, 8));
37433724
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37443725
emitScalarExpr(E->getArg(1)));
37453726
case NEON::BI__builtin_neon_vget_lane_i32:
37463727
case NEON::BI__builtin_neon_vdups_lane_i32:
3747-
Ops[0] = builder.createBitcast(
3748-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 2));
3728+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 2));
37493729
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37503730
emitScalarExpr(E->getArg(1)));
37513731
case NEON::BI__builtin_neon_vget_lane_f32:
37523732
case NEON::BI__builtin_neon_vdups_lane_f32:
3753-
Ops[0] = builder.createBitcast(
3754-
Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 2));
3733+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 2));
37553734
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37563735
emitScalarExpr(E->getArg(1)));
37573736
case NEON::BI__builtin_neon_vgetq_lane_i32:
37583737
case NEON::BI__builtin_neon_vdups_laneq_i32:
3759-
Ops[0] = builder.createBitcast(
3760-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt32Ty, 4));
3738+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt32Ty, 4));
37613739
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37623740
emitScalarExpr(E->getArg(1)));
37633741
case NEON::BI__builtin_neon_vget_lane_i64:
37643742
case NEON::BI__builtin_neon_vdupd_lane_i64:
3765-
Ops[0] = builder.createBitcast(
3766-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 1));
3743+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 1));
37673744
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37683745
emitScalarExpr(E->getArg(1)));
37693746
case NEON::BI__builtin_neon_vdupd_lane_f64:
37703747
case NEON::BI__builtin_neon_vget_lane_f64:
3771-
Ops[0] = builder.createBitcast(
3772-
Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 1));
3748+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 1));
37733749
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37743750
emitScalarExpr(E->getArg(1)));
37753751
case NEON::BI__builtin_neon_vgetq_lane_i64:
37763752
case NEON::BI__builtin_neon_vdupd_laneq_i64:
3777-
Ops[0] = builder.createBitcast(
3778-
Ops[0], cir::VectorType::get(&getMLIRContext(), UInt64Ty, 2));
3753+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(UInt64Ty, 2));
37793754
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37803755
emitScalarExpr(E->getArg(1)));
37813756
case NEON::BI__builtin_neon_vgetq_lane_f32:
37823757
case NEON::BI__builtin_neon_vdups_laneq_f32:
3783-
Ops[0] = builder.createBitcast(
3784-
Ops[0], cir::VectorType::get(&getMLIRContext(), FloatTy, 4));
3758+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(FloatTy, 4));
37853759
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37863760
emitScalarExpr(E->getArg(1)));
37873761
case NEON::BI__builtin_neon_vgetq_lane_f64:
37883762
case NEON::BI__builtin_neon_vdupd_laneq_f64:
3789-
Ops[0] = builder.createBitcast(
3790-
Ops[0], cir::VectorType::get(&getMLIRContext(), DoubleTy, 2));
3763+
Ops[0] = builder.createBitcast(Ops[0], cir::VectorType::get(DoubleTy, 2));
37913764
return builder.create<cir::VecExtractOp>(getLoc(E->getExprLoc()), Ops[0],
37923765
emitScalarExpr(E->getArg(1)));
37933766
case NEON::BI__builtin_neon_vaddh_f16: {
@@ -4318,7 +4291,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
43184291
[[fallthrough]];
43194292
case NEON::BI__builtin_neon_vaddv_s16: {
43204293
cir::IntType eltTy = usgn ? UInt16Ty : SInt16Ty;
4321-
cir::VectorType vTy = cir::VectorType::get(builder.getContext(), eltTy, 4);
4294+
cir::VectorType vTy = cir::VectorType::get(eltTy, 4);
43224295
Ops.push_back(emitScalarExpr(E->getArg(0)));
43234296
// This is to add across the vector elements, so wider result type needed.
43244297
Ops[0] = emitNeonCall(builder, {vTy}, Ops,
@@ -4427,8 +4400,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
44274400
usgn = true;
44284401
[[fallthrough]];
44294402
case NEON::BI__builtin_neon_vaddlvq_s16: {
4430-
mlir::Type argTy = cir::VectorType::get(builder.getContext(),
4431-
usgn ? UInt16Ty : SInt16Ty, 8);
4403+
mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 8);
44324404
llvm::SmallVector<mlir::Value, 1> argOps = {emitScalarExpr(E->getArg(0))};
44334405
return emitNeonCall(builder, {argTy}, argOps,
44344406
usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv",
@@ -4441,8 +4413,7 @@ CIRGenFunction::emitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
44414413
usgn = true;
44424414
[[fallthrough]];
44434415
case NEON::BI__builtin_neon_vaddlv_s16: {
4444-
mlir::Type argTy = cir::VectorType::get(builder.getContext(),
4445-
usgn ? UInt16Ty : SInt16Ty, 4);
4416+
mlir::Type argTy = cir::VectorType::get(usgn ? UInt16Ty : SInt16Ty, 4);
44464417
llvm::SmallVector<mlir::Value, 1> argOps = {emitScalarExpr(E->getArg(0))};
44474418
return emitNeonCall(builder, {argTy}, argOps,
44484419
usgn ? "aarch64.neon.uaddlv" : "aarch64.neon.saddlv",

clang/lib/CIR/CodeGen/CIRGenCUDANV.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ void CIRGenNVCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
129129
// we need to pass it as `void *args[2] = { &a, &b }`.
130130

131131
auto loc = fn.getLoc();
132-
auto voidPtrArrayTy =
133-
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
132+
auto voidPtrArrayTy = cir::ArrayType::get(cgm.VoidPtrTy, args.size());
134133
mlir::Value kernelArgs = builder.createAlloca(
135134
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
136135
CharUnits::fromQuantity(16));

0 commit comments

Comments
 (0)