Skip to content

Commit b82fdb0

Browse files
Lancernlanza
authored andcommitted
[CIR] introduce CIR floating-point types (#385)
This PR adds a dedicated `cir.float` type for representing floating-point types. There are several issues linked to this PR: #5, #78, and #90.
1 parent 72eccae commit b82fdb0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+937
-617
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,33 @@ def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
215215
let hasCustomAssemblyFormat = 1;
216216
}
217217

218+
//===----------------------------------------------------------------------===//
219+
// FPAttr
220+
//===----------------------------------------------------------------------===//
221+
222+
def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
223+
let summary = "An attribute containing a floating-point value";
224+
let description = [{
225+
An fp attribute is a literal attribute that represents a floating-point
226+
value of the specified floating-point type.
227+
}];
228+
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
229+
let builders = [
230+
AttrBuilderWithInferredContext<(ins "Type":$type,
231+
"const APFloat &":$value), [{
232+
return $_get(type.getContext(), type, value);
233+
}]>,
234+
];
235+
let extraClassDeclaration = [{
236+
static FPAttr getZero(mlir::Type type);
237+
}];
238+
let genVerifyDecl = 1;
239+
240+
let assemblyFormat = [{
241+
`<` custom<FloatLiteral>($value, ref($type)) `>`
242+
}];
243+
}
244+
218245
//===----------------------------------------------------------------------===//
219246
// ConstPointerAttr
220247
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2615,8 +2615,8 @@ def IterEndOp : CIR_Op<"iterator_end"> {
26152615

26162616
class UnaryFPToFPBuiltinOp<string mnemonic>
26172617
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
2618-
let arguments = (ins AnyFloat:$src);
2619-
let results = (outs AnyFloat:$result);
2618+
let arguments = (ins CIR_AnyFloat:$src);
2619+
let results = (outs CIR_AnyFloat:$result);
26202620
let summary = "libc builtin equivalent ignoring "
26212621
"floating point exceptions and errno";
26222622
let assemblyFormat = "$src `:` type($src) attr-dict";

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Types.h"
1818
#include "mlir/Interfaces/DataLayoutInterfaces.h"
19+
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"
1920

2021
#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
2122

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@
1515

1616
include "clang/CIR/Dialect/IR/CIRDialect.td"
1717
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
18+
include "clang/CIR/Interfaces/CIRFPTypeInterface.td"
1819
include "mlir/Interfaces/DataLayoutInterfaces.td"
1920
include "mlir/IR/AttrTypeBase.td"
21+
include "mlir/IR/EnumAttr.td"
2022

2123
//===----------------------------------------------------------------------===//
2224
// CIR Types
2325
//===----------------------------------------------------------------------===//
2426

25-
class CIR_Type<string name, string typeMnemonic, list<Trait> traits = []> :
26-
TypeDef<CIR_Dialect, name, traits> {
27+
class CIR_Type<string name, string typeMnemonic, list<Trait> traits = [],
28+
string baseCppClass = "::mlir::Type">
29+
: TypeDef<CIR_Dialect, name, traits, baseCppClass> {
2730
let mnemonic = typeMnemonic;
2831
}
2932

@@ -94,6 +97,37 @@ def SInt16 : SInt<16>;
9497
def SInt32 : SInt<32>;
9598
def SInt64 : SInt<64>;
9699

100+
//===----------------------------------------------------------------------===//
101+
// FloatType
102+
//===----------------------------------------------------------------------===//
103+
104+
class CIR_FloatType<string name, string mnemonic>
105+
: CIR_Type<name, mnemonic,
106+
[
107+
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
108+
DeclareTypeInterfaceMethods<CIRFPTypeInterface>,
109+
]> {}
110+
111+
def CIR_Single : CIR_FloatType<"Single", "float"> {
112+
let summary = "CIR single-precision float type";
113+
let description = [{
114+
Floating-point type that represents the `float` type in C/C++. Its
115+
underlying floating-point format is the IEEE-754 binary32 format.
116+
}];
117+
}
118+
119+
def CIR_Double : CIR_FloatType<"Double", "double"> {
120+
let summary = "CIR double-precision float type";
121+
let description = [{
122+
Floating-point type that represents the `double` type in C/C++. Its
123+
underlying floating-point format is the IEEE-754 binar64 format.
124+
}];
125+
}
126+
127+
// Constraints
128+
129+
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;
130+
97131
//===----------------------------------------------------------------------===//
98132
// PointerType
99133
//===----------------------------------------------------------------------===//
@@ -318,7 +352,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
318352

319353
def CIR_AnyType : AnyTypeOf<[
320354
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
321-
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, AnyFloat,
355+
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, CIR_AnyFloat,
322356
]>;
323357

324358
#endif // MLIR_CIR_DIALECT_CIR_TYPES
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- CIRFPTypeInterface.h - Interface for CIR FP types -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//
9+
// Defines the interface to generically handle CIR floating-point types.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
14+
#define CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
15+
16+
#include "mlir/IR/Types.h"
17+
#include "llvm/ADT/APFloat.h"
18+
19+
/// Include the tablegen'd interface declarations.
20+
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h.inc"
21+
22+
#endif // CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- CIRFPTypeInterface.td - CIR FP Interface Definitions -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
10+
#define MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def CIRFPTypeInterface : TypeInterface<"CIRFPTypeInterface"> {
15+
let description = [{
16+
Contains helper functions to query properties about a floating-point type.
17+
}];
18+
let cppNamespace = "::mlir::cir";
19+
20+
let methods = [
21+
InterfaceMethod<[{
22+
Returns the bit width of this floating-point type.
23+
}],
24+
/*retTy=*/"unsigned",
25+
/*methodName=*/"getWidth",
26+
/*args=*/(ins),
27+
/*methodBody=*/"",
28+
/*defaultImplementation=*/[{
29+
return llvm::APFloat::semanticsSizeInBits($_type.getFloatSemantics());
30+
}]
31+
>,
32+
InterfaceMethod<[{
33+
Return the mantissa width.
34+
}],
35+
/*retTy=*/"unsigned",
36+
/*methodName=*/"getFPMantissaWidth",
37+
/*args=*/(ins),
38+
/*methodBody=*/"",
39+
/*defaultImplementation=*/[{
40+
return llvm::APFloat::semanticsPrecision($_type.getFloatSemantics());
41+
}]
42+
>,
43+
InterfaceMethod<[{
44+
Return the float semantics of this floating-point type.
45+
}],
46+
/*retTy=*/"const llvm::fltSemantics &",
47+
/*methodName=*/"getFloatSemantics"
48+
>,
49+
];
50+
}
51+
52+
#endif // MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE

clang/include/clang/CIR/Interfaces/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ function(add_clang_mlir_op_interface interface)
2020
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
2121
endfunction()
2222

23+
function(add_clang_mlir_type_interface interface)
24+
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
25+
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
26+
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
27+
add_public_tablegen_target(MLIR${interface}IncGen)
28+
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
29+
endfunction()
30+
2331
add_clang_mlir_attr_interface(ASTAttrInterfaces)
2432
add_clang_mlir_op_interface(CIROpInterfaces)
2533
add_clang_mlir_op_interface(CIRLoopOpInterface)
34+
add_clang_mlir_type_interface(CIRFPTypeInterface)

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
224224
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
225225
if (ty.isa<mlir::cir::IntType>())
226226
return mlir::cir::IntAttr::get(ty, 0);
227-
if (ty.isa<mlir::FloatType>())
228-
return mlir::FloatAttr::get(ty, 0.0);
227+
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
228+
return mlir::cir::FPAttr::getZero(fltType);
229+
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
230+
return mlir::cir::FPAttr::getZero(fltType);
229231
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
230232
return getZeroAttr(arrTy);
231233
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
@@ -256,12 +258,13 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
256258
if (const auto boolVal = attr.dyn_cast<mlir::cir::BoolAttr>())
257259
return !boolVal.getValue();
258260

259-
if (const auto fpVal = attr.dyn_cast<mlir::FloatAttr>()) {
261+
if (auto fpAttr = attr.dyn_cast<mlir::cir::FPAttr>()) {
262+
auto fpVal = fpAttr.getValue();
260263
bool ignored;
261264
llvm::APFloat FV(+0.0);
262-
FV.convert(fpVal.getValue().getSemantics(),
263-
llvm::APFloat::rmNearestTiesToEven, &ignored);
264-
return FV.bitwiseIsEqual(fpVal.getValue());
265+
FV.convert(fpVal.getSemantics(), llvm::APFloat::rmNearestTiesToEven,
266+
&ignored);
267+
return FV.bitwiseIsEqual(fpVal);
265268
}
266269

267270
if (const auto structVal = attr.dyn_cast<mlir::cir::ConstStructAttr>()) {
@@ -348,23 +351,21 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
348351
}
349352
bool isInt(mlir::Type i) { return i.isa<mlir::cir::IntType>(); }
350353

351-
mlir::FloatType getLongDouble80BitsTy() const {
352-
return typeCache.LongDouble80BitsTy;
353-
}
354+
mlir::Type getLongDouble80BitsTy() const { llvm_unreachable("NYI"); }
354355

355356
/// Get the proper floating point type for the given semantics.
356-
mlir::FloatType getFloatTyForFormat(const llvm::fltSemantics &format,
357-
bool useNativeHalf) const {
357+
mlir::Type getFloatTyForFormat(const llvm::fltSemantics &format,
358+
bool useNativeHalf) const {
358359
if (&format == &llvm::APFloat::IEEEhalf()) {
359360
llvm_unreachable("IEEEhalf float format is NYI");
360361
}
361362

362363
if (&format == &llvm::APFloat::BFloat())
363364
llvm_unreachable("BFloat float format is NYI");
364365
if (&format == &llvm::APFloat::IEEEsingle())
365-
llvm_unreachable("IEEEsingle float format is NYI");
366+
return typeCache.FloatTy;
366367
if (&format == &llvm::APFloat::IEEEdouble())
367-
llvm_unreachable("IEEEdouble float format is NYI");
368+
return typeCache.DoubleTy;
368369
if (&format == &llvm::APFloat::IEEEquad())
369370
llvm_unreachable("IEEEquad float format is NYI");
370371
if (&format == &llvm::APFloat::PPCDoubleDouble())
@@ -491,9 +492,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
491492
}
492493

493494
bool isSized(mlir::Type ty) {
494-
if (ty.isIntOrFloat() ||
495-
ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
496-
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType>())
495+
if (ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
496+
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
497+
mlir::cir::CIRFPTypeInterface>())
497498
return true;
498499
assert(0 && "Unimplemented size for type");
499500
return false;

clang/lib/CIR/CodeGen/CIRGenExprConst.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,9 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
17081708
assert(0 && "not implemented");
17091709
else {
17101710
mlir::Type ty = CGM.getCIRType(DestType);
1711-
return builder.getFloatAttr(ty, Init);
1711+
assert(ty.isa<mlir::cir::CIRFPTypeInterface>() &&
1712+
"expected floating-point type");
1713+
return CGM.getBuilder().getAttr<mlir::cir::FPAttr>(ty, Init);
17121714
}
17131715
}
17141716
case APValue::Array: {

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
165165
}
166166
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
167167
mlir::Type Ty = CGF.getCIRType(E->getType());
168+
assert(Ty.isa<mlir::cir::CIRFPTypeInterface>() &&
169+
"expect floating-point type");
168170
return Builder.create<mlir::cir::ConstantOp>(
169171
CGF.getLoc(E->getExprLoc()), Ty,
170-
Builder.getFloatAttr(Ty, E->getValue()));
172+
Builder.getAttr<mlir::cir::FPAttr>(Ty, E->getValue()));
171173
}
172174
mlir::Value VisitCharacterLiteral(const CharacterLiteral *E) {
173175
mlir::Type Ty = CGF.getCIRType(E->getType());
@@ -1227,7 +1229,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
12271229
llvm_unreachable("NYI");
12281230

12291231
assert(!UnimplementedFeature::cirVectorType());
1230-
if (Ops.LHS.getType().isa<mlir::FloatType>()) {
1232+
if (Ops.LHS.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
12311233
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
12321234
return Builder.createFSub(Ops.LHS, Ops.RHS);
12331235
}
@@ -1701,20 +1703,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
17011703
llvm_unreachable("NYI: signed bool");
17021704
if (CGF.getBuilder().isInt(DstTy)) {
17031705
CastKind = mlir::cir::CastKind::bool_to_int;
1704-
} else if (DstTy.isa<mlir::FloatType>()) {
1706+
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
17051707
CastKind = mlir::cir::CastKind::bool_to_float;
17061708
} else {
17071709
llvm_unreachable("Internal error: Cast to unexpected type");
17081710
}
17091711
} else if (CGF.getBuilder().isInt(SrcTy)) {
17101712
if (CGF.getBuilder().isInt(DstTy)) {
17111713
CastKind = mlir::cir::CastKind::integral;
1712-
} else if (DstTy.isa<mlir::FloatType>()) {
1714+
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
17131715
CastKind = mlir::cir::CastKind::int_to_float;
17141716
} else {
17151717
llvm_unreachable("Internal error: Cast to unexpected type");
17161718
}
1717-
} else if (SrcTy.isa<mlir::FloatType>()) {
1719+
} else if (SrcTy.isa<mlir::cir::CIRFPTypeInterface>()) {
17181720
if (CGF.getBuilder().isInt(DstTy)) {
17191721
// If we can't recognize overflow as undefined behavior, assume that
17201722
// overflow saturates. This protects against normal optimizations if we
@@ -1724,7 +1726,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
17241726
if (Builder.getIsFPConstrained())
17251727
llvm_unreachable("NYI");
17261728
CastKind = mlir::cir::CastKind::float_to_int;
1727-
} else if (DstTy.isa<mlir::FloatType>()) {
1729+
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
17281730
// TODO: split this to createFPExt/createFPTrunc
17291731
return Builder.createFloatingCast(Src, DstTy);
17301732
} else {

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,10 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
133133

134134
// TODO: HalfTy
135135
// TODO: BFloatTy
136-
FloatTy = builder.getF32Type();
137-
DoubleTy = builder.getF64Type();
136+
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
137+
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
138138
// TODO(cir): perhaps we should abstract long double variations into a custom
139139
// cir.long_double type. Said type would also hold the semantics for lowering.
140-
LongDouble80BitsTy = builder.getF80Type();
141140

142141
// TODO: PointerWidthInBits
143142
PointerAlignInBytes =

clang/lib/CIR/CodeGen/CIRGenTypeCache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ struct CIRGenTypeCache {
3737
// mlir::Type HalfTy, BFloatTy;
3838
// TODO(cir): perhaps we should abstract long double variations into a custom
3939
// cir.long_double type. Said type would also hold the semantics for lowering.
40-
mlir::FloatType FloatTy, DoubleTy, LongDouble80BitsTy;
40+
mlir::cir::SingleType FloatTy;
41+
mlir::cir::DoubleType DoubleTy;
4142

4243
/// int
4344
mlir::Type UIntTy;

0 commit comments

Comments
 (0)