Skip to content

Commit f7b9151

Browse files
authored
[CIR] Add attribute visitor for lowering globals (#1318)
This adds a new mlir-tablegen option to generate a .inc file with the complete set of attrdefs defined in a .td file and uses the file generated for CIR attrdefs to create an attr visitor. This visitor is used in the lowering of global variables directly to LLVM IR. The purpose of this change is to align the incubator lowering implementation with the recent upstream changes to make future upstreaming easier, while also fulfilling the upstream request to have the visitor be based on a tablegen created file. The new mlir-tablegen feature will be upstreamed after it is established here. No observable change is intended in the CIR code.
1 parent a07dbdf commit f7b9151

File tree

6 files changed

+141
-98
lines changed

6 files changed

+141
-98
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines the CirAttrVisitor interface.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
14+
#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
15+
16+
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
17+
18+
namespace cir {
19+
20+
#define DISPATCH(NAME) return getImpl()->visitCir##NAME(cirAttr);
21+
22+
template <typename ImplClass, typename RetTy> class CirAttrVisitor {
23+
public:
24+
RetTy visit(mlir::Attribute attr) {
25+
#define ATTRDEF(NAME) \
26+
if (const auto cirAttr = mlir::dyn_cast<cir::NAME>(attr)) \
27+
DISPATCH(NAME);
28+
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"
29+
llvm_unreachable("unhandled attribute type");
30+
}
31+
32+
// If the implementation chooses not to implement a certain visit
33+
// method, fall back to the parent.
34+
#define ATTRDEF(NAME) \
35+
RetTy visitCir##NAME(NAME cirAttr) { DISPATCH(Attr); }
36+
#include "clang/CIR/Dialect/IR/CIRAttrDefsList.inc"
37+
38+
RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }
39+
40+
ImplClass *getImpl() { return static_cast<ImplClass *>(this); }
41+
};
42+
43+
#undef DISPATCH
44+
45+
} // namespace cir
46+
47+
#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H

clang/include/clang/CIR/Dialect/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mlir_tablegen(CIROpsStructs.h.inc -gen-attrdef-decls)
2626
mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
2727
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
2828
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
29+
mlir_tablegen(CIRAttrDefsList.inc -gen-attrdef-list)
2930
add_public_tablegen_target(MLIRCIREnumsGen)
3031

3132
clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 59 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
4242
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4343
#include "mlir/Target/LLVMIR/Export.h"
44+
#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
4445
#include "clang/CIR/Dialect/Passes.h"
4546
#include "clang/CIR/LoweringHelpers.h"
4647
#include "clang/CIR/MissingFeatures.h"
@@ -425,32 +426,52 @@ emitCirAttrToMemory(mlir::Operation *parentOp, mlir::Attribute attr,
425426
}
426427

427428
/// Switches on the type of attribute and calls the appropriate conversion.
429+
class CirAttrToValue : public CirAttrVisitor<CirAttrToValue, mlir::Value> {
430+
public:
431+
CirAttrToValue(mlir::Operation *parentOp,
432+
mlir::ConversionPatternRewriter &rewriter,
433+
const mlir::TypeConverter *converter,
434+
mlir::DataLayout const &dataLayout)
435+
: parentOp(parentOp), rewriter(rewriter), converter(converter),
436+
dataLayout(dataLayout) {}
437+
438+
mlir::Value visitCirIntAttr(cir::IntAttr attr);
439+
mlir::Value visitCirFPAttr(cir::FPAttr attr);
440+
mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr attr);
441+
mlir::Value visitCirConstStructAttr(cir::ConstStructAttr attr);
442+
mlir::Value visitCirConstArrayAttr(cir::ConstArrayAttr attr);
443+
mlir::Value visitCirConstVectorAttr(cir::ConstVectorAttr attr);
444+
mlir::Value visitCirBoolAttr(cir::BoolAttr attr);
445+
mlir::Value visitCirZeroAttr(cir::ZeroAttr attr);
446+
mlir::Value visitCirUndefAttr(cir::UndefAttr attr);
447+
mlir::Value visitCirPoisonAttr(cir::PoisonAttr attr);
448+
mlir::Value visitCirGlobalViewAttr(cir::GlobalViewAttr attr);
449+
mlir::Value visitCirVTableAttr(cir::VTableAttr attr);
450+
mlir::Value visitCirTypeInfoAttr(cir::TypeInfoAttr attr);
451+
452+
private:
453+
mlir::Operation *parentOp;
454+
mlir::ConversionPatternRewriter &rewriter;
455+
const mlir::TypeConverter *converter;
456+
mlir::DataLayout const &dataLayout;
457+
};
428458

429459
/// IntAttr visitor.
430-
static mlir::Value
431-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::IntAttr intAttr,
432-
mlir::ConversionPatternRewriter &rewriter,
433-
const mlir::TypeConverter *converter) {
460+
mlir::Value CirAttrToValue::visitCirIntAttr(cir::IntAttr intAttr) {
434461
auto loc = parentOp->getLoc();
435462
return rewriter.create<mlir::LLVM::ConstantOp>(
436463
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
437464
}
438465

439466
/// BoolAttr visitor.
440-
static mlir::Value
441-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::BoolAttr boolAttr,
442-
mlir::ConversionPatternRewriter &rewriter,
443-
const mlir::TypeConverter *converter) {
467+
mlir::Value CirAttrToValue::visitCirBoolAttr(cir::BoolAttr boolAttr) {
444468
auto loc = parentOp->getLoc();
445469
return rewriter.create<mlir::LLVM::ConstantOp>(
446470
loc, converter->convertType(boolAttr.getType()), boolAttr.getValue());
447471
}
448472

449473
/// ConstPtrAttr visitor.
450-
static mlir::Value
451-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
452-
mlir::ConversionPatternRewriter &rewriter,
453-
const mlir::TypeConverter *converter) {
474+
mlir::Value CirAttrToValue::visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
454475
auto loc = parentOp->getLoc();
455476
if (ptrAttr.isNullValue()) {
456477
return rewriter.create<mlir::LLVM::ZeroOp>(
@@ -465,51 +486,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstPtrAttr ptrAttr,
465486
}
466487

467488
/// FPAttr visitor.
468-
static mlir::Value
469-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::FPAttr fltAttr,
470-
mlir::ConversionPatternRewriter &rewriter,
471-
const mlir::TypeConverter *converter) {
489+
mlir::Value CirAttrToValue::visitCirFPAttr(cir::FPAttr fltAttr) {
472490
auto loc = parentOp->getLoc();
473491
return rewriter.create<mlir::LLVM::ConstantOp>(
474492
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
475493
}
476494

477495
/// ZeroAttr visitor.
478-
static mlir::Value
479-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ZeroAttr zeroAttr,
480-
mlir::ConversionPatternRewriter &rewriter,
481-
const mlir::TypeConverter *converter) {
496+
mlir::Value CirAttrToValue::visitCirZeroAttr(cir::ZeroAttr zeroAttr) {
482497
auto loc = parentOp->getLoc();
483498
return rewriter.create<mlir::LLVM::ZeroOp>(
484499
loc, converter->convertType(zeroAttr.getType()));
485500
}
486501

487502
/// UndefAttr visitor.
488-
static mlir::Value
489-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
490-
mlir::ConversionPatternRewriter &rewriter,
491-
const mlir::TypeConverter *converter) {
503+
mlir::Value CirAttrToValue::visitCirUndefAttr(cir::UndefAttr undefAttr) {
492504
auto loc = parentOp->getLoc();
493505
return rewriter.create<mlir::LLVM::UndefOp>(
494506
loc, converter->convertType(undefAttr.getType()));
495507
}
496508

497509
/// PoisonAttr visitor.
498-
static mlir::Value
499-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
500-
mlir::ConversionPatternRewriter &rewriter,
501-
const mlir::TypeConverter *converter) {
510+
mlir::Value CirAttrToValue::visitCirPoisonAttr(cir::PoisonAttr poisonAttr) {
502511
auto loc = parentOp->getLoc();
503512
return rewriter.create<mlir::LLVM::PoisonOp>(
504513
loc, converter->convertType(poisonAttr.getType()));
505514
}
506515

507516
/// ConstStruct visitor.
508-
static mlir::Value
509-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
510-
mlir::ConversionPatternRewriter &rewriter,
511-
const mlir::TypeConverter *converter,
512-
mlir::DataLayout const &dataLayout) {
517+
mlir::Value
518+
CirAttrToValue::visitCirConstStructAttr(cir::ConstStructAttr constStruct) {
513519
auto llvmTy = converter->convertType(constStruct.getType());
514520
auto loc = parentOp->getLoc();
515521
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
@@ -525,49 +531,37 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
525531
}
526532

527533
// VTableAttr visitor.
528-
static mlir::Value
529-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::VTableAttr vtableArr,
530-
mlir::ConversionPatternRewriter &rewriter,
531-
const mlir::TypeConverter *converter,
532-
mlir::DataLayout const &dataLayout) {
534+
mlir::Value CirAttrToValue::visitCirVTableAttr(cir::VTableAttr vtableArr) {
533535
auto llvmTy = converter->convertType(vtableArr.getType());
534536
auto loc = parentOp->getLoc();
535537
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
536538

537539
for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) {
538-
mlir::Value init =
539-
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
540+
mlir::Value init = visit(elt);
540541
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
541542
}
542543

543544
return result;
544545
}
545546

546547
// TypeInfoAttr visitor.
547-
static mlir::Value
548-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::TypeInfoAttr typeinfoArr,
549-
mlir::ConversionPatternRewriter &rewriter,
550-
const mlir::TypeConverter *converter,
551-
mlir::DataLayout const &dataLayout) {
548+
mlir::Value
549+
CirAttrToValue::visitCirTypeInfoAttr(cir::TypeInfoAttr typeinfoArr) {
552550
auto llvmTy = converter->convertType(typeinfoArr.getType());
553551
auto loc = parentOp->getLoc();
554552
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
555553

556554
for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) {
557-
mlir::Value init =
558-
lowerCirAttrAsValue(parentOp, elt, rewriter, converter, dataLayout);
555+
mlir::Value init = visit(elt);
559556
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
560557
}
561558

562559
return result;
563560
}
564561

565562
// ConstArrayAttr visitor
566-
static mlir::Value
567-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
568-
mlir::ConversionPatternRewriter &rewriter,
569-
const mlir::TypeConverter *converter,
570-
mlir::DataLayout const &dataLayout) {
563+
mlir::Value
564+
CirAttrToValue::visitCirConstArrayAttr(cir::ConstArrayAttr constArr) {
571565
auto llvmTy = converter->convertType(constArr.getType());
572566
auto loc = parentOp->getLoc();
573567
mlir::Value result;
@@ -610,10 +604,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstArrayAttr constArr,
610604
}
611605

612606
// ConstVectorAttr visitor.
613-
static mlir::Value
614-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
615-
mlir::ConversionPatternRewriter &rewriter,
616-
const mlir::TypeConverter *converter) {
607+
mlir::Value
608+
CirAttrToValue::visitCirConstVectorAttr(cir::ConstVectorAttr constVec) {
617609
auto llvmTy = converter->convertType(constVec.getType());
618610
auto loc = parentOp->getLoc();
619611
SmallVector<mlir::Attribute> mlirValues;
@@ -638,11 +630,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
638630
}
639631

640632
// GlobalViewAttr visitor.
641-
static mlir::Value
642-
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
643-
mlir::ConversionPatternRewriter &rewriter,
644-
const mlir::TypeConverter *converter,
645-
mlir::DataLayout const &dataLayout) {
633+
mlir::Value
634+
CirAttrToValue::visitCirGlobalViewAttr(cir::GlobalViewAttr globalAttr) {
646635
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
647636
mlir::Type sourceType;
648637
unsigned sourceAddrSpace = 0;
@@ -716,43 +705,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
716705
}
717706

718707
/// Switches on the type of attribute and calls the appropriate conversion.
719-
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
708+
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
709+
const mlir::Attribute attr,
720710
mlir::ConversionPatternRewriter &rewriter,
721711
const mlir::TypeConverter *converter,
722712
mlir::DataLayout const &dataLayout) {
723-
if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
724-
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
725-
if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
726-
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
727-
if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
728-
return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter);
729-
if (const auto constStruct = mlir::dyn_cast<cir::ConstStructAttr>(attr))
730-
return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter,
731-
dataLayout);
732-
if (const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(attr))
733-
return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter,
734-
dataLayout);
735-
if (const auto constVec = mlir::dyn_cast<cir::ConstVectorAttr>(attr))
736-
return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter);
737-
if (const auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
738-
return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter);
739-
if (const auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(attr))
740-
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
741-
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
742-
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
743-
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
744-
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
745-
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
746-
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter,
747-
dataLayout);
748-
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
749-
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter,
750-
dataLayout);
751-
if (const auto typeinfoAttr = mlir::dyn_cast<cir::TypeInfoAttr>(attr))
752-
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter,
753-
dataLayout);
754-
755-
llvm_unreachable("unhandled attribute type");
713+
CirAttrToValue valueConverter(parentOp, rewriter, converter, dataLayout);
714+
auto value = valueConverter.visit(attr);
715+
if (!value)
716+
llvm_unreachable("unhandled attribute type");
717+
return value;
756718
}
757719

758720
//===----------------------------------------------------------------------===//
@@ -1734,8 +1696,8 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
17341696
// Regardless of the type, we should lower the constant of poison value
17351697
// into PoisonOp.
17361698
if (auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr)) {
1737-
rewriter.replaceOp(
1738-
op, lowerCirAttrAsValue(op, poisonAttr, rewriter, getTypeConverter()));
1699+
rewriter.replaceOp(op, lowerCirAttrAsValue(op, poisonAttr, rewriter,
1700+
getTypeConverter(), dataLayout));
17391701
return mlir::success();
17401702
}
17411703

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace direct {
2222

2323
/// Convert a CIR attribute to an LLVM attribute. May use the datalayout for
2424
/// lowering attributes to-be-stored in memory.
25-
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
25+
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
26+
const mlir::Attribute attr,
2627
mlir::ConversionPatternRewriter &rewriter,
2728
const mlir::TypeConverter *converter,
2829
mlir::DataLayout const &dataLayout);

mlir/test/mlir-tblgen/attrdefs.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
22
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
3+
// RUN: mlir-tblgen -gen-attrdef-list -I %S/../../include %s | FileCheck %s --check-prefix=LIST
34

45
include "mlir/IR/AttrTypeBase.td"
56
include "mlir/IR/OpBase.td"
@@ -19,6 +20,13 @@ include "mlir/IR/OpBase.td"
1920
// DEF: ::test::CompoundAAttr,
2021
// DEF: ::test::SingleParameterAttr
2122

23+
// LIST: ATTRDEF(IndexAttr)
24+
// LIST: ATTRDEF(SimpleAAttr)
25+
// LIST: ATTRDEF(CompoundAAttr)
26+
// LIST: ATTRDEF(SingleParameterAttr)
27+
28+
// LIST: #undef ATTRDEF
29+
2230
// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
2331
// DEF-SAME: ::mlir::AsmParser &parser,
2432
// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,

0 commit comments

Comments
 (0)