Skip to content

IRDL tests #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/IRDL/IRDLRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H

#include "Dyn/Dialect/IRDL/IR/IRDL.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/LogicalResult.h"

Expand All @@ -26,4 +25,4 @@ LogicalResult registerDialects(ModuleOp op);
} // namespace irdl
} // namespace mlir

#endif // MLIR_DIALECT_IRDL__IRDLREGISTRATION_H
#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H

#include "mlir/Dialect/IRDL/TypeWrapper.h"
#include "mlir/IR/ExtensibleDialect.h"
Expand Down Expand Up @@ -136,4 +136,4 @@ class AnyTypeConstraint : public TypeConstraint {
} // namespace irdl
} // namespace mlir

#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
#endif // MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
Expand Down Expand Up @@ -92,6 +93,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
func::FuncDialect,
gpu::GPUDialect,
index::IndexDialect,
irdl::IRDLDialect,
LLVM::LLVMDialect,
linalg::LinalgDialect,
math::MathDialect,
Expand Down
31 changes: 18 additions & 13 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MemoryBuffer;

namespace mlir {
class DialectRegistry;
class MLIRContext;
class PassPipelineCLParser;
class PassManager;

Expand Down Expand Up @@ -54,34 +55,38 @@ using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
/// - implicitModule will enable implicit addition of a top-level
/// 'builtin.module' if one doesn't already exist.
/// - dumpPassPipeline will dump the pipeline being run to stderr
/// - context is provided if the caller wants to provide the context.
LogicalResult MlirOptMain(
llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline, DialectRegistry &registry,
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext = false,
bool emitBytecode = false, bool implicitModule = false,
bool dumpPassPipeline = false, MLIRContext *context = nullptr);

/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
/// apply on the loaded IR.
LogicalResult
MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr<llvm::MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline, DialectRegistry &registry,
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext = false, bool emitBytecode = false,
bool implicitModule = false, bool dumpPassPipeline = false);

/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
/// apply on the loaded IR.
LogicalResult MlirOptMain(
llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext = false,
bool emitBytecode = false, bool implicitModule = false);
bool implicitModule = false, MLIRContext *context = nullptr);

/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
/// - registry should contain all the dialects that can be parsed in the source.
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
/// - context should be given if the caller wants to provide the context.
LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry &registry,
bool preloadDialectsInContext = false);
bool preloadDialectsInContext = false,
MLIRContext *context = nullptr);

/// Helper wrapper to return the result of MlirOptMain directly from main.
///
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/IRDL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRIRDL
IRDLRegistration.cpp
IRDLVerifiers.cpp
IRDLContext.cpp
TypeWrapper.cpp

DEPENDS
MLIRIRDLIncGen
Expand Down
39 changes: 39 additions & 0 deletions mlir/lib/Dialect/IRDL/TypeWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- TypeWrapper.cpp - IRDL type wrapper definition -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/IRDL/TypeWrapper.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"

namespace mlir {
namespace irdl {

DynamicTypeDefinition *findDynamicType(MLIRContext &ctx, StringRef type) {
auto splitted = type.split('.');
auto dialectName = splitted.first;
auto typeName = splitted.second;

auto dialect = ctx.getOrLoadDialect(dialectName);
if (!dialect)
return nullptr;

auto extensibleDialect = llvm::dyn_cast<ExtensibleDialect>(dialect);
if (!extensibleDialect)
return nullptr;

return extensibleDialect->lookupTypeDefinition(typeName);
}

TypeWrapper *findTypeWrapper(MLIRContext &ctx, StringRef type) {
IRDLDialect *irdl = ctx.getLoadedDialect<IRDLDialect>();
assert(irdl && "irdl is not registered");

return irdl->getTypeWrapper(type);
}

} // namespace irdl
} // namespace mlir
114 changes: 85 additions & 29 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLRegistration.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -113,55 +115,61 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule,
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
llvm::ThreadPool *threadPool) {
llvm::ThreadPool *threadPool, MLIRContext *context) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());

// Check that if a context was provided, it is single threaded.
if (context)
assert(!context->isMultithreadingEnabled() &&
"provided context must be single threaded");

// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
MLIRContext bufferContext(registry, MLIRContext::Threading::DISABLED);
if (!context)
context = &bufferContext;

if (threadPool)
context.setThreadPool(*threadPool);
context->setThreadPool(*threadPool);

// Parse the input file.
if (preloadDialectsInContext)
context.loadAllAvailableDialects();
context.allowUnregisteredDialects(allowUnregisteredDialects);
context->loadAllAvailableDialects();
context->allowUnregisteredDialects(allowUnregisteredDialects);
if (verifyDiagnostics)
context.printOpOnDiagnostic(false);
context.getDebugActionManager().registerActionHandler<DebugCounter>();
context->printOpOnDiagnostic(false);
context->getDebugActionManager().registerActionHandler<DebugCounter>();

// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
&context, passManagerSetupFn, emitBytecode,
context, passManagerSetupFn, emitBytecode,
implicitModule);
}

SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, context);

// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, context,
passManagerSetupFn, emitBytecode, implicitModule);

// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
return sourceMgrHandler.verify();
}

LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule) {
LogicalResult mlir::MlirOptMain(
raw_ostream &outputStream, std::unique_ptr<MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule, MLIRContext *context) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
Expand All @@ -181,18 +189,21 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, emitBytecode, implicitModule,
passManagerSetupFn, registry, threadPool);
passManagerSetupFn, registry, threadPool, context);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
splitInputFile, /*insertMarkerInOutput=*/true);
}

LogicalResult mlir::MlirOptMain(
raw_ostream &outputStream, std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline, DialectRegistry &registry,
bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule, bool dumpPassPipeline) {
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext,
bool emitBytecode, bool implicitModule,
bool dumpPassPipeline, MLIRContext *context) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
Expand All @@ -209,12 +220,48 @@ LogicalResult mlir::MlirOptMain(
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext,
emitBytecode, implicitModule);
emitBytecode, implicitModule, context);
}

LogicalResult registerIRDL(StringRef irdlFile, MLIRContext *ctx) {
DialectRegistry registry;
registry.insert<irdl::IRDLDialect>();
ctx->appendDialectRegistry(registry);

// Set up the input file.
std::string errorMessage;
auto file = openInputFile(irdlFile, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return failure();
}

// Give the buffer to the source manager.
// This will be picked up by the parser.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());

SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, ctx);

// Disable multi-threading when parsing the input file. This removes the
// unnecessary/costly context synchronization when parsing.
bool wasThreadingEnabled = ctx->isMultithreadingEnabled();
ctx->disableMultithreading();

// Parse the input file.
auto module(parseSourceFile<ModuleOp>(sourceMgr, ctx));

// Register IRDL dialects.
irdl::registerDialects(module.get());
ctx->enableMultithreading(wasThreadingEnabled);

return failure(!module);
}

LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry &registry,
bool preloadDialectsInContext) {
bool preloadDialectsInContext,
MLIRContext *context) {
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));

Expand Down Expand Up @@ -261,6 +308,9 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
"dump-pass-pipeline", cl::desc("Print the pipeline that will be run"),
cl::init(false)};

static cl::opt<std::string> irdlFile("irdl-file", cl::desc("IRDL file"),
cl::value_desc("filename"));

InitLLVM y(argc, argv);

// Register any command line options.
Expand All @@ -281,6 +331,12 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
// Parse pass names in main to ensure static initialization completed.
cl::ParseCommandLineOptions(argc, argv, helpHeader);

if (irdlFile != "") {
assert(context && "context should be initialized");
if (failed(registerIRDL(irdlFile, context)))
return failure();
}

if (showDialects) {
llvm::outs() << "Available Dialects:\n";
interleave(
Expand All @@ -307,7 +363,7 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext,
emitBytecode, /*implicitModule=*/!noImplicitModule,
dumpPassPipeline)))
dumpPassPipeline, context)))
return failure();

// Keep the output file if the invocation of MlirOptMain was successful.
Expand Down
Loading