Skip to content

Commit d66cfa9

Browse files
authored
Fix quite subtle but nasty bug in linear map tuple types computation: (#68413)
We need a lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB). As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples. Traverse basic blocks in reverse post-order traverse order building linear map tuples and branch tracing enums in one go, ensuring that we've done with predecessor BBs before processing the BB itself.
1 parent 053c276 commit d66cfa9

File tree

2 files changed

+64
-14
lines changed

2 files changed

+64
-14
lines changed

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
142142
heapAllocatedContext = true;
143143
decl->setInterfaceType(astCtx.TheRawPointerType);
144144
} else { // Otherwise the payload is the linear map tuple.
145-
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
145+
auto *linearMapStructTy = getLinearMapTupleType(predBB);
146+
assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
147+
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
146148
decl->setInterfaceType(
147-
linearMapStructTy->hasArchetype()
148-
? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy);
149+
canLinearMapStructTy->hasArchetype()
150+
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
149151
}
150152
// Create enum element and enum case declarations.
151153
auto *paramList = ParameterList::create(astCtx, {decl});
@@ -331,10 +333,28 @@ void LinearMapInfo::generateDifferentiationDataStructures(
331333
}
332334

333335
// Add linear map fields to the linear map tuples.
334-
for (auto &origBB : *original) {
336+
//
337+
// Now we need to be very careful as we're having a very subtle
338+
// chicken-and-egg problem. We need lowered branch trace enum type for the
339+
// linear map typle type. However branch trace enum type lowering depends on
340+
// the lowering of its elements (at very least, the type classification of
341+
// being trivial / non-trivial). As the lowering is cached we need to ensure
342+
// we compute lowered type for the branch trace enum when the corresponding
343+
// EnumDecl is fully complete: we cannot add more entries without causing some
344+
// very subtle issues later on. However, the elements of the enum are linear
345+
// map tuples of predecessors, that correspondingly may contain branch trace
346+
// enums of corresponding predecessor BBs.
347+
//
348+
// Traverse all BBs in reverse post-order traversal order to ensure we process
349+
// each BB before its predecessors.
350+
llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
351+
for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
352+
auto *origBB = *Iter;
335353
SmallVector<TupleTypeElt, 4> linearTupleTypes;
336-
if (!origBB.isEntry()) {
337-
CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType();
354+
if (!origBB->isEntry()) {
355+
populateBranchingTraceDecl(origBB, loopInfo);
356+
357+
CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
338358
linearTupleTypes.emplace_back(traceEnumType,
339359
astCtx.getIdentifier(traceEnumFieldName));
340360
}
@@ -343,7 +363,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
343363
// Do not add linear map fields for semantic member accessors, which have
344364
// special-case pullback generation. Linear map tuples should be empty.
345365
} else {
346-
for (auto &inst : origBB) {
366+
for (auto &inst : *origBB) {
347367
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
348368
// Add linear map field to struct for active `apply` instructions.
349369
// Skip array literal intrinsic applications since array literal
@@ -363,12 +383,9 @@ void LinearMapInfo::generateDifferentiationDataStructures(
363383
}
364384
}
365385

366-
linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)});
386+
linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
367387
}
368388

369-
for (auto &origBB : *original)
370-
populateBranchingTraceDecl(&origBB, loopInfo);
371-
372389
// Print generated linear map structs and branching trace enums.
373390
// These declarations do not show up with `-emit-sil` because they are
374391
// implicit. Instead, use `-Xllvm -debug-only=differentiation` to test

test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift

+36-3
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ func cond(_ x: Float) -> Float {
5656
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]]
5757
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
5858

59-
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
59+
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
6060
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr
6161
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
6262
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
6363
// CHECK-SIL: return [[VJP_RESULT]]
6464

6565

6666
// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float {
67-
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0):
67+
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0):
6868
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3
6969

7070
// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))):
@@ -132,6 +132,39 @@ func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
132132
return result
133133
}
134134

135+
@differentiable(reverse)
136+
@_silgen_name("loop_context")
137+
func loop_context(x: Float) -> Float {
138+
let y = x + 1
139+
for _ in 0 ..< 1 {}
140+
return y
141+
}
142+
143+
// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context:
144+
// CHECK-DATA-STRUCTURES: (_: (Float) -> Float)
145+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
146+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0)
147+
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0)
148+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 {
149+
// CHECK-DATA-STRUCTURES: }
150+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 {
151+
// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer)
152+
// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float))
153+
// CHECK-DATA-STRUCTURES: }
154+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 {
155+
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
156+
// CHECK-DATA-STRUCTURES: }
157+
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 {
158+
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
159+
// CHECK-DATA-STRUCTURES: }
160+
161+
// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float {
162+
// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer):
163+
// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
164+
// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
165+
// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0))
166+
// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)):
167+
135168
// Test `switch_enum`.
136169

137170
enum Enum {
@@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float {
164197
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float)
165198
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
166199

167-
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
200+
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
168201
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr
169202
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
170203
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)

0 commit comments

Comments
 (0)