Skip to content

[cherry-pick][mlir][llvm] Add support for memset.inline (#115711) #1135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,32 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
];
}

def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
[DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
/*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
/*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
// Append the alias attributes defined by LLVM_IntrOpBase.
let arguments = !con(args, aliasAttrs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
"bool":$isVolatile), [{
build($_builder, $_state, dst, val, len,
$_builder.getBoolAttr(isVolatile));
}]>,
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, val, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}]>
];
}

def LLVM_NoAliasScopeDeclOp
: LLVM_ZeroResultIntrOp<"experimental.noalias.scope.decl"> {
let arguments = (ins LLVM_AliasScopeAttr:$scope);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ SmallVector<Value> mlir::LLVM::MemsetOp::getAccessedOperands() {
return {getDst()};
}

SmallVector<Value> mlir::LLVM::MemsetInlineOp::getAccessedOperands() {
return {getDst()};
}

SmallVector<Value> mlir::LLVM::CallOp::getAccessedOperands() {
return llvm::to_vector(
llvm::make_filter_range(getArgOperands(), [](Value arg) {
Expand Down
281 changes: 211 additions & 70 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,76 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
return memIntrLen.getZExtValue();
}

/// Returns the length of the given memory intrinsic in bytes if it can be known
/// at compile-time on a best-effort basis, nothing otherwise.
/// Because MemsetInlineOp has its length encoded as an attribute, this requires
/// specialized handling.
template <>
std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
APInt memIntrLen = op.getLen();
if (memIntrLen.getBitWidth() > 64)
return {};
return memIntrLen.getZExtValue();
}

/// Returns an integer attribute representing the length of a memset intrinsic
template <class MemsetIntr>
IntegerAttr createMemsetLenAttr(MemsetIntr op) {
IntegerAttr memsetLenAttr;
bool successfulMatch =
matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
(void)successfulMatch;
assert(successfulMatch);
return memsetLenAttr;
}

/// Returns an integer attribute representing the length of a memset intrinsic
/// Because MemsetInlineOp has its length encoded as an attribute, this requires
/// specialized handling.
template <>
IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) {
return op.getLenAttr();
}

/// Creates a memset intrinsic of that matches the `toReplace` intrinsic
/// using the provided parameters. There are template specializations for
/// MemsetOp and MemsetInlineOp.
template <class MemsetIntr>
void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace,
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
DenseMap<Attribute, MemorySlot> &subslots,
Attribute index);

template <>
void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace,
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
DenseMap<Attribute, MemorySlot> &subslots,
Attribute index) {
Value newMemsetSizeValue =
builder
.create<LLVM::ConstantOp>(
toReplace.getLen().getLoc(),
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
.getResult();

builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
toReplace.getVal(), newMemsetSizeValue,
toReplace.getIsVolatile());
}

template <>
void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace,
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
DenseMap<Attribute, MemorySlot> &subslots,
Attribute index) {
auto newMemsetSizeValue =
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);

builder.create<LLVM::MemsetInlineOp>(
toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
newMemsetSizeValue, toReplace.getIsVolatile());
}

} // namespace

/// Returns whether one can be sure the memory intrinsic does not write outside
Expand Down Expand Up @@ -931,38 +1001,52 @@ static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
}

//===----------------------------------------------------------------------===//
// Interfaces for memset
// Interfaces for memset and memset.inline
//===----------------------------------------------------------------------===//

bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
template <class MemsetIntr>
static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
return false;

bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
if (op.getIsVolatile())
return false;

if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;

if (!areAllIndicesI32(slot))
return false;

return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
}

Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
template <class MemsetIntr>
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
OpBuilder &builder) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
if (intType.getWidth() == 8)
return getVal();
return op.getVal();

assert(intType.getWidth() % 8 == 0);

// Build the memset integer by repeatedly shifting the value and
// or-ing it with the previous value.
uint64_t coveredBits = 8;
Value currentValue =
builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
while (coveredBits < intType.getWidth()) {
Value shiftBy =
builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
coveredBits);
Value shifted =
builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
currentValue =
builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
coveredBits *= 2;
}

Expand All @@ -974,10 +1058,12 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
});
}

bool LLVM::MemsetOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
template <class MemsetIntr>
static bool
memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
Expand All @@ -988,62 +1074,28 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
if (!canConvertType)
return false;

if (getIsVolatile())
if (op.getIsVolatile())
return false;

return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
}

DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
}

bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (&slot.elemType.getDialect() != getOperation()->getDialect())
return false;

if (getIsVolatile())
return false;

if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;

if (!areAllIndicesI32(slot))
return false;

return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
}
template <class MemsetIntr>
static DeletionKind
memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
const DataLayout &dataLayout) {

DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();

IntegerAttr memsetLenAttr;
bool successfulMatch =
matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
(void)successfulMatch;
assert(successfulMatch);
IntegerAttr memsetLenAttr = createMemsetLenAttr(op);

bool packed = false;
if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
packed = structType.isPacked();

Type i32 = IntegerType::get(getContext(), 32);
Type i32 = IntegerType::get(op.getContext(), 32);
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
uint64_t covered = 0;
for (size_t i = 0; i < types->size(); i++) {
Expand All @@ -1063,16 +1115,8 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
// Otherwise, only compute its offset within the original memset.
if (subslots.contains(index)) {
uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);

Value newMemsetSizeValue =
builder
.create<LLVM::ConstantOp>(
getLen().getLoc(),
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
.getResult();

builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
newMemsetSizeValue, getIsVolatile());
createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots,
index);
}

covered += typeSize;
Expand All @@ -1081,6 +1125,103 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
return DeletionKind::Delete;
}

bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }

bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}

Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
return memsetGetStored(*this, slot, builder);
}

bool LLVM::MemsetOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
dataLayout);
}

DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
}

bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
dataLayout);
}

DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
return memsetRewire(*this, slot, subslots, builder, dataLayout);
}

bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }

bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}

Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
OpBuilder &builder, Value reachingDef,
const DataLayout &dataLayout) {
return memsetGetStored(*this, slot, builder);
}

bool LLVM::MemsetInlineOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
dataLayout);
}

DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
}

bool LLVM::MemsetInlineOp::canRewire(
const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
dataLayout);
}

DeletionKind
LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder, const DataLayout &dataLayout) {
return memsetRewire(*this, slot, subslots, builder, dataLayout);
}

//===----------------------------------------------------------------------===//
// Interfaces for memcpy/memmove
//===----------------------------------------------------------------------===//
Expand Down
Loading