Skip to content

[AutoDiff] Fix quite subtle but nasty bug in linear map tuple types computation #68413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023

Conversation

asl
Copy link
Contributor

@asl asl commented Sep 9, 2023

We need lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB).

As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples.

Traverse basic blocks in reverse post-order traver order building linear map tuples and branch tracing enumns in one go, ensuring that we've done with predecessor BBs before processing the BB itself.

Fixes #68392

@asl asl added the AutoDiff label Sep 9, 2023
@asl asl requested a review from rxwei September 9, 2023 02:29
@asl
Copy link
Contributor Author

asl commented Sep 9, 2023

Tagging @BradLarson @jkshtj

@asl
Copy link
Contributor Author

asl commented Sep 9, 2023

@swift-ci please test

@asl asl changed the title Fix quite subtle but nasty bug in linear map tuple types computation [AutoDiff] Fix quite subtle but nasty bug in linear map tuple types computation Sep 9, 2023
@asl
Copy link
Contributor Author

asl commented Sep 9, 2023

To give more context: previously we calculated lowered type of Enum out of empty EnumDecl. As a result, all these SIL types were deemed as trivial (unless e.g. generic) despite us adding more non-trivial Enum cases later on.

@@ -331,10 +333,24 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}

// Add linear map fields to the linear map tuples.
for (auto &origBB : *original) {
// Now we need to be very careful as we're having a very subtle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why the reverse post-order traversal(RPOT) is needed and how it is helping?

I understand that we need the branch trace enum for a BB to be fully constructed before using it to derive the type for the linear map tuple of the same BB. But I don't understand how going in RPOT order is going to ensure that we have the right types for the branch trace enums?

I think so because while fully constructing a branch trace enum declaration in populateBranchingTraceDecl we look up the linear map tuple types of the predecessor BBs, but these types haven't been set yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When traversing BBs in RPOT we ensure that each BB is processed after its predecessors. As a result, we know that all linear map tuples for predecessor BBs are already finalized and therefore the enum type that we created will also be "complete" – we will not need to add any entries later.

The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When traversing BBs in RPOT we ensure that each BB is processed after its predecessors.

Ah gotcha. I was mistaking RPOT for regular post order traversal but (Right->Left->Node) instead of (Left->Right->Node). But it seems like it's more or less a "pre-order" traversal.

Copy link
Contributor Author

@asl asl Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's neither a pre-order nor a post-order, it's not a DFS. We cannot use DFS because we're having a DAG, not a tree. Think about A -> B -> C -> D; B -> D CFG. We need to visit D after both B and C. RPOT is quite a standard technique for various data-flow problems (and could be used to compute topological sorting of the graph).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks for the explanation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).

You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.

we need lowered type for branch trace enum in order to compute linear map
tuple type. However, the lowering of branch trace enum type depends on the
types of its elements (the payloads are linear map tuples of predecessor BB).

As lowered types are cached, we cannot populate branch trace enum entries
in the end as we did before: we already used wrong lowered types for linear
map tuples.

Traverse basic blocks in reverse post-order traver order building linear
map tuples and branch tracing enumns in one go, ensuring that we've done
with predecessor BBs before processing the BB itself.
@asl
Copy link
Contributor Author

asl commented Sep 11, 2023

@swift-ci please test

@@ -331,10 +333,24 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}

// Add linear map fields to the linear map tuples.
for (auto &origBB : *original) {
// Now we need to be very careful as we're having a very subtle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).

You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.

@asl
Copy link
Contributor Author

asl commented Sep 11, 2023

@swift-ci please test macos

@asl
Copy link
Contributor Author

asl commented Sep 11, 2023

@slavapestov

You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.

Exactly! And this was what the bug fixed. Now we are building enum AST types before corresponding SIL types and the iteration order enforces this.

It is not quite possible to build all AST types before the SILGen due to way how autodiff is done. In particular, the element types of these enums correspond to the SIL BBs. Do you see any sane way how we can do the things before SILgen?

Tagging @rxwei

@slavapestov
Copy link
Contributor

slavapestov commented Sep 11, 2023

It is not quite possible to build all AST types before the SILGen due to way how autodiff is done. In particular, the element types of these enums correspond to the SIL BBs. Do you see any sane way how we can do the things before SILgen?

Can you enforce this by collecting all the constituents that would go into building the enum somewhere during SILGen, then create the EnumDecl and populate it in one shot, before handing it off? That way you can never access the incomplete declaration.

(That is, pretend EnumDecl was immutable after construction, and write your code like that.)

@asl
Copy link
Contributor Author

asl commented Sep 11, 2023

@slavapestov

Can you enforce this by collecting all the constituents that would go into building the enum somewhere during SILGen, then create the EnumDecl and populate it in one shot, before handing it off? That way you can never access the incomplete declaration.

Ok, this is what I tried to explain in the comment (chicken-and-egg problem), but seems the explanation was confusing. Let me explain what is going on in detail.

Consider we're having the following SIL function:

entry:
...
 apply %f
 apply %g
 cond_br %cond, bb1, bb2
bb1:
 br bb3
bb2:
 apply %h
 br bb3
bb3:
 apply %j

In order to generate derivative for this function we need to:

  • Capture active values (pullbacks of active calls)
  • Trace the execution flow

This is done via so-called linear map tuples. We create tuple for each BB capturing the pullbacks of the active calls. Also, for all but the entry BB we're creating branch tracing enums with entries corresponding to predecessor BBs and the payloads being linear map tuples of predecessors.

So, for the function above we will have:

  • For entry linear map tuple would be (pullback of %f, pullback of %g) (assuming these calls are active)`
  • For bb1 linear map tuple would be (enum1)
  • For bb2 linear map tuple would be (enum2, pullback of %h)
  • For bb3 linear map tiple would be (enum3, pullback of %j)

The branch tracing enums would be:

enum enum1 {
  case entry((pullback of %f, pullback of %g)) // linear map tuple of entry
}

enum enum2 {
  case entry((pullback of %f, pullback of %g)) // linear map tuple of entry
}

enum enum3 {
  case bb1((enum1)) // linear map tuple of bb1
  case bb2((enum2, pullback of %h)) // linear map tuple of bb2
}

Note that linear map tuple of bb3 is created in "telescopic" fashion capturing recursively linear map tuples (with pullbacks and branch tracing enums) of its predecessors. We fill these tuples (and enums) in the forward derivative pass and use them in backward fashion in reverse pass.

So, in order to construct a linear map tuple for a BB we need to know:

  • All pullbacks of this BB
  • The linear map tuple types of all predecessor BBs

And this is what essentially what this PR does now:

  • We iterate over all BBs in reverse post-order traversal order
  • When we see the BB we already know that all its predecessors were already processed
  • Since all its predecessor BBs are processed we already have linear map tuple types of predecessors
  • We create EnumDecl, add all cases for predecessor BBs and lower the AST type to SIL type
  • We create linear map tuple type for the current BB using newly created EnumDecl plus checking all apply's of the BB. The corresponding EnumDecl is immutable now, we never touch it anymore (there is assert to check that all linear map tuple types of predecessor BBs are known by this moment).

Previously we populated EnumDecls in the very end and this obviously caused some subtle ownership issues :)

I hope this is a bit more clear and explains what is going on. If you're having some other suggestions how the things might be organized (ideally without going via AST types) in other way, we would like to hear! Unfortunately, we have to create new types inside SIL autodiff passes and this causes some churn here and there. If would be perfect if we'd not need to go this way.

@asl
Copy link
Contributor Author

asl commented Sep 11, 2023

@swift-ci please test macos

@asl asl requested a review from slavapestov September 13, 2023 04:10
@asl
Copy link
Contributor Author

asl commented Sep 13, 2023

@slavapestov Will you please let me know if you're having further comments to the explanation above? Thanks!

@asl
Copy link
Contributor Author

asl commented Sep 18, 2023

@slavapestov ping :)

@asl asl merged commit d66cfa9 into main Sep 19, 2023
@asl asl deleted the 68392-fix branch September 19, 2023 20:24
@asl
Copy link
Contributor Author

asl commented Sep 19, 2023

@slavapestov Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AutoDiff] Pullbacks w/ loops can segfault
3 participants