Skip to content

[CIR][CIRGen] Add complex type and its CIRGen support #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,35 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return getPointerTo(::mlir::cir::VoidType::get(getContext()), langAS);
}

mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
}

mlir::TypedAttr getZeroAttr(mlir::Type t) {
return mlir::cir::ZeroAttr::get(getContext(), t);
}

mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (mlir::isa<mlir::cir::IntType>(ty))
return mlir::cir::IntAttr::get(ty, 0);
if (auto fltType = mlir::dyn_cast<mlir::cir::SingleType>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
return getZeroAttr(complexType);
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
return getZeroAttr(arrTy);
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
return getConstPtrAttr(ptrTy, 0);
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
return getZeroAttr(structTy);
if (mlir::isa<mlir::cir::BoolType>(ty)) {
return getCIRBoolAttr(false);
}
llvm_unreachable("Zero initializer for given type is NYI");
}

mlir::Value createLoad(mlir::Location loc, mlir::Value ptr,
bool isVolatile = false, uint64_t alignment = 0) {
mlir::IntegerAttr intAttr;
Expand Down
84 changes: 84 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,90 @@ def BinOpOverflowOp : CIR_Op<"binop.overflow", [Pure, SameTypeOperands]> {
];
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//

def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let summary = "Create a complex value from its real and imaginary parts";
let description = [{
`cir.complex.create` operation takes two operands that represent the real
and imaginary part of a complex number, and yields the complex number.

Example:

```mlir
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
%2 = cir.complex.create %0, %1 : !cir.complex<!cir.double>
```
}];

let results = (outs CIR_ComplexType:$result);
let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag);

let assemblyFormat = [{
$real `,` $imag
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ComplexRealPtrOp and ComplexImagPtrOp
//===----------------------------------------------------------------------===//

def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
let summary = "Extract the real part of a complex value";
let description = [{
`cir.complex.real_ptr` operation takes a pointer operand that points to a
complex value of type `!cir.complex` and yields a pointer to the real part
of the operand.

Example:

```mlir
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
```
}];

let results = (outs PrimitiveIntOrFPPtr:$result);
let arguments = (ins ComplexPtr:$operand);

let assemblyFormat = [{
$operand `:`
qualified(type($operand)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

def ComplexImagPtrOp : CIR_Op<"complex.imag_ptr", [Pure]> {
let summary = "Extract the imaginary part of a complex value";
let description = [{
`cir.complex.imag_ptr` operation takes a pointer operand that points to a
complex value of type `!cir.complex` and yields a pointer to the imaginary
part of the operand.

Example:

```mlir
%1 = cir.complex.imag_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
```
}];

let results = (outs PrimitiveIntOrFPPtr:$result);
let arguments = (ins ComplexPtr:$operand);

let assemblyFormat = [{
$operand `:`
qualified(type($operand)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// BitsOp
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 35 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,32 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

def CIR_ComplexType : CIR_Type<"Complex", "complex",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR complex type";
let description = [{
CIR type that represents a C complex number. `cir.complex` models the C type
`T _Complex`.

The parameter `elementTy` gives the type of the real and imaginary part of
the complex number. `elementTy` must be either a CIR integer type or a CIR
floating-point type.
}];

let parameters = (ins "mlir::Type":$elementTy);

let assemblyFormat = [{
`<` $elementTy `>`
}];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -441,6 +467,14 @@ def PrimitiveIntOrFPPtr : Type<
]>, "{int,void}*"> {
}

def ComplexPtr : Type<
And<[
CPred<"::mlir::isa<::mlir::cir::PointerType>($_self)">,
CPred<"::mlir::isa<::mlir::cir::ComplexType>("
"::mlir::cast<::mlir::cir::PointerType>($_self).getPointee())">,
]>, "!cir.complex*"> {
}

// Pointer to struct
def StructPtr : Type<
And<[
Expand Down Expand Up @@ -516,7 +550,7 @@ def CIR_StructType : Type<CPred<"::mlir::isa<::mlir::cir::StructType>($_self)">,
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
CIR_AnyFloat, CIR_FP16, CIR_BFloat16, CIR_ComplexType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
55 changes: 44 additions & 11 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
}

mlir::TypedAttr getZeroAttr(mlir::Type t) {
return mlir::cir::ZeroAttr::get(getContext(), t);
}

mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
}

mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
return getConstPtrAttr(t, 0);
Expand Down Expand Up @@ -265,6 +257,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
return mlir::cir::FPAttr::getZero(fltType);
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
return getZeroAttr(complexType);
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
return getZeroAttr(arrTy);
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
Expand Down Expand Up @@ -764,6 +758,42 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::GetMemberOp>(loc, result, base, name, index);
}

mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
mlir::Value imag) {
auto resultComplexTy =
mlir::cir::ComplexType::get(getContext(), real.getType());
return create<mlir::cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}

/// Create a cir.complex.real_ptr operation that derives a pointer to the real
/// part of the complex value pointed to by the specified pointer value.
mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
auto srcPtrTy = mlir::cast<mlir::cir::PointerType>(value.getType());
auto srcComplexTy =
mlir::cast<mlir::cir::ComplexType>(srcPtrTy.getPointee());
return create<mlir::cir::ComplexRealPtrOp>(
loc, getPointerTo(srcComplexTy.getElementTy()), value);
}

Address createRealPtr(mlir::Location loc, Address addr) {
return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
}

/// Create a cir.complex.imag_ptr operation that derives a pointer to the
/// imaginary part of the complex value pointed to by the specified pointer
/// value.
mlir::Value createImagPtr(mlir::Location loc, mlir::Value value) {
auto srcPtrTy = mlir::cast<mlir::cir::PointerType>(value.getType());
auto srcComplexTy =
mlir::cast<mlir::cir::ComplexType>(srcPtrTy.getPointee());
return create<mlir::cir::ComplexImagPtrOp>(
loc, getPointerTo(srcComplexTy.getElementTy()), value);
}

Address createImagPtr(mlir::Location loc, Address addr) {
return Address{createImagPtr(loc, addr.getPointer()), addr.getAlignment()};
}

/// Cast the element type of the given address to a different type,
/// preserving information like the alignment.
cir::Address createElementBitCast(mlir::Location loc, cir::Address addr,
Expand All @@ -776,15 +806,18 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
addr.getAlignment());
}

mlir::Value createLoad(mlir::Location loc, Address addr) {
mlir::Value createLoad(mlir::Location loc, Address addr,
bool isVolatile = false) {
auto ptrTy =
mlir::dyn_cast<mlir::cir::PointerType>(addr.getPointer().getType());
if (addr.getElementType() != ptrTy.getPointee())
addr = addr.withPointer(
createPtrBitcast(addr.getPointer(), addr.getElementType()));

return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
addr.getPointer());
return create<mlir::cir::LoadOp>(
loc, addr.getElementType(), addr.getPointer(), /*isDeref=*/false,
/*is_volatile=*/isVolatile, /*alignment=*/mlir::IntegerAttr{},
/*mem_order=*/mlir::cir::MemOrderAttr{});
}

mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,
Expand Down
6 changes: 5 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,11 @@ void CIRGenFunction::buildExprAsInit(const Expr *init, const ValueDecl *D,
buildScalarInit(init, getLoc(D->getSourceRange()), lvalue);
return;
case TEK_Complex: {
assert(0 && "not implemented");
mlir::Value complex = buildComplexExpr(init);
if (capturedByInit)
llvm_unreachable("NYI");
buildStoreOfComplex(getLoc(init->getExprLoc()), complex, lvalue,
/*init*/ true);
return;
}
case TEK_Aggregate:
Expand Down
29 changes: 26 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
}

case TEK_Complex:
assert(0 && "not implemented");
return buildComplexAssignmentLValue(E);
case TEK_Aggregate:
assert(0 && "not implemented");
}
Expand Down Expand Up @@ -1260,6 +1260,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
if (E->getOpcode() == UO_Extension)
return buildLValue(E->getSubExpr());

QualType ExprTy = getContext().getCanonicalType(E->getSubExpr()->getType());
switch (E->getOpcode()) {
default:
llvm_unreachable("Unknown unary operator lvalue!");
Expand All @@ -1284,7 +1285,29 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
}
case UO_Real:
case UO_Imag: {
assert(0 && "not implemented");
LValue LV = buildLValue(E->getSubExpr());
assert(LV.isSimple() && "real/imag on non-ordinary l-value");

// __real is valid on scalars. This is a faster way of testing that.
// __imag can only produce an rvalue on scalars.
if (E->getOpcode() == UO_Real &&
!mlir::isa<mlir::cir::ComplexType>(LV.getAddress().getElementType())) {
assert(E->getSubExpr()->getType()->isArithmeticType());
return LV;
}

QualType T = ExprTy->castAs<clang::ComplexType>()->getElementType();

auto Loc = getLoc(E->getExprLoc());
Address Component =
(E->getOpcode() == UO_Real
? buildAddrOfRealComponent(Loc, LV.getAddress(), LV.getType())
: buildAddrOfImagComponent(Loc, LV.getAddress(), LV.getType()));
// TODO(cir): TBAA info.
assert(!MissingFeatures::tbaa());
LValue ElemLV = makeAddrLValue(Component, T, LV.getBaseInfo());
ElemLV.getQuals().addQualifiers(LV.getQuals());
return ElemLV;
}
case UO_PreInc:
case UO_PreDec: {
Expand All @@ -1311,7 +1334,7 @@ RValue CIRGenFunction::buildAnyExpr(const Expr *E, AggValueSlot aggSlot,
case TEK_Scalar:
return RValue::get(buildScalarExpr(E));
case TEK_Complex:
assert(0 && "not implemented");
return RValue::getComplex(buildComplexExpr(E));
case TEK_Aggregate: {
if (!ignoreResult && aggSlot.isIgnored())
aggSlot = CreateAggTemp(E->getType(), getLoc(E->getSourceRange()),
Expand Down
Loading
Loading