-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Fix quite subtle but nasty bug in linear map tuple types computation #68413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Tagging @BradLarson @jkshtj |
@swift-ci please test |
To give more context: previously we calculated lowered type of Enum out of empty EnumDecl. As a result, all these SIL types were deemed as |
@@ -331,10 +333,24 @@ void LinearMapInfo::generateDifferentiationDataStructures( | |||
} | |||
|
|||
// Add linear map fields to the linear map tuples. | |||
for (auto &origBB : *original) { | |||
// Now we need to be very careful as we're having a very subtle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain why the reverse post-order traversal(RPOT) is needed and how it is helping?
I understand that we need the branch trace enum for a BB to be fully constructed before using it to derive the type for the linear map tuple of the same BB. But I don't understand how going in RPOT order is going to ensure that we have the right types for the branch trace enums?
I think so because while fully constructing a branch trace enum declaration in populateBranchingTraceDecl
we look up the linear map tuple types of the predecessor BBs, but these types haven't been set yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When traversing BBs in RPOT we ensure that each BB is processed after its predecessors. As a result, we know that all linear map tuples for predecessor BBs are already finalized and therefore the enum type that we created will also be "complete" – we will not need to add any entries later.
The problem was not linear map tuples, but that we passed "incomplete" EnumDecl
to getBranchingTraceEnumLoweredType
. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When traversing BBs in RPOT we ensure that each BB is processed after its predecessors.
Ah gotcha. I was mistaking RPOT for regular post order traversal but (Right->Left->Node) instead of (Left->Right->Node). But it seems like it's more or less a "pre-order" traversal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's neither a pre-order nor a post-order, it's not a DFS. We cannot use DFS because we're having a DAG, not a tree. Think about A -> B -> C -> D; B -> D
CFG. We need to visit D
after both B
and C
. RPOT is quite a standard technique for various data-flow problems (and could be used to compute topological sorting of the graph).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, thanks for the explanation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).
You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.
we need lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB). As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples. Traverse basic blocks in reverse post-order traver order building linear map tuples and branch tracing enumns in one go, ensuring that we've done with predecessor BBs before processing the BB itself.
@swift-ci please test |
@@ -331,10 +333,24 @@ void LinearMapInfo::generateDifferentiationDataStructures( | |||
} | |||
|
|||
// Add linear map fields to the linear map tuples. | |||
for (auto &origBB : *original) { | |||
// Now we need to be very careful as we're having a very subtle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).
You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.
@swift-ci please test macos |
Exactly! And this was what the bug fixed. Now we are building enum AST types before corresponding SIL types and the iteration order enforces this. It is not quite possible to build all AST types before the SILGen due to way how autodiff is done. In particular, the element types of these enums correspond to the SIL BBs. Do you see any sane way how we can do the things before SILgen? Tagging @rxwei |
Can you enforce this by collecting all the constituents that would go into building the enum somewhere during SILGen, then create the EnumDecl and populate it in one shot, before handing it off? That way you can never access the incomplete declaration. (That is, pretend EnumDecl was immutable after construction, and write your code like that.) |
Ok, this is what I tried to explain in the comment (chicken-and-egg problem), but seems the explanation was confusing. Let me explain what is going on in detail. Consider we're having the following SIL function:
In order to generate derivative for this function we need to:
This is done via so-called linear map tuples. We create tuple for each BB capturing the pullbacks of the active calls. Also, for all but the entry BB we're creating branch tracing enums with entries corresponding to predecessor BBs and the payloads being linear map tuples of predecessors. So, for the function above we will have:
The branch tracing enums would be:
Note that linear map tuple of So, in order to construct a linear map tuple for a BB we need to know:
And this is what essentially what this PR does now:
Previously we populated EnumDecls in the very end and this obviously caused some subtle ownership issues :) I hope this is a bit more clear and explains what is going on. If you're having some other suggestions how the things might be organized (ideally without going via AST types) in other way, we would like to hear! Unfortunately, we have to create new types inside SIL autodiff passes and this causes some churn here and there. If would be perfect if we'd not need to go this way. |
@swift-ci please test macos |
@slavapestov Will you please let me know if you're having further comments to the explanation above? Thanks! |
@slavapestov ping :) |
@slavapestov Thanks! |
We need lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB).
As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples.
Traverse basic blocks in reverse post-order traver order building linear map tuples and branch tracing enumns in one go, ensuring that we've done with predecessor BBs before processing the BB itself.
Fixes #68392