Skip to content

Support copy and device keywords in from_dlpack #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 14, 2024
41 changes: 36 additions & 5 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __dlpack__(
*,
stream: Optional[Union[int, Any]] = None,
max_version: Optional[tuple[int, int]] = None,
dl_device: Optional[Tuple[Enum, int]] = None,
copy: Optional[bool] = None
) -> PyCapsule:
"""
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
Expand Down Expand Up @@ -324,6 +326,12 @@ def __dlpack__(
- ``> 2``: stream number represented as a Python integer.
- Using ``1`` and ``2`` is not supported.

.. note::
When ``dl_device`` is provided explicitly, ``stream`` must be a valid
construct for the specified device type. In particular, when ``kDLCPU``
is in use, ``stream`` must be ``None`` and a synchronization must be
performed to ensure data safety.

.. admonition:: Tip
:class: important

Expand All @@ -333,12 +341,30 @@ def __dlpack__(
not want to think about stream handling at all, potentially at the
cost of more synchronizations than necessary.
max_version: Optional[tuple[int, int]]
The maximum DLPack version that the *consumer* (i.e., the caller of
the maximum DLPack version that the *consumer* (i.e., the caller of
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
This method may return a capsule of version ``max_version`` (recommended
if it does support that), or of a different version.
This means the consumer must verify the version even when
`max_version` is passed.
dl_device: Optional[Tuple[Enum, int]]
the DLPack device type. Default is ``None``, meaning the exported capsule
should be on the same device as ``self`` is. When specified, the format
must be a 2-tuple, following that of the return value of :meth:`array.__dlpack_device__`.
If the device type cannot be handled by the producer, this function must
raise ``BufferError``.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the
function must always copy (paerformed by the producer), potentially allowing
data movement across the library (and/or device) boundary. If ``False``,
the function must never copy. If ``None``, the function must reuse existing
memory buffer if possible and copy otherwise. Default: ``None``.

When a copy happens, the ``DLPACK_FLAG_BITMASK_IS_COPIED`` flag must be set.

.. note::
If a copy happens, and if the consumer-provided ``stream`` and ``dl_device``
can be understood by the producer, the copy must be performed over ``stream``.

Returns
-------
Expand Down Expand Up @@ -394,22 +420,25 @@ def __dlpack__(
# here to tell users that the consumer's max_version is too
# old to allow the data exchange to happen.

And this logic for the consumer in ``from_dlpack``:
And this logic for the consumer in :func:`~array_api.from_dlpack`:

.. code:: python

try:
x.__dlpack__(max_version=(1, 0))
x.__dlpack__(max_version=(1, 0), ...)
# if it succeeds, store info from the capsule named "dltensor_versioned",
# and need to set the name to "used_dltensor_versioned" when we're done
except TypeError:
x.__dlpack__()
x.__dlpack__(...)

This logic is also applicable to handling of the new ``dl_device`` and ``copy``
keywords.

.. versionchanged:: 2022.12
Added BufferError.

.. versionchanged:: 2023.12
Added the ``max_version`` keyword.
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
"""

def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
Expand All @@ -436,6 +465,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
METAL = 8
VPI = 9
ROCM = 10
CUDA_MANAGED = 13
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
Expand Down
31 changes: 26 additions & 5 deletions src/array_api_stubs/_draft/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


from ._types import (
Any,
List,
NestedSequence,
Optional,
Expand Down Expand Up @@ -214,19 +215,36 @@ def eye(
"""


def from_dlpack(x: object, /) -> array:
def from_dlpack(
x: object, /, *,
device: Optional[device] = None,
copy: Optional[bool] = None
) -> Union[array, Any]:
"""
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.

Parameters
----------
x: object
input (array) object.
device: Optional[device]
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array device must be inferred from ``x.device``. Default: ``None``.

The v2023.12 standard only mandates that a compliant library must offer a way for ``from_dlpack`` to create an array on CPU (using
a library-specifc way to represent the CPU device (``kDLCPU`` in DLPack) e.g. a ``"CPU"`` string or a ``Device("CPU")`` object).
If the array library does not support the CPU device and needs to outsource to another (compliant) array library, it may do so
with a clear user documentation and/or run-time warning. If a copy must be made to enable this, and ``copy`` is set to ``False``,
the function must raise ``ValueError``.

Other kinds of devices will be considered for standardization in a future version of this API standard.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy and must raise a ``BufferError`` in case a copy is deemed necessary (e.g. the producer disallows views). If ``None``, the function must reuse the existing memory buffer if possible and copy otherwise. Default: ``None``.

Returns
-------
out: array
an array containing the data in `x`.
out: Union[array, Any]
an array containing the data in ``x``. In the case that the compliant library does not support the given ``device`` out of box
and must oursource to another (compliant) library, the output will be that library's compliant array object.

.. admonition:: Note
:class: note
Expand All @@ -238,9 +256,9 @@ def from_dlpack(x: object, /) -> array:
BufferError
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
may raise ``BufferError`` when the data cannot be exported as DLPack
(e.g., incompatible dtype or strides). It may also raise other errors
(e.g., incompatible dtype, strides, or device). It may also raise other errors
when export fails for other reasons (e.g., not enough memory available
to materialize the data). ``from_dlpack`` must propagate such
to materialize the data, a copy must made, etc). ``from_dlpack`` must propagate such
exceptions.
AttributeError
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
Expand All @@ -251,6 +269,9 @@ def from_dlpack(x: object, /) -> array:
-----
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
order to handle DLPack versioning correctly.

.. versionchanged:: 2023.12
Added device and copy support.
"""


Expand Down