diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 3c33434e56..c13afbd6fa 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -36,7 +36,6 @@ from pytensor.utils import ( apply_across_args, difference, - from_return_values, to_return_values, ) @@ -1081,6 +1080,16 @@ def real_out(type): return (type,) +def _cast_to_promised_scalar_dtype(x, dtype): + try: + return x.astype(dtype) + except AttributeError: + if dtype == "bool": + return np.bool_(x) + else: + return getattr(np, dtype)(x) + + class ScalarOp(COp): nin = -1 nout = 1 @@ -1134,28 +1143,18 @@ def output_types(self, types): else: raise NotImplementedError(f"Cannot calculate the output types for {self}") - @staticmethod - def _cast_scalar(x, dtype): - if hasattr(x, "astype"): - return x.astype(dtype) - elif dtype == "bool": - return np.bool_(x) - else: - return getattr(np, dtype)(x) - def perform(self, node, inputs, output_storage): if self.nout == 1: - dtype = node.outputs[0].dtype - output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype) + output_storage[0][0] = _cast_to_promised_scalar_dtype( + self.impl(*inputs), + node.outputs[0].dtype, + ) else: - variables = from_return_values(self.impl(*inputs)) - assert len(variables) == len(output_storage) # strict=False because we are in a hot loop for out, storage, variable in zip( - node.outputs, output_storage, variables, strict=False + node.outputs, output_storage, self.impl(*inputs), strict=False ): - dtype = out.dtype - storage[0] = self._cast_scalar(variable, dtype) + storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) def impl(self, *inputs): raise MethodNotDefined("impl", type(self), self.__class__.__name__) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index a5512c6564..ec7eca76b9 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -9,8 +9,7 @@ from textwrap import dedent import numpy as np -import scipy.special -import scipy.stats +from scipy import special from pytensor.configdefaults import config from pytensor.gradient import grad_not_implemented, grad_undefined @@ -40,7 +39,6 @@ true_div, upcast, upgrade_to_float, - upgrade_to_float64, upgrade_to_float_no_complex, ) from pytensor.scalar.basic import abs as scalar_abs @@ -54,7 +52,7 @@ class Erf(UnaryScalarOp): nfunc_spec = ("scipy.special.erf", 1, 1) def impl(self, x): - return scipy.special.erf(x) + return special.erf(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -88,7 +86,7 @@ class Erfc(UnaryScalarOp): nfunc_spec = ("scipy.special.erfc", 1, 1) def impl(self, x): - return scipy.special.erfc(x) + return special.erfc(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -115,7 +113,7 @@ def c_code(self, node, name, inp, out, sub): return f"{z} = erfc(({cast}){x});" -# scipy.special.erfc don't support complex. Why? +# special.erfc don't support complex. Why? erfc = Erfc(upgrade_to_float_no_complex, name="erfc") @@ -137,7 +135,7 @@ class Erfcx(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcx", 1, 1) def impl(self, x): - return scipy.special.erfcx(x) + return special.erfcx(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -193,7 +191,7 @@ class Erfinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfinv", 1, 1) def impl(self, x): - return scipy.special.erfinv(x) + return special.erfinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -228,7 +226,7 @@ class Erfcinv(UnaryScalarOp): nfunc_spec = ("scipy.special.erfcinv", 1, 1) def impl(self, x): - return scipy.special.erfcinv(x) + return special.erfcinv(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -262,12 +260,8 @@ def c_code(self, node, name, inp, out, sub): class Owens_t(BinaryScalarOp): nfunc_spec = ("scipy.special.owens_t", 2, 1) - @staticmethod - def st_impl(h, a): - return scipy.special.owens_t(h, a) - def impl(self, h, a): - return Owens_t.st_impl(h, a) + return special.owens_t(h, a) def grad(self, inputs, grads): (h, a) = inputs @@ -291,12 +285,8 @@ def c_code(self, *args, **kwargs): class Gamma(UnaryScalarOp): nfunc_spec = ("scipy.special.gamma", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gamma(x) - def impl(self, x): - return Gamma.st_impl(x) + return special.gamma(x) def L_op(self, inputs, outputs, gout): (x,) = inputs @@ -330,12 +320,8 @@ class GammaLn(UnaryScalarOp): nfunc_spec = ("scipy.special.gammaln", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.gammaln(x) - def impl(self, x): - return GammaLn.st_impl(x) + return special.gammaln(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -374,12 +360,8 @@ class Psi(UnaryScalarOp): nfunc_spec = ("scipy.special.psi", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.psi(x) - def impl(self, x): - return Psi.st_impl(x) + return special.psi(x) def L_op(self, inputs, outputs, grads): (x,) = inputs @@ -465,12 +447,8 @@ class TriGamma(UnaryScalarOp): """ - @staticmethod - def st_impl(x): - return scipy.special.polygamma(1, x) - def impl(self, x): - return TriGamma.st_impl(x) + return special.polygamma(1, x) def L_op(self, inputs, outputs, outputs_gradients): (x,) = inputs @@ -568,12 +546,8 @@ def output_types_preference(n_type, x_type): # Scipy doesn't support it return upgrade_to_float_no_complex(x_type) - @staticmethod - def st_impl(n, x): - return scipy.special.polygamma(n, x) - def impl(self, n, x): - return PolyGamma.st_impl(n, x) + return special.polygamma(n, x) def L_op(self, inputs, outputs, output_gradients): (n, x) = inputs @@ -592,50 +566,6 @@ def c_code(self, *args, **kwargs): polygamma = PolyGamma(name="polygamma") -class Chi2SF(BinaryScalarOp): - """ - Compute (1 - chi2_cdf(x)) - ie. chi2 pvalue (chi2 'survival function') - """ - - nfunc_spec = ("scipy.stats.chi2.sf", 2, 1) - - @staticmethod - def st_impl(x, k): - return scipy.stats.chi2.sf(x, k) - - def impl(self, x, k): - return Chi2SF.st_impl(x, k) - - def c_support_code(self, **kwargs): - return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") - - def c_code(self, node, name, inp, out, sub): - x, k = inp - (z,) = out - if node.inputs[0].type in float_types: - dtype = "npy_" + node.outputs[0].dtype - return f"""{z} = - ({dtype}) 1 - GammaP({k}/2., {x}/2.);""" - raise NotImplementedError("only floatingpoint is implemented") - - def __eq__(self, other): - return type(self) is type(other) - - def __hash__(self): - return hash(type(self)) - - def c_code_cache_version(self): - v = super().c_code_cache_version() - if v: - return (2, *v) - else: - return v - - -chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf") - - class GammaInc(BinaryScalarOp): """ Compute the regularized lower gamma function (P). @@ -643,12 +573,8 @@ class GammaInc(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) - def impl(self, k, x): - return GammaInc.st_impl(k, x) + return special.gammainc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -694,12 +620,8 @@ class GammaIncC(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincc", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) - def impl(self, k, x): - return GammaIncC.st_impl(k, x) + return special.gammaincc(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -745,12 +667,8 @@ class GammaIncInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammaincinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincinv(k, x) - def impl(self, k, x): - return GammaIncInv.st_impl(k, x) + return special.gammaincinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -774,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp): nfunc_spec = ("scipy.special.gammainccinv", 2, 1) - @staticmethod - def st_impl(k, x): - return scipy.special.gammainccinv(k, x) - def impl(self, k, x): - return GammaIncCInv.st_impl(k, x) + return special.gammainccinv(k, x) def grad(self, inputs, grads): (k, x) = inputs @@ -1013,12 +927,8 @@ class GammaU(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaU.st_impl(k, x) + return special.gammaincc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1049,12 +959,8 @@ class GammaL(BinaryScalarOp): # Note there is no basic SciPy version so no nfunc_spec. - @staticmethod - def st_impl(k, x): - return scipy.special.gammainc(k, x) * scipy.special.gamma(k) - def impl(self, k, x): - return GammaL.st_impl(k, x) + return special.gammainc(k, x) * special.gamma(k) def c_support_code(self, **kwargs): return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8") @@ -1085,12 +991,8 @@ class Jv(BinaryScalarOp): nfunc_spec = ("scipy.special.jv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.jv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.jv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1114,12 +1016,8 @@ class J1(UnaryScalarOp): nfunc_spec = ("scipy.special.j1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j1(x) - def impl(self, x): - return self.st_impl(x) + return special.j1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1145,12 +1043,8 @@ class J0(UnaryScalarOp): nfunc_spec = ("scipy.special.j0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.j0(x) - def impl(self, x): - return self.st_impl(x) + return special.j0(x) def grad(self, inp, grads): (x,) = inp @@ -1176,12 +1070,8 @@ class Iv(BinaryScalarOp): nfunc_spec = ("scipy.special.iv", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.iv(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.iv(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1205,12 +1095,8 @@ class I1(UnaryScalarOp): nfunc_spec = ("scipy.special.i1", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i1(x) - def impl(self, x): - return self.st_impl(x) + return special.i1(x) def grad(self, inputs, grads): (x,) = inputs @@ -1231,12 +1117,8 @@ class I0(UnaryScalarOp): nfunc_spec = ("scipy.special.i0", 1, 1) - @staticmethod - def st_impl(x): - return scipy.special.i0(x) - def impl(self, x): - return self.st_impl(x) + return special.i0(x) def grad(self, inp, grads): (x,) = inp @@ -1257,12 +1139,8 @@ class Ive(BinaryScalarOp): nfunc_spec = ("scipy.special.ive", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.ive(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.ive(v, x) def grad(self, inputs, grads): v, x = inputs @@ -1286,12 +1164,8 @@ class Kve(BinaryScalarOp): nfunc_spec = ("scipy.special.kve", 2, 1) - @staticmethod - def st_impl(v, x): - return scipy.special.kve(v, x) - def impl(self, v, x): - return self.st_impl(v, x) + return special.kve(v, x) def L_op(self, inputs, outputs, output_grads): v, x = inputs @@ -1321,7 +1195,7 @@ class Sigmoid(UnaryScalarOp): nfunc_spec = ("scipy.special.expit", 1, 1) def impl(self, x): - return scipy.special.expit(x) + return special.expit(x) def grad(self, inp, grads): (x,) = inp @@ -1372,8 +1246,7 @@ class Softplus(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): # If x is an int8 or uint8, numpy.exp will compute the result in # half-precision (float16), where we want float32. not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8") @@ -1388,9 +1261,6 @@ def static_impl(x): else: return x - def impl(self, x): - return Softplus.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1453,16 +1323,12 @@ class Log1mexp(UnaryScalarOp): "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package" """ - @staticmethod - def static_impl(x): + def impl(self, x): if x < np.log(0.5): return np.log1p(-np.exp(x)) else: return np.log(-np.expm1(x)) - def impl(self, x): - return Log1mexp.static_impl(x) - def grad(self, inp, grads): (x,) = inp (gz,) = grads @@ -1496,7 +1362,7 @@ class BetaInc(ScalarOp): nfunc_spec = ("scipy.special.betainc", 3, 1) def impl(self, a, b, x): - return scipy.special.betainc(a, b, x) + return special.betainc(a, b, x) def grad(self, inp, grads): a, b, x = inp @@ -1756,7 +1622,7 @@ class BetaIncInv(ScalarOp): nfunc_spec = ("scipy.special.betaincinv", 3, 1) def impl(self, a, b, x): - return scipy.special.betaincinv(a, b, x) + return special.betaincinv(a, b, x) def grad(self, inputs, grads): (a, b, x) = inputs @@ -1794,12 +1660,8 @@ class Hyp2F1(ScalarOp): nin = 4 nfunc_spec = ("scipy.special.hyp2f1", 4, 1) - @staticmethod - def st_impl(a, b, c, z): - return scipy.special.hyp2f1(a, b, c, z) - def impl(self, a, b, c, z): - return Hyp2F1.st_impl(a, b, c, z) + return special.hyp2f1(a, b, c, z) def grad(self, inputs, grads): a, b, c, z = inputs diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index 76738fdb63..cb4476ede0 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -258,11 +258,6 @@ def tri_gamma_inplace(a): """second derivative of the log gamma function""" -@scalar_elemwise -def chi2sf_inplace(x, k): - """chi squared survival function""" - - @scalar_elemwise def gammainc_inplace(k, x): """regularized lower gamma function (P)""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index f11e33b41d..b185f686bc 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1154,9 +1154,10 @@ def polygamma(n, x): """Polygamma function of order n evaluated at x""" -@scalar_elemwise def chi2sf(x, k): """chi squared survival function""" + warnings.warn("chi2sf is deprecated. Use `gammaincc(k / 2, x / 2)` instead") + return gammaincc(k / 2, x / 2) @scalar_elemwise diff --git a/pytensor/tensor/xlogx.py b/pytensor/tensor/xlogx.py index 8cc27de9fb..3709688e54 100644 --- a/pytensor/tensor/xlogx.py +++ b/pytensor/tensor/xlogx.py @@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp): """ - @staticmethod - def st_impl(x): + def impl(self, x): if x == 0.0: return 0.0 return x * np.log(x) - def impl(self, x): - return XlogX.st_impl(x) - def grad(self, inputs, grads): (x,) = inputs (gz,) = grads @@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp): """ - @staticmethod - def st_impl(x, y): + def impl(self, x, y): if x == 0.0: return 0.0 return x * np.log(y) - def impl(self, x, y): - return XlogY0.st_impl(x, y) - def grad(self, inputs, grads): x, y = inputs (gz,) = grads diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 921aae826b..8f70950206 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -306,16 +306,6 @@ def scipy_special_gammal(k, x): name="Chi2SF", ) -TestChi2SFInplaceBroadcast = makeBroadcastTester( - op=inplace.chi2sf_inplace, - expected=expected_chi2sf, - good=_good_broadcast_unary_chi2sf, - eps=2e-10, - mode=mode_no_scipy, - inplace=True, - name="Chi2SF", -) - rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_binary_gamma = dict( normal=(