Skip to content

Commit bab23d8

Browse files
authored
SILOptimizer: fix partial_apply optimization. (swiftlang#31552)
`partial_apply` can be rewritten to `thin_to_thick_function` only if the specialized callee is `@convention(thin)`. This condition is newly exercised by the differentiation transform: `{JVP,VJP}Emitter::visitApplyInst` generates argument-less `partial_apply` with `@convention(method)` callees. Resolves SR-12732.
1 parent 5d8af8c commit bab23d8

File tree

5 files changed

+35
-14
lines changed

5 files changed

+35
-14
lines changed

lib/SILOptimizer/IPO/CapturePropagation.cpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -460,15 +460,20 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
460460
// arguments are dead?
461461
std::pair<SILFunction *, SILFunction *> GenericSpecialized;
462462
SILOptFunctionBuilder FuncBuilder(*this);
463-
if (auto *NewFunc = getSpecializedWithDeadParams(FuncBuilder,
464-
PAI, SubstF, PAI->getNumArguments(), GenericSpecialized)) {
465-
rewritePartialApply(PAI, NewFunc);
466-
if (GenericSpecialized.first) {
467-
// Notify the pass manager about the new function.
468-
addFunctionToPassManagerWorklist(GenericSpecialized.first,
469-
GenericSpecialized.second);
463+
if (auto *NewFunc = getSpecializedWithDeadParams(FuncBuilder, PAI, SubstF,
464+
PAI->getNumArguments(),
465+
GenericSpecialized)) {
466+
// `partial_apply` can be rewritten to `thin_to_thick_function` only if the
467+
// specialized callee is `@convention(thin)`.
468+
if (NewFunc->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
469+
rewritePartialApply(PAI, NewFunc);
470+
if (GenericSpecialized.first) {
471+
// Notify the pass manager about the new function.
472+
addFunctionToPassManagerWorklist(GenericSpecialized.first,
473+
GenericSpecialized.second);
474+
}
475+
return true;
470476
}
471-
return true;
472477
}
473478

474479
// Second possibility: Are all partially applied arguments constant?
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %target-build-swift -Osize %s
2+
// REQUIRES: asserts
3+
4+
// SR-12732: Fix `partial_apply` optimization.
5+
6+
// Do not rewrite `partial_apply` to `thin_to_thick_function` if the specialized
7+
// callee is not `@convention(thin)`.
8+
9+
import DifferentiationUnittest
10+
11+
func callback(_ x: inout Tracked<Float>.TangentVector) {}
12+
13+
@differentiable
14+
func caller(_ x: Tracked<Float>) -> Tracked<Float> {
15+
return x.withDerivative(callback)
16+
}
17+
18+
// SIL verification failed: operand of thin_to_thick_function must be thin: opFTy->getRepresentation() == SILFunctionType::Representation::Thin
19+
// Verifying instruction:
20+
// // function_ref specialized Differentiable._vjpWithDerivative(_:)
21+
// %10 = function_ref @$s16_Differentiation14DifferentiablePAAE18_vjpWithDerivativeyx5value_13TangentVectorQzAGc8pullbacktyAGzcF0A8Unittest7TrackedVySfG_Tg5 : $@convention(method) (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) // user: %11
22+
// -> %11 = thin_to_thick_function %10 : $@convention(method) (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) to $@callee_guaranteed (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) // user: %12

test/AutoDiff/stdlib/derivative_customization.swift

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

4-
// REQUIRES: SR12732
5-
64
import DifferentiationUnittest
75
import StdlibUnittest
86

test/AutoDiff/validation-test/custom_derivatives.swift

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

4-
// REQUIRES: SR12732
5-
64
import StdlibUnittest
75
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
86
import Darwin.C

test/AutoDiff/validation-test/derivative_registration.swift

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

4-
// REQUIRES: SR12732
5-
64
import StdlibUnittest
75
import DifferentiationUnittest
86

0 commit comments

Comments
 (0)