Skip to content

Commit c8c5fbd

Browse files
committed
[CIR] Add support for __fp16 type
1 parent 2a11e98 commit c8c5fbd

File tree

8 files changed

+145
-10
lines changed

8 files changed

+145
-10
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+35
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,41 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
462462
return createCast(mlir::cir::CastKind::int_to_ptr, src, newTy);
463463
}
464464

465+
mlir::Value createCastFromFP16(mlir::Location loc, mlir::Value src,
466+
mlir::Type destTy) {
467+
assert((mlir::isa<mlir::cir::SingleType, mlir::cir::DoubleType>(destTy)) &&
468+
"dest type must be either !cir.float or !cir.double");
469+
assert(mlir::isa<mlir::cir::StorageOnlyFP16Type>(src.getType()) &&
470+
"src must be of !cir.fp16.storage type");
471+
472+
std::string intrinName = "llvm.convert.from.fp16";
473+
if (mlir::isa<mlir::cir::SingleType>(destTy))
474+
intrinName.append(".f32");
475+
else
476+
intrinName.append(".f64");
477+
478+
auto intrinsicCallOp = create<mlir::cir::IntrinsicCallOp>(
479+
loc, getStringAttr(intrinName), destTy, src);
480+
return intrinsicCallOp.getResult();
481+
}
482+
483+
mlir::Value createCastToFP16(mlir::Location loc, mlir::Value src) {
484+
auto srcTy = src.getType();
485+
assert((mlir::isa<mlir::cir::SingleType, mlir::cir::DoubleType>(srcTy)) &&
486+
"src type must be either float or double");
487+
488+
std::string intrinName = "llvm.convert.to.fp16";
489+
if (mlir::isa<mlir::cir::SingleType>(srcTy))
490+
intrinName.append(".f32");
491+
else
492+
intrinName.append(".f64");
493+
494+
auto destTy = mlir::cir::StorageOnlyFP16Type::get(src.getContext());
495+
auto intrinsicCallOp = create<mlir::cir::IntrinsicCallOp>(
496+
loc, getStringAttr(intrinName), destTy, src);
497+
return intrinsicCallOp.getResult();
498+
}
499+
465500
mlir::Value createGetMemberOp(mlir::Location &loc, mlir::Value structPtr,
466501
const char *fldName, unsigned idx) {
467502

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+17
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,23 @@ def CIR_FP16 : CIR_FloatType<"FP16", "f16"> {
157157
}];
158158
}
159159

160+
def CIR_StorageOnlyFP16 : CIR_FloatType<"StorageOnlyFP16", "f16.storage"> {
161+
let summary = "CIR type that represents a storage-only fp16 type";
162+
let description = [{
163+
Floating-point type that represents a storage-only fp16 type.
164+
165+
Unlike `!cir.f16`, all the usual arithmetic operations are not defined for
166+
`!cir.f16.storage`. Values of `!cir.f16.storage` type must be promoted to
167+
a single- or double-precision floating point value before performing any
168+
arithmetic operations.
169+
170+
Additionally, `!cir.f16.storage` is lowered to the `i16` LLVM type. The
171+
promotion and un-promotion of `!cir.f16.storage` values are lowered to
172+
calls to `llvm.convert.from.f16` and `llvm.convert.to.f16` LLVM intrinsic
173+
functions.
174+
}];
175+
}
176+
160177
def CIR_BFloat16 : CIR_FloatType<"BF16", "bf16"> {
161178
let summary = "CIR type that represents";
162179
let description = [{

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+59-8
Original file line numberDiff line numberDiff line change
@@ -502,16 +502,26 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
502502
// TODO(cir): CGFPOptionsRAII
503503
assert(!MissingFeatures::CGFPOptionsRAII());
504504

505-
if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
506-
llvm_unreachable("__fp16 type NYI");
505+
if (type->isHalfType() &&
506+
!CGF.getContext().getLangOpts().NativeHalfType) {
507+
// Another special case: half FP increment should be done via float
508+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics()) {
509+
value = Builder.createCastFromFP16(CGF.getLoc(E->getExprLoc()), input,
510+
CGF.CGM.FloatTy);
511+
} else {
512+
value = Builder.createCast(CGF.getLoc(E->getExprLoc()),
513+
mlir::cir::CastKind::floating, input,
514+
CGF.CGM.FloatTy);
515+
}
516+
}
507517

508518
if (mlir::isa<mlir::cir::SingleType, mlir::cir::DoubleType>(
509519
value.getType())) {
510520
// Create the inc/dec operation.
511521
// NOTE(CIR): clang calls CreateAdd but folds this to a unary op
512522
auto kind =
513523
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
514-
value = buildUnaryOp(E, kind, input);
524+
value = buildUnaryOp(E, kind, value);
515525
} else {
516526
// Remaining types are Half, Bfloat16, LongDouble, __ibm128 or
517527
// __float128. Convert from float.
@@ -537,8 +547,16 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
537547
value = Builder.createBinop(value, mlir::cir::BinOpKind::Add, amt);
538548
}
539549

540-
if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
541-
llvm_unreachable("NYI");
550+
if (type->isHalfType() &&
551+
!CGF.getContext().getLangOpts().NativeHalfType) {
552+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics()) {
553+
value = Builder.createCastToFP16(CGF.getLoc(E->getExprLoc()), value);
554+
} else {
555+
value = Builder.createCast(CGF.getLoc(E->getExprLoc()),
556+
mlir::cir::CastKind::floating, value,
557+
input.getType());
558+
}
559+
}
542560

543561
} else if (type->isFixedPointType()) {
544562
llvm_unreachable("no fixed point inc/dec yet");
@@ -1043,7 +1061,23 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
10431061
// Cast from half through float if half isn't a native type.
10441062
if (SrcType->isHalfType() &&
10451063
!CGF.getContext().getLangOpts().NativeHalfType) {
1046-
llvm_unreachable("not implemented");
1064+
// Cast to FP using the intrinsic if the half type itself isn't supported.
1065+
if (mlir::isa<mlir::cir::CIRFPTypeInterface>(DstTy)) {
1066+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics())
1067+
return Builder.createCastFromFP16(CGF.getLoc(Loc), Src, DstTy);
1068+
} else {
1069+
// Cast to other types through float, using either the intrinsic or
1070+
// FPExt, depending on whether the half type itself is supported (as
1071+
// opposed to operations on half, available with NativeHalfType).
1072+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics()) {
1073+
Src = Builder.createCastFromFP16(CGF.getLoc(Loc), Src, DstTy);
1074+
} else {
1075+
Src = Builder.createCast(
1076+
CGF.getLoc(Loc), mlir::cir::CastKind::floating, Src, CGF.FloatTy);
1077+
}
1078+
SrcType = CGF.getContext().FloatTy;
1079+
SrcTy = CGF.FloatTy;
1080+
}
10471081
}
10481082

10491083
// TODO(cir): LLVM codegen ignore conversions like int -> uint,
@@ -1098,13 +1132,30 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
10981132
// Cast to half through float if half isn't a native type.
10991133
if (DstType->isHalfType() &&
11001134
!CGF.getContext().getLangOpts().NativeHalfType) {
1101-
llvm_unreachable("NYI");
1135+
// Make sure we cast in a single step if from another FP type.
1136+
if (mlir::isa<mlir::cir::CIRFPTypeInterface>(SrcTy)) {
1137+
// Use the intrinsic if the half type itself isn't supported
1138+
// (as opposed to operations on half, available with NativeHalfType).
1139+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics())
1140+
return Builder.createCastToFP16(CGF.getLoc(Loc), Src);
1141+
// If the half type is supported, just use an fptrunc.
1142+
return Builder.createCast(CGF.getLoc(Loc),
1143+
mlir::cir::CastKind::floating, Src, DstTy);
1144+
}
1145+
DstTy = CGF.FloatTy;
11021146
}
11031147

11041148
Res = buildScalarCast(Src, SrcType, DstType, SrcTy, DstTy, Opts);
11051149

11061150
if (DstTy != ResTy) {
1107-
llvm_unreachable("NYI");
1151+
if (CGF.getContext().getTargetInfo().useFP16ConversionIntrinsics()) {
1152+
assert(mlir::cast<mlir::cir::StorageOnlyFP16Type>(ResTy) &&
1153+
"only storage-only fp16 requires extra conversion");
1154+
Res = Builder.createCastToFP16(CGF.getLoc(Loc), Res);
1155+
} else {
1156+
Res = Builder.createCast(CGF.getLoc(Loc), mlir::cir::CastKind::floating,
1157+
Res, ResTy);
1158+
}
11081159
}
11091160

11101161
if (Opts.EmitImplicitIntegerTruncationChecks)

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
137137
VoidPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), VoidTy);
138138

139139
FP16Ty = ::mlir::cir::FP16Type::get(builder.getContext());
140+
StorageOnlyFP16Ty =
141+
::mlir::cir::StorageOnlyFP16Type::get(builder.getContext());
140142
BFloat16Ty = ::mlir::cir::BF16Type::get(builder.getContext());
141143
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
142144
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());

clang/lib/CIR/CodeGen/CIRGenTypeCache.h

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct CIRGenTypeCache {
3737
mlir::cir::IntType UInt8Ty, UInt16Ty, UInt32Ty, UInt64Ty;
3838
/// half, bfloat, float, double, fp80
3939
mlir::cir::FP16Type FP16Ty;
40+
mlir::cir::StorageOnlyFP16Type StorageOnlyFP16Ty;
4041
mlir::cir::BF16Type BFloat16Ty;
4142
mlir::cir::SingleType FloatTy;
4243
mlir::cir::DoubleType DoubleTy;

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,11 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
469469
ResultType = CGM.FP16Ty;
470470
break;
471471
case BuiltinType::Half:
472-
// Should be the same as above?
473-
assert(0 && "not implemented");
472+
if (Context.getLangOpts().NativeHalfType ||
473+
!Context.getTargetInfo().useFP16ConversionIntrinsics())
474+
ResultType = CGM.FP16Ty;
475+
else
476+
ResultType = CGM.StorageOnlyFP16Ty;
474477
break;
475478
case BuiltinType::BFloat16:
476479
ResultType = CGM.BFloat16Ty;

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,28 @@ FP16Type::getPreferredAlignment(const ::mlir::DataLayout &dataLayout,
719719
return (uint64_t)(getWidth() / 8);
720720
}
721721

722+
const llvm::fltSemantics &StorageOnlyFP16Type::getFloatSemantics() const {
723+
return llvm::APFloat::IEEEhalf();
724+
}
725+
726+
llvm::TypeSize StorageOnlyFP16Type::getTypeSizeInBits(
727+
const mlir::DataLayout &dataLayout,
728+
mlir::DataLayoutEntryListRef params) const {
729+
return llvm::TypeSize::getFixed(getWidth());
730+
}
731+
732+
uint64_t StorageOnlyFP16Type::getABIAlignment(
733+
const mlir::DataLayout &dataLayout,
734+
mlir::DataLayoutEntryListRef params) const {
735+
return (uint64_t)(getWidth() / 8);
736+
}
737+
738+
uint64_t StorageOnlyFP16Type::getPreferredAlignment(
739+
const ::mlir::DataLayout &dataLayout,
740+
::mlir::DataLayoutEntryListRef params) const {
741+
return (uint64_t)(getWidth() / 8);
742+
}
743+
722744
const llvm::fltSemantics &BF16Type::getFloatSemantics() const {
723745
return llvm::APFloat::BFloat();
724746
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -3951,6 +3951,10 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
39513951
converter.addConversion([&](mlir::cir::FP16Type type) -> mlir::Type {
39523952
return mlir::FloatType::getF16(type.getContext());
39533953
});
3954+
converter.addConversion(
3955+
[&](mlir::cir::StorageOnlyFP16Type type) -> mlir::Type {
3956+
return mlir::IntegerType::get(type.getContext(), 16);
3957+
});
39543958
converter.addConversion([&](mlir::cir::BF16Type type) -> mlir::Type {
39553959
return mlir::FloatType::getBF16(type.getContext());
39563960
});

0 commit comments

Comments
 (0)