Skip to content

Commit 99c3011

Browse files
Lancernxlauko
authored andcommitted
[CIR][CIRGen] Add CIRGen support for float16 and bfloat (llvm#571)
This PR adds two new CIR floating-point types, namely `!cir.f16` and `!cir.bf16`, to represent the float16 format and bfloat format, respectively. This PR converts the clang extension type `_Float16` to `!cir.f16`, and converts the clang extension type `__bf16` type to `!cir.bf16`. The type conversion for clang extension type `__fp16` is not included in this PR since it requires additional work during CIRGen. Only CIRGen is implemented here, LLVMIR lowering / MLIR lowering should come next.
1 parent 80012e1 commit 99c3011

File tree

8 files changed

+86
-14
lines changed

8 files changed

+86
-14
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+16-2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,20 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
150150
}];
151151
}
152152

153+
def CIR_FP16 : CIR_FloatType<"FP16", "f16"> {
154+
let summary = "CIR type that represents IEEE-754 binary16 format";
155+
let description = [{
156+
Floating-point type that represents the IEEE-754 binary16 format.
157+
}];
158+
}
159+
160+
def CIR_BFloat16 : CIR_FloatType<"BF16", "bf16"> {
161+
let summary = "CIR type that represents";
162+
let description = [{
163+
Floating-point type that represents the bfloat16 format.
164+
}];
165+
}
166+
153167
def CIR_FP80 : CIR_FloatType<"FP80", "f80"> {
154168
let summary = "CIR type that represents x87 80-bit floating-point format";
155169
let description = [{
@@ -179,7 +193,7 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
179193

180194
// Constraints
181195

182-
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_LongDouble]>;
196+
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
183197
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
184198

185199
//===----------------------------------------------------------------------===//
@@ -475,7 +489,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
475489
def CIR_AnyType : AnyTypeOf<[
476490
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
477491
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
478-
CIR_AnyFloat,
492+
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
479493
]>;
480494

481495
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+4
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
250250
return mlir::cir::FPAttr::getZero(fltType);
251251
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
252252
return mlir::cir::FPAttr::getZero(fltType);
253+
if (auto fltType = ty.dyn_cast<mlir::cir::FP16Type>())
254+
return mlir::cir::FPAttr::getZero(fltType);
255+
if (auto fltType = ty.dyn_cast<mlir::cir::BF16Type>())
256+
return mlir::cir::FPAttr::getZero(fltType);
253257
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
254258
return getZeroAttr(arrTy);
255259
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
132132
/// Emit a value that corresponds to null for the given type.
133133
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);
134134

135+
mlir::Value buildPromotedValue(mlir::Value result, QualType PromotionType) {
136+
return Builder.createFloatingCast(result, ConvertType(PromotionType));
137+
}
138+
139+
mlir::Value buildUnPromotedValue(mlir::Value result, QualType ExprType) {
140+
return Builder.createFloatingCast(result, ConvertType(ExprType));
141+
}
142+
143+
mlir::Value buildPromoted(const Expr *E, QualType PromotionType);
144+
135145
//===--------------------------------------------------------------------===//
136146
// Visitor Methods
137147
//===--------------------------------------------------------------------===//
@@ -896,8 +906,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
896906
if (auto *CT = Ty->getAs<ComplexType>()) {
897907
llvm_unreachable("NYI");
898908
}
899-
if (Ty.UseExcessPrecision(CGF.getContext()))
900-
llvm_unreachable("NYI");
909+
if (Ty.UseExcessPrecision(CGF.getContext())) {
910+
if (auto *VT = Ty->getAs<VectorType>())
911+
llvm_unreachable("NYI");
912+
return CGF.getContext().FloatTy;
913+
}
901914
return QualType();
902915
}
903916

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,11 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
131131
// Initialize CIR pointer types cache.
132132
VoidPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), VoidTy);
133133

134-
// TODO: HalfTy
135-
// TODO: BFloatTy
134+
FP16Ty = ::mlir::cir::FP16Type::get(builder.getContext());
135+
BFloat16Ty = ::mlir::cir::BF16Type::get(builder.getContext());
136136
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
137137
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
138138
FP80Ty = ::mlir::cir::FP80Type::get(builder.getContext());
139-
// TODO(cir): perhaps we should abstract long double variations into a custom
140-
// cir.long_double type. Said type would also hold the semantics for lowering.
141139

142140
// TODO: PointerWidthInBits
143141
PointerAlignInBytes =

clang/lib/CIR/CodeGen/CIRGenTypeCache.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ struct CIRGenTypeCache {
3434
mlir::cir::IntType SInt8Ty, SInt16Ty, SInt32Ty, SInt64Ty;
3535
// usigned char, unsigned, unsigned short, unsigned long
3636
mlir::cir::IntType UInt8Ty, UInt16Ty, UInt32Ty, UInt64Ty;
37-
/// half, bfloat, float, double
38-
// mlir::Type HalfTy, BFloatTy;
39-
// TODO(cir): perhaps we should abstract long double variations into a custom
40-
// cir.long_double type. Said type would also hold the semantics for lowering.
37+
/// half, bfloat, float, double, fp80
38+
mlir::cir::FP16Type FP16Ty;
39+
mlir::cir::BF16Type BFloat16Ty;
4140
mlir::cir::SingleType FloatTy;
4241
mlir::cir::DoubleType DoubleTy;
4342
mlir::cir::FP80Type FP80Ty;

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,14 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
464464
break;
465465

466466
case BuiltinType::Float16:
467-
ResultType = Builder.getF16Type();
467+
ResultType = CGM.FP16Ty;
468468
break;
469469
case BuiltinType::Half:
470470
// Should be the same as above?
471471
assert(0 && "not implemented");
472472
break;
473473
case BuiltinType::BFloat16:
474-
ResultType = Builder.getBF16Type();
474+
ResultType = CGM.BFloat16Ty;
475475
break;
476476
case BuiltinType::Float:
477477
ResultType = CGM.FloatTy;

clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ struct UnimplementedFeature {
130130
static bool shouldEmitLifetimeMarkers() { return false; }
131131
static bool peepholeProtection() { return false; }
132132
static bool CGCapturedStmtInfo() { return false; }
133+
static bool CGFPOptionsRAII() { return false; }
134+
static bool getFPFeaturesInEffect() { return false; }
133135
static bool cxxABI() { return false; }
134136
static bool openCL() { return false; }
135137
static bool CUDA() { return false; }

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,48 @@ DoubleType::getPreferredAlignment(const ::mlir::DataLayout &dataLayout,
691691
return (uint64_t)(getWidth() / 8);
692692
}
693693

694+
const llvm::fltSemantics &FP16Type::getFloatSemantics() const {
695+
return llvm::APFloat::IEEEhalf();
696+
}
697+
698+
llvm::TypeSize
699+
FP16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
700+
mlir::DataLayoutEntryListRef params) const {
701+
return llvm::TypeSize::getFixed(getWidth());
702+
}
703+
704+
uint64_t FP16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
705+
mlir::DataLayoutEntryListRef params) const {
706+
return (uint64_t)(getWidth() / 8);
707+
}
708+
709+
uint64_t
710+
FP16Type::getPreferredAlignment(const ::mlir::DataLayout &dataLayout,
711+
::mlir::DataLayoutEntryListRef params) const {
712+
return (uint64_t)(getWidth() / 8);
713+
}
714+
715+
const llvm::fltSemantics &BF16Type::getFloatSemantics() const {
716+
return llvm::APFloat::BFloat();
717+
}
718+
719+
llvm::TypeSize
720+
BF16Type::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
721+
mlir::DataLayoutEntryListRef params) const {
722+
return llvm::TypeSize::getFixed(getWidth());
723+
}
724+
725+
uint64_t BF16Type::getABIAlignment(const mlir::DataLayout &dataLayout,
726+
mlir::DataLayoutEntryListRef params) const {
727+
return (uint64_t)(getWidth() / 8);
728+
}
729+
730+
uint64_t
731+
BF16Type::getPreferredAlignment(const ::mlir::DataLayout &dataLayout,
732+
::mlir::DataLayoutEntryListRef params) const {
733+
return (uint64_t)(getWidth() / 8);
734+
}
735+
694736
const llvm::fltSemantics &FP80Type::getFloatSemantics() const {
695737
return llvm::APFloat::x87DoubleExtended();
696738
}

0 commit comments

Comments
 (0)