Skip to content

fix: make numpy_backend.tile() and jax_backend.tile() consistent #2587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ligerlac
Copy link

@ligerlac ligerlac commented May 21, 2025

This fixes a bug in the jax_backend.tile() method. Consider the following minimal example:

import jax.numpy as jnp
import pyhf

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

The last line fails with TypeError: tile requires ndarray or scalar arguments, got <class 'list'> at position 0.. However, it works fine when using the numpy backend. The problem stems from differences between np.tile and jnp.tile:

import numpy as np
import jax.numpy as jnp

tensor_in = [[[0, 1, 2]]]
repeats = (0, 1, 1)

np.tile(tensor_in, repeats)  # works fine
jnp.tile(tensor_in, repeats)  # fails with same error message as above
jnp.tile(jnp.array(tensor_in), repeats)  # works fine

Unlike jnp.tile, np.tile implicitly converts the input to the correct type.
This PR ensures tensor_in is a jnp.array to make the behaviour of numpy_backend.tile() and jax_backend.tile() consistent.

@matthewfeickert matthewfeickert changed the title Bugfix: make numpy_backend.tile() and jax_backend.tile() consistent fix: make numpy_backend.tile() and jax_backend.tile() consistent May 22, 2025
@matthewfeickert matthewfeickert added the fix A bug fix label May 22, 2025
@matthewfeickert
Copy link
Member

@ligerlac Thanks for the PR. Today I have been clawing myself out of travel related time dependent TODOs, but I can review this on Thursday (2025-05-22).

I haven't looked/thought about this yet, but I assume that this isn't something unique to tile but more generic to how things are being dealt with in spec validation of pyhf.Model (though maybe if I actually think about the PR the reason would be clear to me). Is this a general solution or more of a targeted use patch?

@matthewfeickert matthewfeickert requested review from matthewfeickert, a team and kratsg and removed request for a team May 22, 2025 06:29
@ligerlac
Copy link
Author

It's more of a patch. You are right, the problem is not unqiue to tile(). There are similar problems with concatenate():

pyhf.set_backend("jax", default=True)  # works without this line

spec = {
    "channels": [
        {
            "name": "singlechannel",
            "samples": [
                {
                    "name": "signal",
                    "data": jnp.array([0.0, 0.0, 0.0]),
                    "modifiers": [
                        {
                            "name": "mu",
                            "type": "normfactor",
                            "data": None,
                        }, 
                    ],
                },
                {
                    "name": "background",
                    "data": jnp.array([0.0, 0.0, 0.0]),  # dummy data
                    "modifiers": [
                        {
                            "name": "correlated_bkg_uncertainty",
                            "type": "histosys",
                            "data": {
                                "hi_data": jnp.array([0.0, 0.0, 0.0]),
                                "lo_data": jnp.array([0.0, 0.0, 0.0]),
                            },
                        },
                    ],
                },
            ],
        },
    ],
}

my_model = pyhf.Model(spec, validate=False)

last line fails with TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.. Again, this can be tracked down to a difference between np.concatenate() and jnp.concatenate():

np.concatenate([[True, True, True]])  # works
jnp.concatenate([[True, True, True]])  # fails
jnp.concatenate(jnp.array([[True, True, True]]))  # works

We could also patch that in the jax backend. But I guess a more elegant solution would be to make sure that each backend is only receiving arguments of the correct type by calling tensorlib.astensor in all the right places (like the _precompute() methods). I'll try to find some time over the weekend to have another look at this (including missings tests).

@ligerlac ligerlac marked this pull request as draft May 23, 2025 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix A bug fix
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

2 participants