Skip to content

Commit d0c6a88

Browse files
authored
Fix ret (rust-lang#469)
* fix handling of const return activity * zero differetval if ret is const * fix Constant case * fix formating
1 parent 8830204 commit d0c6a88

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,19 +3432,21 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
34323432
replacedReturns[orig] = si;
34333433
}
34343434

3435-
if (dretAlloca && !gutils->isConstantValue(orig->getReturnValue())) {
3436-
rb.CreateStore(gutils->invertPointerM(orig->getReturnValue(), rb),
3437-
dretAlloca);
3438-
}
3439-
if (key.retType == DIFFE_TYPE::OUT_DIFF) {
3435+
if (key.retType == DIFFE_TYPE::DUP_ARG ||
3436+
key.retType == DIFFE_TYPE::DUP_NONEED) {
3437+
if (dretAlloca) {
3438+
rb.CreateStore(gutils->invertPointerM(orig->getReturnValue(), rb),
3439+
dretAlloca);
3440+
}
3441+
} else if (key.retType == DIFFE_TYPE::OUT_DIFF) {
34403442
assert(orig->getReturnValue());
34413443
assert(differetval);
34423444
if (!gutils->isConstantValue(orig->getReturnValue())) {
34433445
IRBuilder<> reverseB(gutils->reverseBlocks[BB].back());
34443446
gutils->setDiffe(orig->getReturnValue(), differetval, reverseB);
34453447
}
34463448
} else {
3447-
assert(retAlloca == nullptr);
3449+
assert(dretAlloca == nullptr);
34483450
}
34493451

34503452
rb.CreateBr(gutils->reverseBlocks[BB].front());

0 commit comments

Comments
 (0)