Skip to content

Commit 118fd86

Browse files
authored
Ensure we can fold apply of a differentiable_function_inst. Also fixes one small potential issue while there. (#65605)
Fixes #65489 #67992
1 parent 8a2aeb9 commit 118fd86

File tree

3 files changed

+88
-30
lines changed

3 files changed

+88
-30
lines changed

lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,13 @@ SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) {
15001500
}
15011501
}
15021502

1503+
// (apply (differentiable_function f)) to (apply f)
1504+
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(AI->getCallee())) {
1505+
return cloneFullApplySiteReplacingCallee(AI, DFI->getOperand(0),
1506+
Builder.getBuilderContext())
1507+
.getInstruction();
1508+
}
1509+
15031510
// (apply (thin_to_thick_function f)) to (apply f)
15041511
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(AI->getCallee())) {
15051512
// We currently don't remove any possible retain associated with the thick

lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp

+33-30
Original file line numberDiff line numberDiff line change
@@ -1110,39 +1110,42 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
11101110
// %vjp' = convert_escape_to_noescape %vjp
11111111
// %y = differentiable_function(%orig', %jvp', %vjp')
11121112
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getOperand())) {
1113-
auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) {
1114-
if (!DFI->hasExtractee(extractee))
1115-
return SILValue();
1113+
if (DFI->hasOneUse()) {
1114+
auto createConvertEscapeToNoEscape =
1115+
[&](NormalDifferentiableFunctionTypeComponent extractee) {
1116+
if (!DFI->hasExtractee(extractee))
1117+
return SILValue();
11161118

1117-
auto operand = DFI->getExtractee(extractee);
1118-
auto fnType = operand->getType().castTo<SILFunctionType>();
1119-
auto noEscapeFnType =
1120-
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1121-
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1122-
return Builder.createConvertEscapeToNoEscape(
1123-
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1124-
};
1119+
auto operand = DFI->getExtractee(extractee);
1120+
auto fnType = operand->getType().castTo<SILFunctionType>();
1121+
auto noEscapeFnType =
1122+
fnType->getWithExtInfo(fnType->getExtInfo().withNoEscape());
1123+
auto noEscapeType = SILType::getPrimitiveObjectType(noEscapeFnType);
1124+
return Builder.createConvertEscapeToNoEscape(
1125+
operand.getLoc(), operand, noEscapeType, Cvt->isLifetimeGuaranteed())->getResult(0);
1126+
};
11251127

1126-
SILValue originalNoEscape =
1127-
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1128-
SILValue convertedJVP = createConvertEscapeToNoEscape(
1129-
NormalDifferentiableFunctionTypeComponent::JVP);
1130-
SILValue convertedVJP = createConvertEscapeToNoEscape(
1131-
NormalDifferentiableFunctionTypeComponent::VJP);
1132-
1133-
llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1134-
if (convertedJVP && convertedVJP)
1135-
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
1136-
1137-
auto *newDFI = Builder.createDifferentiableFunction(
1138-
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1139-
originalNoEscape, derivativeFunctions);
1140-
assert(newDFI->getType() == Cvt->getType() &&
1141-
"New `@differentiable` function instruction should have same type "
1142-
"as the old `convert_escape_to_no_escape` instruction");
1143-
return newDFI;
1144-
}
1128+
SILValue originalNoEscape =
1129+
createConvertEscapeToNoEscape(NormalDifferentiableFunctionTypeComponent::Original);
1130+
SILValue convertedJVP = createConvertEscapeToNoEscape(
1131+
NormalDifferentiableFunctionTypeComponent::JVP);
1132+
SILValue convertedVJP = createConvertEscapeToNoEscape(
1133+
NormalDifferentiableFunctionTypeComponent::VJP);
1134+
1135+
llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions;
1136+
if (convertedJVP && convertedVJP)
1137+
derivativeFunctions = std::make_pair(convertedJVP, convertedVJP);
11451138

1139+
auto *newDFI = Builder.createDifferentiableFunction(
1140+
DFI->getLoc(), DFI->getParameterIndices(), DFI->getResultIndices(),
1141+
originalNoEscape, derivativeFunctions);
1142+
assert(newDFI->getType() == Cvt->getType() &&
1143+
"New `@differentiable` function instruction should have same type "
1144+
"as the old `convert_escape_to_no_escape` instruction");
1145+
return newDFI;
1146+
}
1147+
}
1148+
11461149
return nullptr;
11471150
}
11481151

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %target-swift-frontend -emit-sil -O %s | %FileCheck %s
2+
// REQUIRES: swift_in_compiler
3+
4+
import _Differentiation
5+
6+
@differentiable(reverse)
7+
@_silgen_name("test_f")
8+
// Check that (differentiable) closure apply is optimized out
9+
// CHECK-LABEL: test_f : $@convention(thin) (@guaranteed Array<Double>) -> Double
10+
// CHECK-NOT: differentiable_function [parameters 0] [results 0]
11+
func f(array: [Double]) -> Double {
12+
var array = array
13+
array.update(at: 1,
14+
byCalling: {
15+
(element: inout Double) in
16+
let initialElement = element;
17+
element *= initialElement
18+
}
19+
)
20+
21+
return 0.0
22+
}
23+
24+
public func valueWithPullback<T>(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (value: Void, pullback: (inout T.TangentVector) -> Void) {fatalError()}
25+
public func pullback<T>(at x: T, of f: @differentiable(reverse) (inout T) -> Void) -> (inout T.TangentVector) -> Void {return valueWithPullback(at: x, of: f).pullback}
26+
27+
public extension Array {
28+
@differentiable(reverse)
29+
mutating func update(at index: Int,
30+
byCalling closure: @differentiable(reverse) (inout Element) -> Void) where Element: Differentiable {
31+
closure(&self[index])
32+
}
33+
}
34+
35+
public extension Array where Element: Differentiable {
36+
@derivative(of: update(at:byCalling:))
37+
mutating func vjpUpdate(at index: Int, byCalling closure: @differentiable(reverse) (inout Element) -> Void) -> (value: Void, pullback: (inout Self.TangentVector) -> Void) {
38+
let closurePullback = pullback(at: self[index], of: closure)
39+
return (value: (), pullback: { closurePullback(&$0.base[index]) })
40+
}
41+
}
42+
43+
public struct D<I: Equatable, D> {
44+
public subscript(_ index: I) -> D? {
45+
get {fatalError()}
46+
set {fatalError()}
47+
}
48+
}

0 commit comments

Comments
 (0)