Skip to content

Commit 26bb94a

Browse files
Mingsheng HongMarc Rasi
Mingsheng Hong
authored and
Marc Rasi
committed
Fixed retain/release related test failures, caused by the change of swift (#12)
calling convention, where arguments are passed in as @guaranteed (+0) instead of @owned (+1). As such, the changes to the TFPartition pass are: 1. When sinking a special instruction foo(%x) (tf_send, tf_receive, tf_get_scalar_or_die) below tensor end point, where %x is a tensor handle, make sure we keep a strong_retain in its original inst position, and sink I along with a strong_release below tensor end point. Example code snippet: %x = ... foo(%x) ... <tensor end point> strong_release %x The transformed code after sinking foo(%x) is: %x = ... strong_retain %x ... <tensor end point> foo(%x) strong_release %x strong_release %x 2. When removing a copy marker inst (tf_send, tf_receive) from the host code, add a strong_retain to balance the refcount. Example code snippet: %x = ... %y = tf_send(%x) strong_release %y strong_release %x The transformed code after sinking foo(%x) is: %x = ... strong_retain %x strong_release %x strong_release %x 3. Also addressed Richard's code formatting suggestions in a previous PR.
1 parent 7435eb8 commit 26bb94a

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

Diff for: lib/SILOptimizer/Mandatory/TFPartition.cpp

+40-5
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,17 @@ void TFFunctionPartition::markArgument(SILArgument *arg, SILInstruction *user) {
13291329
/// Determine whether we are able to move the specified instruction across
13301330
/// arbitrary other instructions. This is basically "side effect free" in the
13311331
/// most liberal sense.
1332-
static bool canMoveInstruction(SILInstruction *inst) {
1332+
///
1333+
/// If `plusZeroTensorOperand` and `*plusZeroTensorOperand` are both non-NULL,
1334+
/// when the call returns true, it points to a TensorHandle-typed operand `o`
1335+
/// where `inst` takes that operand at +0. In that case, simply sinking `inst`
1336+
/// can lead to `operand` being deallocated before `inst`. As such, in order for
1337+
/// caller to sink `inst`, caller should generate a pair of retain/release on
1338+
/// `o` to wrap `inst`, and sink `inst` along with the release.
1339+
static bool canMoveInstruction(SILInstruction *inst,
1340+
SILValue *plusZeroTensorOperand) {
1341+
if (plusZeroTensorOperand) *plusZeroTensorOperand = SILValue();
1342+
13331343
// Instructions that SIL knows are always side-effect-free can generally be
13341344
// moved.
13351345
if (inst->getMemoryBehavior() == SILInstruction::MemoryBehavior::None) {
@@ -1346,11 +1356,18 @@ static bool canMoveInstruction(SILInstruction *inst) {
13461356
if (isa<PartialApplyInst>(inst))
13471357
return true;
13481358

1359+
SILValue tensorOperand;
13491360
switch (classifyInst(inst)) {
13501361
case PartitioningClass::GetScalarOrDie:
1351-
case PartitioningClass::Hoistable:
13521362
case PartitioningClass::ExplicitSend:
13531363
case PartitioningClass::ExplicitReceive:
1364+
// For the these functions, we know its first parameter is a TensorHandle,
1365+
// with @guaranteed calling convention.
1366+
tensorOperand = inst->getOperand(1);
1367+
if (plusZeroTensorOperand)
1368+
*plusZeroTensorOperand = tensorOperand;
1369+
LLVM_FALLTHROUGH;
1370+
case PartitioningClass::Hoistable:
13541371
return true;
13551372
default:
13561373
return false;
@@ -1372,7 +1389,7 @@ static bool hoistValueAboveStartPoint(SILInstruction *inst,
13721389

13731390
// In general, we need to check to see if we have a chain of side-effect free
13741391
// instructions whose ultimate inputs dominate the start point.
1375-
if (canMoveInstruction(inst)) {
1392+
if (canMoveInstruction(inst, /*plusZeroTensorOperand*/ nullptr)) {
13761393
// We can hoist one of these instructions if all of their operands are
13771394
// hoistable.
13781395
for (auto &op : inst->getAllOperands()) {
@@ -1408,7 +1425,8 @@ static bool sinkValueAfterEndPoint(SILInstruction *inst,
14081425

14091426
// In general, we need to check to see if we have a chain of side-effect free
14101427
// instructions whose ultimate results can all be sunk after the endpoint.
1411-
if (canMoveInstruction(inst)) {
1428+
SILValue plusZeroTensorOperand;
1429+
if (canMoveInstruction(inst, &plusZeroTensorOperand)) {
14121430
// Make sure the end point is dominated by any operands.
14131431
//
14141432
// TODO: We could make this more aggressive through various techniques, e.g.
@@ -1425,12 +1443,25 @@ static bool sinkValueAfterEndPoint(SILInstruction *inst,
14251443

14261444
// If all of the uses are sunk after the end point, then this
14271445
// instruction can be too.
1446+
if (plusZeroTensorOperand) {
1447+
SILBuilder B(inst);
1448+
B.createStrongRetain(inst->getLoc(), plusZeroTensorOperand,
1449+
Atomicity::Atomic);
1450+
}
14281451

14291452
// The tensorEndPoint is the first non-tensor instruction in the program.
14301453
// Insert our sunk instruction immediately before it, and this instruction
14311454
// becomes the new end point.
14321455
inst->moveBefore(tensorEndPoint);
14331456
tensorEndPoint = inst;
1457+
if (plusZeroTensorOperand) {
1458+
// Create strong_release right after `inst`.
1459+
SILBuilder B(inst);
1460+
auto *releaseInst =
1461+
B.createStrongRelease(tensorEndPoint->getLoc(), plusZeroTensorOperand,
1462+
Atomicity::Atomic);
1463+
releaseInst->moveAfter(inst);
1464+
}
14341465
return true;
14351466
}
14361467

@@ -3024,7 +3055,11 @@ bool PartitionCloner::finalizeOriginal() {
30243055
assert(isa<ApplyInst>(ecm) && ecm->getResults().size() == 1 &&
30253056
ecm->getNumOperands() == 2 && "unknown copy in/out instruction");
30263057
auto callee = ecm->getOperand(1);
3027-
ecm->getResults()[0]->replaceAllUsesWith(ecm->getOperand(1));
3058+
// `ecm` takes `callee` at +0 and returns a result at +1, so we need to
3059+
// issue a retain to balance the removal of `ecm`.
3060+
SILBuilder B(ecm);
3061+
B.createStrongRetain(ecm->getLoc(), callee, Atomicity::Atomic);
3062+
ecm->getResults()[0]->replaceAllUsesWith(callee);
30283063
ecm->eraseFromParent();
30293064

30303065
if (callee->use_empty()) // Remove the function_ref too.

Diff for: lib/SILOptimizer/Mandatory/TFUtilities.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -617,14 +617,13 @@ SingleValueInstruction *SILTensorOpInfo::getAttrOperand(SILValue v) {
617617
// %0 = string_literal utf8 "foo"
618618
// // function_ref specialized String.init(
619619
// _builtinStringLiteral:utf8CodeUnitCount:isASCII:)
620-
// function_ref @$SSS21_builtinStringLiteral... : $@convention(thin) (
620+
// function_ref @$SSS21_builtinStringLiteral... : $@convention(thin) (
621621
// Builtin.RawPointer...) -> @owned String
622-
// %4 = apply %3(%0, ...
622+
// %4 = apply %3(%0, ...
623623
// So we want to follow the first func arg of the ApplyInst (%0 above).
624624
if (auto *ai = dyn_cast<ApplyInst>(str)) {
625625
// If the ApplyInst does not have such an operand, we bail with failure.
626626
if (ai->getNumOperands() < 2) return nullptr;
627-
628627
str = ai->getOperand(1);
629628
continue;
630629
}

0 commit comments

Comments
 (0)