Skip to content

Commit da601b3

Browse files
authored
[cherry-pick][mlir][llvm] Add support for memset.inline (#115711) (#1135)
support `llvm.intr.memset.inline` in llvm-project repo before we add support for `__builtin_memset_inline` in clangir cc @bcardosolopes (cherry picked from commit 30753af)
1 parent 6427019 commit da601b3

File tree

7 files changed

+598
-70
lines changed

7 files changed

+598
-70
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

+26
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,32 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
256256
];
257257
}
258258

259+
def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
260+
[DeclareOpInterfaceMethods<PromotableMemOpInterface>,
261+
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
262+
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
263+
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
264+
/*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
265+
/*immArgAttrNames=*/["len", "isVolatile"]> {
266+
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
267+
I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
268+
// Append the alias attributes defined by LLVM_IntrOpBase.
269+
let arguments = !con(args, aliasAttrs);
270+
let builders = [
271+
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
272+
"bool":$isVolatile), [{
273+
build($_builder, $_state, dst, val, len,
274+
$_builder.getBoolAttr(isVolatile));
275+
}]>,
276+
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
277+
"IntegerAttr":$isVolatile), [{
278+
build($_builder, $_state, dst, val, len, isVolatile,
279+
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
280+
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
281+
}]>
282+
];
283+
}
284+
259285
def LLVM_NoAliasScopeDeclOp
260286
: LLVM_ZeroResultIntrOp<"experimental.noalias.scope.decl"> {
261287
let arguments = (ins LLVM_AliasScopeAttr:$scope);

mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ SmallVector<Value> mlir::LLVM::MemsetOp::getAccessedOperands() {
9494
return {getDst()};
9595
}
9696

97+
SmallVector<Value> mlir::LLVM::MemsetInlineOp::getAccessedOperands() {
98+
return {getDst()};
99+
}
100+
97101
SmallVector<Value> mlir::LLVM::CallOp::getAccessedOperands() {
98102
return llvm::to_vector(
99103
llvm::make_filter_range(getArgOperands(), [](Value arg) {

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

+211-70
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,76 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
904904
return memIntrLen.getZExtValue();
905905
}
906906

907+
/// Returns the length of the given memory intrinsic in bytes if it can be known
908+
/// at compile-time on a best-effort basis, nothing otherwise.
909+
/// Because MemsetInlineOp has its length encoded as an attribute, this requires
910+
/// specialized handling.
911+
template <>
912+
std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
913+
APInt memIntrLen = op.getLen();
914+
if (memIntrLen.getBitWidth() > 64)
915+
return {};
916+
return memIntrLen.getZExtValue();
917+
}
918+
919+
/// Returns an integer attribute representing the length of a memset intrinsic
920+
template <class MemsetIntr>
921+
IntegerAttr createMemsetLenAttr(MemsetIntr op) {
922+
IntegerAttr memsetLenAttr;
923+
bool successfulMatch =
924+
matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
925+
(void)successfulMatch;
926+
assert(successfulMatch);
927+
return memsetLenAttr;
928+
}
929+
930+
/// Returns an integer attribute representing the length of a memset intrinsic
931+
/// Because MemsetInlineOp has its length encoded as an attribute, this requires
932+
/// specialized handling.
933+
template <>
934+
IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) {
935+
return op.getLenAttr();
936+
}
937+
938+
/// Creates a memset intrinsic of that matches the `toReplace` intrinsic
939+
/// using the provided parameters. There are template specializations for
940+
/// MemsetOp and MemsetInlineOp.
941+
template <class MemsetIntr>
942+
void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace,
943+
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
944+
DenseMap<Attribute, MemorySlot> &subslots,
945+
Attribute index);
946+
947+
template <>
948+
void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace,
949+
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
950+
DenseMap<Attribute, MemorySlot> &subslots,
951+
Attribute index) {
952+
Value newMemsetSizeValue =
953+
builder
954+
.create<LLVM::ConstantOp>(
955+
toReplace.getLen().getLoc(),
956+
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
957+
.getResult();
958+
959+
builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
960+
toReplace.getVal(), newMemsetSizeValue,
961+
toReplace.getIsVolatile());
962+
}
963+
964+
template <>
965+
void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace,
966+
IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
967+
DenseMap<Attribute, MemorySlot> &subslots,
968+
Attribute index) {
969+
auto newMemsetSizeValue =
970+
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);
971+
972+
builder.create<LLVM::MemsetInlineOp>(
973+
toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
974+
newMemsetSizeValue, toReplace.getIsVolatile());
975+
}
976+
907977
} // namespace
908978

909979
/// Returns whether one can be sure the memory intrinsic does not write outside
@@ -931,38 +1001,52 @@ static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
9311001
}
9321002

9331003
//===----------------------------------------------------------------------===//
934-
// Interfaces for memset
1004+
// Interfaces for memset and memset.inline
9351005
//===----------------------------------------------------------------------===//
9361006

937-
bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
1007+
template <class MemsetIntr>
1008+
static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1009+
SmallPtrSetImpl<Attribute> &usedIndices,
1010+
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1011+
const DataLayout &dataLayout) {
1012+
if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
1013+
return false;
9381014

939-
bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
940-
return getDst() == slot.ptr;
1015+
if (op.getIsVolatile())
1016+
return false;
1017+
1018+
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1019+
return false;
1020+
1021+
if (!areAllIndicesI32(slot))
1022+
return false;
1023+
1024+
return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
9411025
}
9421026

943-
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
944-
Value reachingDef,
945-
const DataLayout &dataLayout) {
1027+
template <class MemsetIntr>
1028+
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
1029+
OpBuilder &builder) {
9461030
// TODO: Support non-integer types.
9471031
return TypeSwitch<Type, Value>(slot.elemType)
9481032
.Case([&](IntegerType intType) -> Value {
9491033
if (intType.getWidth() == 8)
950-
return getVal();
1034+
return op.getVal();
9511035

9521036
assert(intType.getWidth() % 8 == 0);
9531037

9541038
// Build the memset integer by repeatedly shifting the value and
9551039
// or-ing it with the previous value.
9561040
uint64_t coveredBits = 8;
9571041
Value currentValue =
958-
builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
1042+
builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
9591043
while (coveredBits < intType.getWidth()) {
960-
Value shiftBy =
961-
builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
1044+
Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
1045+
coveredBits);
9621046
Value shifted =
963-
builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
1047+
builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
9641048
currentValue =
965-
builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
1049+
builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
9661050
coveredBits *= 2;
9671051
}
9681052

@@ -974,10 +1058,12 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
9741058
});
9751059
}
9761060

977-
bool LLVM::MemsetOp::canUsesBeRemoved(
978-
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
979-
SmallVectorImpl<OpOperand *> &newBlockingUses,
980-
const DataLayout &dataLayout) {
1061+
template <class MemsetIntr>
1062+
static bool
1063+
memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
1064+
const SmallPtrSetImpl<OpOperand *> &blockingUses,
1065+
SmallVectorImpl<OpOperand *> &newBlockingUses,
1066+
const DataLayout &dataLayout) {
9811067
// TODO: Support non-integer types.
9821068
bool canConvertType =
9831069
TypeSwitch<Type, bool>(slot.elemType)
@@ -988,62 +1074,28 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
9881074
if (!canConvertType)
9891075
return false;
9901076

991-
if (getIsVolatile())
1077+
if (op.getIsVolatile())
9921078
return false;
9931079

994-
return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
995-
}
996-
997-
DeletionKind LLVM::MemsetOp::removeBlockingUses(
998-
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
999-
OpBuilder &builder, Value reachingDefinition,
1000-
const DataLayout &dataLayout) {
1001-
return DeletionKind::Delete;
1002-
}
1003-
1004-
LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1005-
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1006-
const DataLayout &dataLayout) {
1007-
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1080+
return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
10081081
}
10091082

1010-
bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1011-
SmallPtrSetImpl<Attribute> &usedIndices,
1012-
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1013-
const DataLayout &dataLayout) {
1014-
if (&slot.elemType.getDialect() != getOperation()->getDialect())
1015-
return false;
1016-
1017-
if (getIsVolatile())
1018-
return false;
1019-
1020-
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1021-
return false;
1022-
1023-
if (!areAllIndicesI32(slot))
1024-
return false;
1025-
1026-
return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
1027-
}
1083+
template <class MemsetIntr>
1084+
static DeletionKind
1085+
memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1086+
DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1087+
const DataLayout &dataLayout) {
10281088

1029-
DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1030-
DenseMap<Attribute, MemorySlot> &subslots,
1031-
OpBuilder &builder,
1032-
const DataLayout &dataLayout) {
10331089
std::optional<DenseMap<Attribute, Type>> types =
10341090
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
10351091

1036-
IntegerAttr memsetLenAttr;
1037-
bool successfulMatch =
1038-
matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
1039-
(void)successfulMatch;
1040-
assert(successfulMatch);
1092+
IntegerAttr memsetLenAttr = createMemsetLenAttr(op);
10411093

10421094
bool packed = false;
10431095
if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
10441096
packed = structType.isPacked();
10451097

1046-
Type i32 = IntegerType::get(getContext(), 32);
1098+
Type i32 = IntegerType::get(op.getContext(), 32);
10471099
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
10481100
uint64_t covered = 0;
10491101
for (size_t i = 0; i < types->size(); i++) {
@@ -1063,16 +1115,8 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
10631115
// Otherwise, only compute its offset within the original memset.
10641116
if (subslots.contains(index)) {
10651117
uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
1066-
1067-
Value newMemsetSizeValue =
1068-
builder
1069-
.create<LLVM::ConstantOp>(
1070-
getLen().getLoc(),
1071-
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
1072-
.getResult();
1073-
1074-
builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
1075-
newMemsetSizeValue, getIsVolatile());
1118+
createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots,
1119+
index);
10761120
}
10771121

10781122
covered += typeSize;
@@ -1081,6 +1125,103 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
10811125
return DeletionKind::Delete;
10821126
}
10831127

1128+
bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
1129+
1130+
bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
1131+
return getDst() == slot.ptr;
1132+
}
1133+
1134+
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1135+
Value reachingDef,
1136+
const DataLayout &dataLayout) {
1137+
return memsetGetStored(*this, slot, builder);
1138+
}
1139+
1140+
bool LLVM::MemsetOp::canUsesBeRemoved(
1141+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1142+
SmallVectorImpl<OpOperand *> &newBlockingUses,
1143+
const DataLayout &dataLayout) {
1144+
return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1145+
dataLayout);
1146+
}
1147+
1148+
DeletionKind LLVM::MemsetOp::removeBlockingUses(
1149+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1150+
OpBuilder &builder, Value reachingDefinition,
1151+
const DataLayout &dataLayout) {
1152+
return DeletionKind::Delete;
1153+
}
1154+
1155+
LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1156+
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1157+
const DataLayout &dataLayout) {
1158+
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1159+
}
1160+
1161+
bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1162+
SmallPtrSetImpl<Attribute> &usedIndices,
1163+
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1164+
const DataLayout &dataLayout) {
1165+
return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1166+
dataLayout);
1167+
}
1168+
1169+
DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1170+
DenseMap<Attribute, MemorySlot> &subslots,
1171+
OpBuilder &builder,
1172+
const DataLayout &dataLayout) {
1173+
return memsetRewire(*this, slot, subslots, builder, dataLayout);
1174+
}
1175+
1176+
bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }
1177+
1178+
bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
1179+
return getDst() == slot.ptr;
1180+
}
1181+
1182+
Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
1183+
OpBuilder &builder, Value reachingDef,
1184+
const DataLayout &dataLayout) {
1185+
return memsetGetStored(*this, slot, builder);
1186+
}
1187+
1188+
bool LLVM::MemsetInlineOp::canUsesBeRemoved(
1189+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1190+
SmallVectorImpl<OpOperand *> &newBlockingUses,
1191+
const DataLayout &dataLayout) {
1192+
return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1193+
dataLayout);
1194+
}
1195+
1196+
DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
1197+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1198+
OpBuilder &builder, Value reachingDefinition,
1199+
const DataLayout &dataLayout) {
1200+
return DeletionKind::Delete;
1201+
}
1202+
1203+
LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
1204+
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1205+
const DataLayout &dataLayout) {
1206+
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1207+
}
1208+
1209+
bool LLVM::MemsetInlineOp::canRewire(
1210+
const DestructurableMemorySlot &slot,
1211+
SmallPtrSetImpl<Attribute> &usedIndices,
1212+
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1213+
const DataLayout &dataLayout) {
1214+
return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1215+
dataLayout);
1216+
}
1217+
1218+
DeletionKind
1219+
LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
1220+
DenseMap<Attribute, MemorySlot> &subslots,
1221+
OpBuilder &builder, const DataLayout &dataLayout) {
1222+
return memsetRewire(*this, slot, subslots, builder, dataLayout);
1223+
}
1224+
10841225
//===----------------------------------------------------------------------===//
10851226
// Interfaces for memcpy/memmove
10861227
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)