Skip to content

Improve model debugging #4205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
twiecki opened this issue Oct 31, 2020 · 21 comments · Fixed by #6634
Closed

Improve model debugging #4205

twiecki opened this issue Oct 31, 2020 · 21 comments · Fixed by #6634

Comments

@twiecki
Copy link
Member

twiecki commented Oct 31, 2020

From Prieseman group:
"I don't whether Viola answered this question in the session, I wasn't there, but this is something I can answer as I did a large part of the development. We hadn't had much difficulties with PyMC3, if anything we were surprised I simple the development was. However I missed good debugging tools. It is quite easy to build a model where for some reason nans occur, or the gradient can't be computed. I haven't really found a good way to deal with it. There exist a monitor_mode in theano, but this involves looking through a txt file where the output of each node is printed. It works, but is somewhat cumbersome. I wrote here the different techniques we used: https://covid19-inference.readthedocs.io/en/latest/doc/debugging.html"

@ricardoV94
Copy link
Member

It would be nice to have a helper function to just debug bad initial energy (with or without jitter as you mention in those docs). Every time I get this issue I have to google for jupenglao's code snippet in the discourse FAQ. The snippet just doesn't stick in my head :)

Even better would be if the helper function was called automatically (or at least mentioned) in the error message whenever the bad initial energy occurs.

@twiecki
Copy link
Member Author

twiecki commented Nov 26, 2020

@ricardoV94 Completely agree, that would be a great contribution. Want to give that a try?

@ricardoV94
Copy link
Member

On the other hand, would such helper function become irrelevant after PR #4211 ? (Just noticed it now)

@twiecki
Copy link
Member Author

twiecki commented Nov 26, 2020

I think it could perhaps build on that PR to provide more informative output. In general anything that makes model problems easier to debug or for more informative error messages around that will be huge.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 26, 2020

Yeah, maybe the new function check_start_vals in that PR can be tweaked to be called outside of the sample function.

I imagine something like making the start parameter optional (if its missing we simply retrieve the model.test_point) and adding another argument that switches between the current behavior (raising an exception if something is nan or inf, and being silent otherwise) to a more informative behavior (printing the start log value of each RV without raising exceptions). The second behavior would be the default since that is the one users would want when calling the function directly.

The helper function would then be accessed by pm.utils.check_start_vals(model), and it would print basically the same output as the snippet we mentioned above. In addition users can pass a dictionary for the start argument, to test different values other than the model.test_point.

Or does it make more sense to write a separate function?

@twiecki
Copy link
Member Author

twiecki commented Nov 26, 2020

Yes, exactly what I had in mind.

@ricardoV94
Copy link
Member

I don't mind giving it a go, but I should wait until that PR is completed, no?

@twiecki
Copy link
Member Author

twiecki commented Nov 26, 2020

@ricardoV94 That PR will be merged very soon so you could already start from that branch.

@ricardoV94
Copy link
Member

Noob question: how can I work on someone else's fork/PR via git?

@twiecki
Copy link
Member Author

twiecki commented Nov 26, 2020

@ricardoV94 This might not work completely but something along the lines of:

git remote add stephenhogg https://github.com/StephenHogg/pymc3
git pull stephenhogg
git checkout 4116
git checkout -b improve_model_debugging

@ricardoV94
Copy link
Member

Thanks, that helped me figure it out :)

@ricardoV94
Copy link
Member

After some digging I found again I might be just trying to reinvent the wheel. The model method check_test_point seems to do exactly what I was looking for.

@MarcoGorelli
Copy link
Contributor

MarcoGorelli commented Nov 26, 2020

@ricardoV94 This might not work completely but something along the lines of:

git remote add stephenhogg https://github.com/StephenHogg/pymc3
git pull stephenhogg
git checkout 4116
git checkout -b improve_model_debugging

little "hack" I use to avoid adding a million remotes:

git checkout -b pr/StephenHogg/4116
git fetch https://github.com/StephenHogg/pymc3 4116
git reset --hard FETCH_HEAD

then, to push new commits:

git push https://github.com/StephenHogg/pymc3 HEAD:4116

@twiecki
Copy link
Member Author

twiecki commented Nov 27, 2020

Neat, didn't know about this one.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 24, 2021

Here is an idea:

One of the most common issues users face is specifying a model that has bad/infinity initial energy, usually because some observed values / prior distributions testvals violate one or more bounds defined in a distribution logp method. Right now we can show which variable(s) logp is causing bad energy with the check_test_point(), but it would be much more helpful if we could also show what specific bound in that logp is being violated, and with what parameter / observed value(s).

For instance, recently on the Discourse a user was having issues because he had a zero observation with a Gamma Likelihood:
https://discourse.pymc.io/t/why-am-i-getting-inf-or-nan-likelihood/6587/10:

image

It would be nice to have say a debug_bad_energy() (or something memorable) that would output something like:

variable logp issue problematic_values
alpha_log__ -1.06 None None
beta_log__ -1.06 None None
g -inf value <= 0 observed=[0, -1.23...]

Where we would (somehow) find the issue column by checking which bound in Gamma.logp is being invalidated and just print the inverse. This could also apply to the case where the bound is invalidated because of the testval of a prior distribution.

Here is another example from the Discourse that illustrates the last point: https://discourse.pymc.io/t/why-am-i-getting-bad-initial-energy-with-this-simple-model/6630/1

with pm.Model() as basic_model:
    theta = pm.Uniform('theta', lower=0, upper=1)

    y = pm.Uniform('y', lower=0, upper=theta,
                   observed=[0.49131995252993826, 0.2774397121516236, 0.5381935236905224, 0.19753121107715765])

print(basic_model.check_test_point())

theta_interval__ -1.39
y -inf
Name: Log-probability of test_point, dtype: float64

And we would show:

variable logp issue problematic_values
theta_interval__ -1.39 None None
y -inf value > upper upper=0.5, observed=[0.538]

The user might be surprised why is the upper value initialized at 0.5, but at least he now has some pointers to get started.

Obviously, this wouldn't cover all issues, sometimes the issue is implicitly encoded in the logp expression or helper functions called by it. However, I think most of the problematic values / priors / testvals tend to be explicitly captured by the bound checks.


Do you think something like this might be feasible? I don't have a good clue how to get there, but I think the end product could be really helpful and save a lot of users (me included) time during debugging.

@twiecki
Copy link
Member Author

twiecki commented Jan 24, 2021

This would be really cool, question is how to extract which boundary was violated.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 24, 2021

Perhaps we can at least find out which condition number is being violated by having bound take an optional extra index argument.

Current bound:
https://github.com/pymc-devs/pymc3/blob/823906a3efcf66897eac8a4c89052d9153bca49e/pymc3/distributions/dist_math.py#L89

New bound:
return tt.switch(alltrue(conditions[index]), logp, -np.inf)

Then we can at least print which condition number is being violated (by iterated testing?). But I have no idea if even something crude like that can be done. This sounds inefficient, but it's up to the user if he would want to try it. Slow debug can still be orders of magnitude faster / more illuminating than brain debugging.

Whether from here we could get to the description of the condition... I don't know. I don't think Python offers enough introspection out-of-the-box that we could use to read out the original conditions (e.g., value <= upper). A developer intensive alternative would be to label each condition manually such as:

return bound(
  <expression>,
  (value > 0, "value <= 0"),
  (value <= upper, "value > upper"),
)

Then we could optionally return them?


Or can we somehow take advantage of Theano's graph abilities?

@twiecki
Copy link
Member Author

twiecki commented Jan 24, 2021

Yes, some clever graph analysis could help here. You could just theano debug print the graph and see if it's in there in a parseable way. I think the answer should be yes and it'd be great for you to learn a bit more about Theano internals.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 26, 2021

So I was able to automatically find the bound switches in the logpt of each RV (using a very very hugly hack based on the string representation of the graph nodes), and I am able to disable one check bound at a time to check if that is behind any initial -infs:

with pm.Model() as model:
    x = pm.Normal('x', -1, 2)
    obs = pm.HalfNormal('obs', sd=x, observed=np.array([-12, 21]))

debug_bad_energy(model)

Which prints the following almost readable output:

The following bound(s) of obs ~ HalfNormal were violated:

Elemwise{ge,no_inplace}.0
 | TensorConstant{[-12.  21.]}
 | TensorConstant{0}

Elemwise{gt,no_inplace}.0
 | f(x ~ Normal) (?) = -1.0
 | TensorConstant{0}

The first one is due to negative data, and the second one due to negative sd coming from x

I don't know how to recognize that the TensorConstant{[-12. 21.]} is the observed, but from here we are not very far from getting an output like:

The following bound(s) of obs ~ HalfNormal were violated:

obs ~ HalfNormal = [-12. 21.] >= 0
x ~ Normal = -1.0 > 0

***

This is a snippet of the code (sharing in the hope of getting some advice):

def debug_bad_energy(model):
    test_point = model.test_point
    for variable in model.basic_RVs:
        if not(np.isfinite(variable.logp(test_point))):
            debug_bounds(model, variable)

def debug_bounds(model, variable):
    first_bound = True
    bound_conditions = find_bound_conds(variable.logpt.owner)   # ugly hack used here (suggestions are welcome!)
    
    # Test bound conditions
    for enabled_bound in bound_conditions:
        # Disable all switches except one (another ugly hack, there must be a better way)
        new_logpt = theano.clone(
            variable.logpt, 
            {bound: tt.eq(bound.owner.inputs[0], bound.owner.inputs[0])  # input == input
             for bound in bound_conditions if bound != enabled_bound}
        )

        output = model.fn(new_logpt)(model.test_point)

        if not(np.all(np.isfinite(output))):
            if first_bound:
                first_bound = False
                print(f'The following bound(s) of {variable} were violated:')
                print('')

            print(enabled_bound)
            find_switch_IVs(enabled_bound)  # Tries to find and print the child nodes of the bound comparison (suggestions are welcome!)
            print('')

@twiecki
Copy link
Member Author

twiecki commented Jan 27, 2021

This is great progress, I think with some parsing we can turn this into the string you had above.

@ricardoV94
Copy link
Member

BTW, we now have slightly more informative error messages, when an -inf probability comes from violating a parameter constraint (not when the value itself is out of the support). We could compile a function without compile_pymc so that those errors show as exceptions:

import aesara
import pymc as pm

with pm.Model() as model:
  x = pm.Normal("x", 3, 1)
  p = pm.Bernoulli("p", p=x, observed=0)

ip = model.initial_point()
fn = aesara.function(model.value_vars, model.logp())
fn(**ip)  # ParameterValueError: 0 <= p <= 1

We could perhaps also try to copy the stack trace from the RV to the value_var so that it shows which RV it comes from, but that would require more work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants