-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Repair ode API after refactor broke it #3684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
closes pymc-devs#3676 + order of sensitivity columns now beginning with y0 + tests were aligned + order of y0, theta now everywhere
Oh, my sides.
Intended actually, but not important to the implementation. Better that you have fixed it. Thanks for all your hard work. So does this fix the sampling issue? Can you verify that the API now samples the example notebook in a reasonable time? I can't recall how fast they were at the end of GSOC, but I think the first example sampled within 10 mins or so. |
Codecov Report
@@ Coverage Diff @@
## master #3684 +/- ##
==========================================
- Coverage 89.9% 89.88% -0.02%
==========================================
Files 134 134
Lines 20166 20176 +10
==========================================
+ Hits 18130 18135 +5
- Misses 2036 2041 +5
|
Yes, the performance is now restored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing that requires immediate action, but one optional change.
This looks like it was tricky to find, and seems to generally improve how debuggable the codebase is. Thanks for keeping on it!
I can confirm that the example from the current master seems to be waay slower than it was in the run shown on docs.pymc.io. Unfortunately I won't be able to investigate before the weekend. But my next step would be to compare the merge commit of this PR against current master & look for changes in the Link to the compare tool: 5e9f349...master |
Problem
Fix #3634, which broke the ODE API. It became apparent because NUTS went nuts.
Cause
When I refactored, I changed the order of inputs to the
DifferentialEquation
such thaty0
comes beforetheta
. But because I neglected to account for this change also in the forward sensitivities, the resulting gradient was broken.Solution & Changes
In this PR, I fixed the problem by:
y0
comes beforetheta
y0
beforetheta
is done everywhereA few more improvements:
dtype=floatX
wherever possiblefloat64
in the forward integration (augment_system), becauseodeint
can't be set tofloat32
utils.py
line 51, only thetheta
part of the parameters is passed to the user-defined ODE function. In the original implementation it wasstack(theta, y0)
so, the function actually had access toy0
which was unintended (I suppose, @Dpananos ?)DtypeError
where applicable