Skip to content

ENH: pad: add delegation #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
expand_dims
kron
nunique
pad
setdiff1d
sinc
```
969 changes: 512 additions & 457 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ tests-backends = ["py310", "tests", "backends", "cuda-backends"]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = ["error"]
filterwarnings = [
"error",
# TODO: when Python 3.10 is dropped, use `enum.member` in `_delegation.py`
"ignore:functools.partial will be a method descriptor:FutureWarning",
]
log_cli_level = "INFO"
testpaths = ["tests"]
markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific backend"]
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import (
from ._delegation import pad
from ._lib._funcs import (
at,
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
nunique,
pad,
setdiff1d,
sinc,
)
Expand Down
83 changes: 83 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Delegation to existing implementations for Public API Functions."""

from types import ModuleType
from typing import Literal

from ._lib import Backend, _funcs
from ._lib._utils._compat import array_namespace
from ._lib._utils._typing import Array

__all__ = ["pad"]


def _delegate(xp: ModuleType, *backends: Backend) -> bool:
"""
Check whether `xp` is one of the `backends` to delegate to.

Parameters
----------
xp : array_namespace
Array namespace to check.
*backends : IsNamespace
Arbitrarily many backends (from the ``IsNamespace`` enum) to check.

Returns
-------
bool
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
"""
return any(backend.is_namespace(xp) for backend in backends)


def pad(
x: Array,
pad_width: int | tuple[int, int] | list[tuple[int, int]],
mode: Literal["constant"] = "constant",
*,
constant_values: bool | int | float | complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.

Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or list of pairs of ints
Pad the input array with this many elements from each side.
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp

if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if _delegate(xp, Backend.TORCH):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

Check warning on line 78 in src/array_api_extra/_delegation.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_delegation.py#L75-L78

Added lines #L75 - L78 were not covered by tests

if _delegate(xp, Backend.NUMPY, Backend.JAX_NUMPY, Backend.CUPY):
return xp.pad(x, pad_width, mode, constant_values=constant_values)

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
6 changes: 5 additions & 1 deletion src/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Modules housing private functions."""
"""Internals of array-api-extra."""

from ._backends import Backend

__all__ = ["Backend"]
59 changes: 59 additions & 0 deletions src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Backends with which array-api-extra interacts in delegation and testing."""

from collections.abc import Callable
from enum import Enum
from types import ModuleType
from typing import cast

from ._utils import _compat

__all__ = ["Backend"]


class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-any]
"""
All array library backends explicitly tested by array-api-extra.

Parameters
----------
value : str
String describing the backend.
is_namespace : Callable[[ModuleType], bool]
Function to check whether an input module is the array namespace
corresponding to the backend.
module_name : str
Name of the backend's module.
"""

ARRAY_API_STRICT = (
"array_api_strict",
_compat.is_array_api_strict_namespace,
"array_api_strict",
)
NUMPY = "numpy", _compat.is_numpy_namespace, "numpy"
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace, "numpy"
CUPY = "cupy", _compat.is_cupy_namespace, "cupy"
TORCH = "torch", _compat.is_torch_namespace, "torch"
DASK_ARRAY = "dask.array", _compat.is_dask_namespace, "dask.array"
SPARSE = "sparse", _compat.is_pydata_sparse_namespace, "sparse"
JAX_NUMPY = "jax.numpy", _compat.is_jax_namespace, "jax.numpy"

def __new__(
cls, value: str, _is_namespace: Callable[[ModuleType], bool], _module_name: str
): # numpydoc ignore=GL08
obj = object.__new__(cls)
obj._value_ = value
return obj

def __init__(
self,
value: str, # noqa: ARG002 # pylint: disable=unused-argument
is_namespace: Callable[[ModuleType], bool],
module_name: str,
): # numpydoc ignore=GL08
self.is_namespace = is_namespace
self.module_name = module_name

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return cast(str, self.value)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Public API Functions."""
"""Array-agnostic implementations for the public API."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except.... it isn't agnostic, see for example the special paths in at and nunique

Copy link
Member Author

@lucascolley lucascolley Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up.

# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations
Expand All @@ -11,13 +11,9 @@
from types import ModuleType
from typing import ClassVar, cast

from ._lib import _compat, _utils
from ._lib._compat import (
array_namespace,
is_jax_array,
is_writeable_array,
)
from ._lib._typing import Array, Index
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
from ._utils._typing import Array, Index

__all__ = [
"at",
Expand Down Expand Up @@ -151,7 +147,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _utils.mean(m, axis=1, xp=xp)
avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand Down Expand Up @@ -467,7 +463,7 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down Expand Up @@ -562,54 +558,18 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
def pad(
x: Array,
pad_width: int | tuple[int, int] | list[tuple[int, int]],
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
) -> Array:
"""
Pad the input array.

Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or list of pairs of ints
Pad the input array with this many elements from each side.
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.

Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

value = constant_values

xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
# make pad_width a list of length-2 tuples of ints
x_ndim = cast(int, x.ndim)
if isinstance(pad_width, int):
pad_width = [(pad_width, pad_width)] * x_ndim
if isinstance(pad_width, tuple):
pad_width = [pad_width] * x_ndim

if xp is None:
xp = array_namespace(x)

# https://github.com/python/typeshed/issues/13376
slices: list[slice] = [] # type: ignore[no-any-explicit]
newshape: list[int] = []
Expand All @@ -633,7 +593,7 @@ def pad(

padded = xp.full(
tuple(newshape),
fill_value=value,
fill_value=constant_values,
dtype=x.dtype,
device=_compat.device(x),
)
Expand Down
6 changes: 4 additions & 2 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Note that this is private API; don't expect it to be stable.
"""

from ._compat import (
from types import ModuleType

from ._utils._compat import (
array_namespace,
is_cupy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array, ModuleType
from ._utils._typing import Array

__all__ = ["xp_assert_close", "xp_assert_equal"]

Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules housing private utility functions."""
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
# `array-api-compat` to override the import location

try:
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
from ...._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace,
device,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
Expand All @@ -18,9 +21,12 @@
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace,
device,
is_array_api_strict_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
is_writeable_array,
Expand All @@ -30,9 +36,12 @@
__all__ = [
"array_namespace",
"device",
"is_array_api_strict_namespace",
"is_cupy_namespace",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_namespace",
"is_pydata_sparse_namespace",
"is_torch_namespace",
"is_writeable_array",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ...
def device(x: Array, /) -> Device: ...
def is_cupy_namespace(x: object, /) -> bool: ...
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
def is_jax_namespace(x: object, /) -> bool: ...
def is_pydata_sparse_namespace(x: object, /) -> bool: ...
def is_torch_namespace(x: object, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Utility functions used by `array_api_extra/_funcs.py`."""
"""Helper functions used by `array_api_extra/_funcs.py`."""

# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

from types import ModuleType

from . import _compat
from ._typing import Array, ModuleType
from ._typing import Array

__all__ = ["in1d", "mean"]

Expand Down
Loading
Loading