Skip to content

[IRDL2CPP] How about this templater? #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ constexpr char opDefTemplateText[] =
#include "Templates/OperationDef.txt"
;

constexpr auto typeDefTempl =
#include "Templates/TypeDefTest.cpp"
;

namespace {

struct DialectStrings {
Expand Down Expand Up @@ -341,9 +337,14 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
}),
"\n"));

const auto typeIdDefinitions = llvm::join(llvm::map_range(typeNames, [&](StringRef name) -> std::string {
return llvm::formatv("MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})", name, dialectStrings.namespacePath);
}), "\n");
const auto typeIdDefinitions =
llvm::join(llvm::map_range(typeNames,
[&](StringRef name) -> std::string {
return llvm::formatv(
"MLIR_DEFINE_EXPLICIT_TYPE_ID({1}::{0})",
name, dialectStrings.namespacePath);
}),
"\n");

output << llvm::formatv(
typeDefTemplateText, commaSeparatedTypeList, generatedTypeParser,
Expand Down Expand Up @@ -404,12 +405,12 @@ void {0}::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState,
"::mlir::Value {0}, ", attr);
}),
""),
llvm::join(llvm::map_range(
opStrings.opOperandNames,
[](StringRef attr) -> std::string {
return llvm::formatv(
" odsState.addOperands({0});", attr);
}),
llvm::join(llvm::map_range(opStrings.opOperandNames,
[](StringRef attr) -> std::string {
return llvm::formatv(
" odsState.addOperands({0});",
attr);
}),
"\n"),
llvm::join(llvm::map_range(opStrings.opResultNames,
[](StringRef attr) -> std::string {
Expand All @@ -419,10 +420,9 @@ void {0}::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState,
}),
"\n"));
return llvm::formatv(
perOpDefTemplateText, opStrings.opCppName, operandCount,
resultCount, buildDefinition,
dialectStrings.namespaceOpen, dialectStrings.namespaceClose,
dialectStrings.namespacePath);
perOpDefTemplateText, opStrings.opCppName, operandCount,
resultCount, buildDefinition, dialectStrings.namespaceOpen,
dialectStrings.namespaceClose, dialectStrings.namespacePath);
}),
"\n");

Expand All @@ -432,17 +432,19 @@ void {0}::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState,
output << llvm::formatv(dialectDefTemplateText, dialectStrings.namespaceOpen,
dialectStrings.namespaceClose,
dialectStrings.dialectCppName,
dialectStrings.namespacePath,
commaSeparatedOpList,
commaSeparatedTypeList
);
dialectStrings.namespacePath, commaSeparatedOpList,
commaSeparatedTypeList);

output << "#endif // " << definitionMacroFlag << "\n";
return success();
}

LogicalResult irdl::translateIRDLDialectToCpp(irdl::DialectOp dialect,
raw_ostream &output) {
const auto typeDefTempl = detail::Template(
#include "Templates/TypeDefTest.cpp"
);

StringRef dialectName = dialect.getSymName();

// TODO: deal with no more constraints than the verifier allows.
Expand Down Expand Up @@ -497,7 +499,8 @@ LogicalResult irdl::translateIRDLDialectToCpp(irdl::DialectOp dialect,
dict["DIALECT_CPP_NAME"] = "Test";
dict["DIALECT_NAME"] = "test";

llvm::errs() << detail::formatTemplate(typeDefTempl, dict) << "\n";
typeDefTempl.render(llvm::errs(), dict);
llvm::errs() << "\n";

// if (failed(generateInclude(dialect, output, dialectStrings)))
// return failure();
Expand Down
108 changes: 65 additions & 43 deletions mlir/lib/Target/IRDLToCpp/Templates/TemplatingUtils.h
Original file line number Diff line number Diff line change
@@ -1,52 +1,74 @@
#ifndef IRDLTOCPP_TEMPLATE_UTILS_H
#define IRDLTOCPP_TEMPLATE_UTILS_H

#include "llvm/ADT/SmallVector.h"
#include <string>
#include <array>
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <variant>

namespace mlir::irdl::detail {
using dictionary = llvm::StringMap<std::string>;

inline std::string formatTemplate(std::string_view value, const dictionary& dict) {
std::string workingString;
workingString.reserve(value.length() * 2);

char prevToken = '\0';
bool isProcessingVarName = false;
size_t tokenStart{};

for (size_t r = 0; r < value.length(); ++r)
{
const auto currToken = value[r];
if (currToken == '_' && prevToken == '_') {
if (!isProcessingVarName) {
tokenStart = r-1;
isProcessingVarName = true;
}
else {
llvm::StringRef varName {value.begin() + tokenStart + 2, r - tokenStart - 3};
if (const auto itr = dict.find(varName); itr != dict.end()) {
const auto& value = itr->second;
workingString.resize(workingString.length() - 1);
workingString.insert(workingString.end(), value.begin(), value.end());
} else {
llvm::errs() << "undefined variable: " << varName << "\n";
assert(false && "undefined variable");
}
isProcessingVarName = false;
continue;
}
}

if (!isProcessingVarName)
workingString.push_back(currToken);

prevToken = currToken;
using dictionary = llvm::StringMap<llvm::SmallString<8>>;

class Template {
public:
Template(llvm::StringRef str) {
bool processingReplacementToken = false;
while (!str.empty()) {
auto [token, remainder] = str.split("__");

if (processingReplacementToken) {
assert(!token.empty() && "replacement name cannot be empty");
bytecode.emplace_back(ReplacementToken{token});
} else {
if (!token.empty())
bytecode.emplace_back(LiteralToken{token});
}

processingReplacementToken = !processingReplacementToken;
str = remainder;
}
}

void render(llvm::raw_ostream &out, const dictionary &replacements) const {
for (auto instruction : bytecode) {
std::visit(
[&](auto &&inst) {
using T = std::decay_t<decltype(inst)>;
if constexpr (std::is_same_v<T, LiteralToken>) {
out << inst.text;
} else if constexpr (std::is_same_v<T, ReplacementToken>) {
auto replacement = replacements.find(inst.keyName);
#ifndef NDEBUG
if (replacement == replacements.end()) {
llvm::errs()
<< "Missing template key: " << inst.keyName << "\n";
llvm_unreachable("Missing template key");
}
#endif
out << replacement->second;
} else {
static_assert(false, "non-exhaustive visitor!");
}
},
instruction);
}
}

private:
struct LiteralToken {
llvm::StringRef text;
};

struct ReplacementToken {
llvm::StringRef keyName;
};

std::vector<std::variant<LiteralToken, ReplacementToken>> bytecode;
};

return workingString;
}
} // namespace mlir::irdl
} // namespace mlir::irdl::detail

#endif // #ifndef IRDLTOCPP_TEMPLATE_UTILS_H
#endif // #ifndef IRDLTOCPP_TEMPLATE_UTILS_H