diff --git a/pytensor/tensor/ssignal.py b/pytensor/tensor/ssignal.py new file mode 100644 index 0000000000..dae02d62ba --- /dev/null +++ b/pytensor/tensor/ssignal.py @@ -0,0 +1,38 @@ +import scipy.signal as scipy_signal + +from pytensor.graph.basic import Apply +from pytensor.tensor import Op, as_tensor_variable +from pytensor.tensor.type import TensorType + + +class GaussSpline(Op): + __props__ = ("n",) + + def __init__(self, n: int): + self.n = n + + def make_node(self, knots): + knots = as_tensor_variable(knots) + if not isinstance(knots.type, TensorType): + raise TypeError("Input must be a TensorType") + + if not isinstance(self.n, int) or self.n is None or self.n < 0: + raise ValueError("n must be a non-negative integer") + + if knots.ndim < 1: + raise TypeError("Input must be at least 1-dimensional") + + out = knots.type() + return Apply(self, [knots], [out]) + + def perform(self, node, inputs, output_storage): + [x] = inputs + [out] = output_storage + out[0] = scipy_signal.gauss_spline(x, self.n) + + def infer_shape(self, fgraph, node, shapes): + return [shapes[0]] + + +def gauss_spline(x, n): + return GaussSpline(n)(x) diff --git a/tests/tensor/test_ssignal.py b/tests/tensor/test_ssignal.py new file mode 100644 index 0000000000..0b53ccfc9f --- /dev/null +++ b/tests/tensor/test_ssignal.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest +import scipy.signal as scipy_signal + +from pytensor import function +from pytensor.tensor.ssignal import GaussSpline, gauss_spline +from pytensor.tensor.type import matrix +from tests import unittest_tools as utt + + +class TestGaussSpline(utt.InferShapeTester): + def setup_method(self): + super().setup_method() + self.op_class = GaussSpline + self.op = gauss_spline + + @pytest.mark.parametrize("n", [-1, 1.5, None, "string"]) + def test_make_node_raises(self, n): + a = matrix() + with pytest.raises(ValueError, match="n must be a non-negative integer"): + self.op(a, n=n) + + def test_perform(self): + a = matrix() + f = function([a], self.op(a, n=10)) + a = np.random.random((8, 6)) + assert np.allclose(f(a), scipy_signal.gauss_spline(a, 10)) + + def test_infer_shape(self): + a = matrix() + self._compile_and_check( + [a], [self.op(a, 16)], [np.random.random((12, 4))], self.op_class + )