diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 3412c7764e54..9f30541dc3f4 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -620,6 +620,95 @@ class CIRYieldOpLowering } }; +class CIRGlobalOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::cir::GlobalOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + return mlir::failure(); + + mlir::OpBuilder b(moduleOp.getContext()); + + const auto CIRSymType = op.getSymType(); + auto convertedType = getTypeConverter()->convertType(CIRSymType); + if (!convertedType) + return mlir::failure(); + auto memrefType = dyn_cast(convertedType); + if (!memrefType) + memrefType = mlir::MemRefType::get({}, convertedType); + // Add an optional alignment to the global memref. + mlir::IntegerAttr memrefAlignment = + op.getAlignment() + ? mlir::IntegerAttr::get(b.getI64Type(), op.getAlignment().value()) + : mlir::IntegerAttr(); + // Add an optional initial value to the global memref. + mlir::Attribute initialValue = mlir::Attribute(); + std::optional init = op.getInitialValue(); + if (init.has_value()) { + if (auto constArr = init.value().dyn_cast()) { + if (memrefType.getShape().size()) { + auto rtt = mlir::RankedTensorType::get(memrefType.getShape(), + memrefType.getElementType()); + initialValue = mlir::DenseIntElementsAttr::get(rtt, 0); + } else { + auto rtt = mlir::RankedTensorType::get({}, convertedType); + initialValue = mlir::DenseIntElementsAttr::get(rtt, 0); + } + } else if (auto intAttr = init.value().dyn_cast()) { + auto rtt = mlir::RankedTensorType::get({}, convertedType); + initialValue = mlir::DenseIntElementsAttr::get(rtt, intAttr.getValue()); + } else if (auto fltAttr = init.value().dyn_cast()) { + auto rtt = mlir::RankedTensorType::get({}, convertedType); + initialValue = mlir::DenseFPElementsAttr::get(rtt, fltAttr.getValue()); + } else if (auto boolAttr = init.value().dyn_cast()) { + auto rtt = mlir::RankedTensorType::get({}, convertedType); + initialValue = + mlir::DenseIntElementsAttr::get(rtt, (char)boolAttr.getValue()); + } else + llvm_unreachable( + "GlobalOp lowering with initial value is not fully supported yet"); + } + + // Add symbol visibility + std::string sym_visibility = op.isPrivate() ? "private" : "public"; + + rewriter.replaceOpWithNewOp( + op, b.getStringAttr(op.getSymName()), + /*sym_visibility=*/b.getStringAttr(sym_visibility), + /*type=*/memrefType, initialValue, + /*constant=*/op.getConstant(), + /*alignment=*/memrefAlignment); + + return mlir::success(); + } +}; + +class CIRGetGlobalOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::GetGlobalOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. + // CIRGen should mitigate this and not emit the get_global. + if (op->getUses().empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + auto type = getTypeConverter()->convertType(op.getType()); + auto symbol = op.getName(); + rewriter.replaceOpWithNewOp(op, type, symbol); + return mlir::success(); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); @@ -628,8 +717,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering, - CIRYieldOpLowering, CIRCosOpLowering>(converter, - patterns.getContext()); + CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering, + CIRGetGlobalOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { @@ -639,6 +728,8 @@ static mlir::TypeConverter prepareTypeConverter() { // FIXME: The pointee type might not be converted (e.g. struct) if (!ty) return nullptr; + if (isa(type.getPointee())) + return ty; return mlir::MemRefType::get({}, ty); }); converter.addConversion( @@ -669,8 +760,17 @@ static mlir::TypeConverter prepareTypeConverter() { return converter.convertType(type.getUnderlying()); }); converter.addConversion([&](mlir::cir::ArrayType type) -> mlir::Type { - auto elementType = converter.convertType(type.getEltType()); - return mlir::MemRefType::get(type.getSize(), elementType); + SmallVector shape; + mlir::Type curType = type; + while (auto arrayType = dyn_cast(curType)) { + shape.push_back(arrayType.getSize()); + curType = arrayType.getEltType(); + } + auto elementType = converter.convertType(curType); + // FIXME: The element type might not be converted (e.g. struct) + if (!elementType) + return nullptr; + return mlir::MemRefType::get(shape, elementType); }); return converter; diff --git a/clang/test/CIR/Lowering/ThroughMLIR/global.cir b/clang/test/CIR/Lowering/ThroughMLIR/global.cir new file mode 100644 index 000000000000..3b1ed83239c6 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/global.cir @@ -0,0 +1,55 @@ +// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR +// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM + +!u32i = !cir.int +module { + cir.global external @i = #cir.int<2> : !u32i + cir.global external @f = #cir.fp<3.000000e+00> : !cir.float + cir.global external @b = #cir.bool : !cir.bool + cir.global "private" external @a : !cir.array + cir.global external @aa = #cir.zero : !cir.array x 256> + + cir.func @get_global_int_value() -> !u32i { + %0 = cir.get_global @i : cir.ptr + %1 = cir.load %0 : cir.ptr , !u32i + cir.return %1 : !u32i + } + cir.func @get_global_float_value() -> !cir.float { + %0 = cir.get_global @f : cir.ptr + %1 = cir.load %0 : cir.ptr , !cir.float + cir.return %1 : !cir.float + } + cir.func @get_global_bool_value() -> !cir.bool { + %0 = cir.get_global @b : cir.ptr + %1 = cir.load %0 : cir.ptr , !cir.bool + cir.return %1 : !cir.bool + } + cir.func @get_global_array_pointer() -> !cir.ptr> { + %0 = cir.get_global @a : cir.ptr > + cir.return %0 : !cir.ptr> + } + cir.func @get_global_multi_array_pointer() -> !cir.ptr x 256>> { + %0 = cir.get_global @aa : cir.ptr x 256>> + cir.return %0 : !cir.ptr x 256>> + } +} + +// MLIR: memref.global "public" @i : memref = dense<2> +// MLIR: memref.global "public" @f : memref = dense<3.000000e+00> +// MLIR: memref.global "public" @b : memref = dense<1> +// MLIR: memref.global "private" @a : memref<100xi32> +// MLIR: memref.global "public" @aa : memref<256x256xi32> = dense<0> +// MLIR: memref.get_global @i : memref +// MLIR: memref.get_global @f : memref +// MLIR: memref.get_global @b : memref +// MLIR: memref.get_global @a : memref<100xi32> +// MLIR: memref.get_global @aa : memref<256x256xi32> + +// LLVM: @i = global i32 2 +// LLVM: @f = global float 3.000000e+00 +// LLVM: @b = global i8 1 +// LLVM: @a = private global [100 x i32] undef +// LLVM: @aa = global [256 x [256 x i32]] zeroinitializer +// LLVM: load i32, ptr @i +// LLVM: load float, ptr @f +// LLVM: load i8, ptr @b