@@ -1110,39 +1110,42 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
1110
1110
// %vjp' = convert_escape_to_noescape %vjp
1111
1111
// %y = differentiable_function(%orig', %jvp', %vjp')
1112
1112
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getOperand ())) {
1113
- auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1114
- if (!DFI->hasExtractee (extractee))
1115
- return SILValue ();
1113
+ if (DFI->hasOneUse ()) {
1114
+ auto createConvertEscapeToNoEscape =
1115
+ [&](NormalDifferentiableFunctionTypeComponent extractee) {
1116
+ if (!DFI->hasExtractee (extractee))
1117
+ return SILValue ();
1116
1118
1117
- auto operand = DFI->getExtractee (extractee);
1118
- auto fnType = operand->getType ().castTo <SILFunctionType>();
1119
- auto noEscapeFnType =
1120
- fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1121
- auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1122
- return Builder.createConvertEscapeToNoEscape (
1123
- operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1124
- };
1119
+ auto operand = DFI->getExtractee (extractee);
1120
+ auto fnType = operand->getType ().castTo <SILFunctionType>();
1121
+ auto noEscapeFnType =
1122
+ fnType->getWithExtInfo (fnType->getExtInfo ().withNoEscape ());
1123
+ auto noEscapeType = SILType::getPrimitiveObjectType (noEscapeFnType);
1124
+ return Builder.createConvertEscapeToNoEscape (
1125
+ operand.getLoc (), operand, noEscapeType, Cvt->isLifetimeGuaranteed ())->getResult (0 );
1126
+ };
1125
1127
1126
- SILValue originalNoEscape =
1127
- createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1128
- SILValue convertedJVP = createConvertEscapeToNoEscape (
1129
- NormalDifferentiableFunctionTypeComponent::JVP);
1130
- SILValue convertedVJP = createConvertEscapeToNoEscape (
1131
- NormalDifferentiableFunctionTypeComponent::VJP);
1132
-
1133
- llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1134
- if (convertedJVP && convertedVJP)
1135
- derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1136
-
1137
- auto *newDFI = Builder.createDifferentiableFunction (
1138
- DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1139
- originalNoEscape, derivativeFunctions);
1140
- assert (newDFI->getType () == Cvt->getType () &&
1141
- " New `@differentiable` function instruction should have same type "
1142
- " as the old `convert_escape_to_no_escape` instruction" );
1143
- return newDFI;
1144
- }
1128
+ SILValue originalNoEscape =
1129
+ createConvertEscapeToNoEscape (NormalDifferentiableFunctionTypeComponent::Original);
1130
+ SILValue convertedJVP = createConvertEscapeToNoEscape (
1131
+ NormalDifferentiableFunctionTypeComponent::JVP);
1132
+ SILValue convertedVJP = createConvertEscapeToNoEscape (
1133
+ NormalDifferentiableFunctionTypeComponent::VJP);
1134
+
1135
+ llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1136
+ if (convertedJVP && convertedVJP)
1137
+ derivativeFunctions = std::make_pair (convertedJVP, convertedVJP);
1145
1138
1139
+ auto *newDFI = Builder.createDifferentiableFunction (
1140
+ DFI->getLoc (), DFI->getParameterIndices (), DFI->getResultIndices (),
1141
+ originalNoEscape, derivativeFunctions);
1142
+ assert (newDFI->getType () == Cvt->getType () &&
1143
+ " New `@differentiable` function instruction should have same type "
1144
+ " as the old `convert_escape_to_no_escape` instruction" );
1145
+ return newDFI;
1146
+ }
1147
+ }
1148
+
1146
1149
return nullptr ;
1147
1150
}
1148
1151
0 commit comments