Skip to content

Make GEMV more robust to zero strided inputs #1266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 73 additions & 55 deletions pytensor/tensor/blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,15 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
char NOTRANS = 'N';
int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1];
/* This formula is needed in the case where A is actually a row or
* column matrix, because BLAS sometimes insists that the strides:
* - are not smaller than the number of elements in the array
* - are not 0.
int Nx = PyArray_DIMS(%(x)s)[0];
/* If A or x have length 1 dimensions, the respective strides don't matter
* However, BLAS often insists that the strides be not zero nor smaller than
* the number of elements in the array. We set them to 1 arbitrarily;
*/
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : 1;
int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : 1;
int Sx = (Nx > 1) ? PyArray_STRIDES(%(x)s)[0] / elemsize: 1;
int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;

dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
Expand All @@ -435,62 +435,49 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non

if (NA0 * NA1)
{
// If A is neither C- nor F-contiguous, we make a copy.
// TODO:
// - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
if ((PyArray_STRIDES(%(A)s)[0] < 0)
|| (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize)
&& (PyArray_STRIDES(%(A)s)[1] != elemsize)))
// Non-empty branch

if (Sx == 0)
{
npy_intp dims[2];
dims[0] = NA0;
dims[1] = NA1;
// This is a broadcasted vector with length > 1 and a stride of 0.
// We need to make a full copy of it.

PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
%(A)s);
if (!A_copy)
PyArrayObject * x_copy = (PyArrayObject *) PyArray_Copy(%(x)s);
if (!x_copy)
%(fail)s
Py_XDECREF(%(A)s);
%(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
Py_XDECREF(%(x)s);
%(x)s = x_copy;
x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
}

if (PyArray_STRIDES(%(A)s)[0] == elemsize)
if (
(PyArray_STRIDES(%(A)s)[0] < 0) || (PyArray_STRIDES(%(A)s)[1] < 0)
|| ((PyArray_STRIDES(%(A)s)[0] != elemsize) && (PyArray_STRIDES(%(A)s)[1] != elemsize))
)
{
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)x_data, &Sx,
&fbeta,
(float*)z_data, &Sz);
}
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
// If A is neither C- nor F-contiguous, we make a copy.
// TODO:
// - if one stride is equal to "- elemsize", we can still call
// gemv on reversed matrix and vectors
// - if the copy is too long, maybe call vector/vector dot on
// each row instead
PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
if (!A_copy)
%(fail)s
}
Py_XDECREF(%(A)s);
%(A)s = A_copy;
SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : 1;
SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : 1;
}
else if (PyArray_STRIDES(%(A)s)[1] == elemsize)


if (PyArray_STRIDES(%(A)s)[1] == elemsize)
{
// C-contiguous branch
// May also be F-contiguous, but we give preference to it,
// because it has special handling for the A row/col matrix

if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
Expand Down Expand Up @@ -554,10 +541,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
%(fail)s
}
}
else if (PyArray_STRIDES(%(A)s)[0] == elemsize)
{
// Fortran order branch
if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
{
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
sgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(float*)(PyArray_DATA(%(A)s)), &SA1,
(float*)x_data, &Sx,
&fbeta,
(float*)z_data, &Sz);
}
else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &NA0, &NA1,
&alpha,
(double*)(PyArray_DATA(%(A)s)), &SA1,
(double*)x_data, &Sx,
&dbeta,
(double*)z_data, &Sz);
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"xx is a double-strided matrix, and should have been "
"A is a double-strided matrix, and should have been "
"copied into a memory-contiguous one.");
%(fail)s
}
Expand Down Expand Up @@ -603,6 +620,7 @@ def c_code(self, node, name, inp, out, sub):
return code

def c_code_cache_version(self):
return None
return (14, blas_header_version(), check_force_gemv_init())


Expand Down
110 changes: 68 additions & 42 deletions tests/tensor/test_blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
import pytest
from numpy.lib.stride_tricks import as_strided

import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init
Expand Down Expand Up @@ -199,53 +201,77 @@ def test_force_gemv_init(self):
" degradation in performance for such calls."
)

def t_gemv1(self, m_shp):
"""test vector2 + dot(matrix, vector1)"""
@pytest.mark.skipif(config.blas__ldflags == "", reason="No blas")
@pytest.mark.parametrize(
"A_shape",
[(3, 2), (1, 2), (0, 2), (3, 1), (3, 0), (1, 0), (1, 1), (0, 1), (0, 0)],
ids=str,
)
@pytest.mark.parametrize("inplace", [True, False])
def test_gemv1(self, A_shape, inplace: bool):
"""test y + dot(A, x)"""
rng = np.random.default_rng(unittest_tools.fetch_seed())
v1 = pytensor.shared(np.array(rng.uniform(size=(m_shp[1],)), dtype="float32"))
v2_orig = np.array(rng.uniform(size=(m_shp[0],)), dtype="float32")
v2 = pytensor.shared(v2_orig)
m = pytensor.shared(np.array(rng.uniform(size=m_shp), dtype="float32"))

f = pytensor.function([], v2 + pt.dot(m, v1), mode=self.mode)

# Assert they produce the same output
assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig)
topo = [n.op for n in f.maker.fgraph.toposort()]
assert topo == [CGemv(inplace=False)], topo

# test the inplace version
g = pytensor.function(
[], [], updates=[(v2, v2 + pt.dot(m, v1))], mode=self.mode
y = pt.vector("y", dtype="float32")
x = pt.vector("x", dtype="float32")
A = pt.matrix("A", dtype="float32")
alpha = beta = 1.0

out = CGemv(inplace=inplace)(y, alpha, A, x, beta)
f = pytensor.function([y, A, x], out, mode=self.mode, accept_inplace=inplace)
f.dprint()
assert [node.op for node in f.maker.fgraph.toposort()] == [
CGemv(inplace=inplace)
]

def assert_expected_output(inplace, f, y_test, A_test, x_test):
# Copy y with the same strides as the original one
y_test_copy = y_test.copy()
y_test_copy = as_strided(
y_test_copy, shape=y_test.shape, strides=y_test.strides
)
res = f(y_test_copy, A_test, x_test)
if inplace:
res = y_test_copy
else:
np.testing.assert_array_equal(y_test, y_test_copy)
np.testing.assert_allclose(res, y_test + A_test @ x_test)

y_test = rng.uniform(size=A_shape[0]).astype("float32")
A_test = rng.uniform(size=A_shape).astype("float32")
x_test = rng.uniform(size=A_shape[1]).astype("float32")
assert_expected_output(inplace, f, y_test, A_test, x_test)

## Fortran order
y_test_fortran = np.asfortranarray(y_test)
A_test_fortran = np.asfortranarray(A_test)
x_test_fortran = np.asfortranarray(x_test)
assert_expected_output(
inplace, f, y_test_fortran, A_test_fortran, x_test_fortran
)

# Assert they produce the same output
g()
assert np.allclose(
v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig
)
topo = [n.op for n in g.maker.fgraph.toposort()]
assert topo == [CGemv(inplace=True)]

# Do the same tests with a matrix with strides in both dimensions
m.set_value(m.get_value(borrow=True)[::-1, ::-1], borrow=True)
v2.set_value(v2_orig)
assert np.allclose(f(), np.dot(m.get_value(), v1.get_value()) + v2_orig)
g()
assert np.allclose(
v2.get_value(), np.dot(m.get_value(), v1.get_value()) + v2_orig
## Negative strides (or zero when size is zero)
y_test_neg_strides = y_test[::-1]
assert y_test_neg_strides.strides[0] in (-4, 0)
A_test_neg_strides = A_test[::-1, ::-1]
assert A_test_neg_strides.strides[1] in (-4, 0)
x_test_neg_strides = x_test[::-1]
assert x_test_neg_strides.strides[0] in (-4, 0)
# assert_expected_output(inplace, f, y_test_neg_strides, A_test_neg_strides, x_test_neg_strides)

# Zero strides (by broadcasting)
y_test_0_strides = np.broadcast_to(np.array(np.pi, dtype="float32"), A_shape[0])
assert y_test_0_strides.strides == (0,)
A_test_0_strides = np.broadcast_to(np.array(np.e, dtype="float32"), A_shape)
assert A_test_0_strides.strides == (0, 0)
x_test_0_strides = np.broadcast_to(
np.array(np.euler_gamma, dtype="float32"), A_shape[1]
)

def test_gemv1(self):
skip_if_blas_ldflags_empty()
self.t_gemv1((3, 2))
self.t_gemv1((1, 2))
self.t_gemv1((0, 2))
self.t_gemv1((3, 1))
self.t_gemv1((3, 0))
self.t_gemv1((1, 0))
self.t_gemv1((0, 1))
self.t_gemv1((0, 0))
assert x_test_0_strides.strides == (0,)
# Test one input at a time so the outputs are unique
assert_expected_output(inplace, f, y_test, A_test, x_test_0_strides)
assert_expected_output(inplace, f, y_test, A_test_0_strides, x_test)
# assert_expected_output(inplace, f, y_test_0_strides, A_test, x_test)

def test_gemv_dimensions(self, dtype="float32"):
alpha = pytensor.shared(np.asarray(1.0, dtype=dtype), name="alpha")
Expand Down
Loading