Skip to content

Commit 83420d2

Browse files
rgommersleofangkgryteseberg
authored
Add versioning support to DLPack APIs (#602)
* Add versioning support to DLPack APIs xref dmlc/dlpack#116 * Address review comment, replace ">=2 years" by "from March 2025" * nit: re-order * improvements & fixes * Satisfy linter * Satisfy linter * Update src/array_api_stubs/_draft/array_object.py Co-authored-by: Sebastian Berg <[email protected]> --------- Co-authored-by: Leo Fang <[email protected]> Co-authored-by: Athan <[email protected]> Co-authored-by: Sebastian Berg <[email protected]>
1 parent 425e9eb commit 83420d2

File tree

2 files changed

+83
-24
lines changed

2 files changed

+83
-24
lines changed

Diff for: src/array_api_stubs/_draft/array_object.py

+78-24
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ def __complex__(self: array, /) -> complex:
288288
"""
289289

290290
def __dlpack__(
291-
self: array, /, *, stream: Optional[Union[int, Any]] = None
291+
self: array,
292+
/,
293+
*,
294+
stream: Optional[Union[int, Any]] = None,
295+
max_version: Optional[tuple[int, int]] = None,
292296
) -> PyCapsule:
293297
"""
294298
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
@@ -298,45 +302,43 @@ def __dlpack__(
298302
self: array
299303
array instance.
300304
stream: Optional[Union[int, Any]]
301-
for CUDA and ROCm, a Python integer representing a pointer to a stream, on devices that support streams. ``stream`` is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via "wait for event"). The pointer must be a positive integer or ``-1``. If ``stream`` is ``-1``, the value may be used by the consumer to signal "producer must not perform any synchronization". The ownership of the stream stays with the consumer. On CPU and other device types without streams, only ``None`` is accepted.
302-
303-
For other device types which do have a stream, queue or similar synchronization mechanism, the most appropriate type to use for ``stream`` is not yet determined. E.g., for SYCL one may want to use an object containing an in-order ``cl::sycl::queue``. This is allowed when libraries agree on such a convention, and may be standardized in a future version of this API standard.
304-
305-
306-
.. note::
307-
Support for a ``stream`` value other than ``None`` is optional and implementation-dependent.
305+
for CUDA and ROCm, a Python integer representing a pointer to a stream, on devices that support streams. ``stream`` is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via "wait for event"). The pointer must be an integer larger than or equal to ``-1`` (see below for allowed values on each platform). If ``stream`` is ``-1``, the value may be used by the consumer to signal "producer must not perform any synchronization". The ownership of the stream stays with the consumer. On CPU and other device types without streams, only ``None`` is accepted.
308306
307+
For other device types which do have a stream, queue, or similar synchronization/ordering mechanism, the most appropriate type to use for ``stream`` is not yet determined. E.g., for SYCL one may want to use an object containing an in-order ``cl::sycl::queue``. This is allowed when libraries agree on such a convention, and may be standardized in a future version of this API standard.
309308
310-
Device-specific notes:
309+
.. note::
310+
Support for a ``stream`` value other than ``None`` is optional and implementation-dependent.
311311
312-
313-
.. admonition:: CUDA
314-
:class: note
312+
Device-specific values of ``stream`` for CUDA:
315313
316314
- ``None``: producer must assume the legacy default stream (default).
317315
- ``1``: the legacy default stream.
318316
- ``2``: the per-thread default stream.
319317
- ``> 2``: stream number represented as a Python integer.
320318
- ``0`` is disallowed due to its ambiguity: ``0`` could mean either ``None``, ``1``, or ``2``.
321319
322-
323-
.. admonition:: ROCm
324-
:class: note
320+
Device-specific values of ``stream`` for ROCm:
325321
326322
- ``None``: producer must assume the legacy default stream (default).
327323
- ``0``: the default stream.
328324
- ``> 2``: stream number represented as a Python integer.
329325
- Using ``1`` and ``2`` is not supported.
330326
331-
332-
.. admonition:: Tip
333-
:class: important
334-
335-
It is recommended that implementers explicitly handle streams. If
336-
they use the legacy default stream, specifying ``1`` (CUDA) or ``0``
337-
(ROCm) is preferred. ``None`` is a safe default for developers who do
338-
not want to think about stream handling at all, potentially at the
339-
cost of more synchronization than necessary.
327+
.. admonition:: Tip
328+
:class: important
329+
330+
It is recommended that implementers explicitly handle streams. If
331+
they use the legacy default stream, specifying ``1`` (CUDA) or ``0``
332+
(ROCm) is preferred. ``None`` is a safe default for developers who do
333+
not want to think about stream handling at all, potentially at the
334+
cost of more synchronizations than necessary.
335+
max_version: Optional[tuple[int, int]]
336+
The maximum DLPack version that the *consumer* (i.e., the caller of
337+
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
338+
This method may return a capsule of version ``max_version`` (recommended
339+
if it does support that), or of a different version.
340+
This means the consumer must verify the version even when
341+
`max_version` is passed.
340342
341343
Returns
342344
-------
@@ -353,9 +355,61 @@ def __dlpack__(
353355
354356
Notes
355357
-----
358+
The DLPack version scheme is SemVer, where the major DLPack versions
359+
represent ABI breaks, and minor versions represent ABI-compatible additions
360+
(e.g., new enum values for new data types or device types).
361+
362+
The ``max_version`` keyword was introduced in v2023.12, and goes
363+
together with the ``DLManagedTensorVersioned`` struct added in DLPack
364+
1.0. This keyword may not be used by consumers until a later time after
365+
introduction, because producers may implement the support at a different
366+
point in time.
367+
368+
It is recommended for the producer to use this logic in the implementation
369+
of ``__dlpack__``:
370+
371+
.. code:: python
372+
373+
if max_version is None:
374+
# Keep and use the DLPack 0.X implementation
375+
# Note: from March 2025 onwards (but ideally as late as
376+
# possible), it's okay to raise BufferError here
377+
else:
378+
# We get to produce `DLManagedTensorVersioned` now. Note that
379+
# our_own_dlpack_version is the max version that the *producer*
380+
# supports and fills in the `DLManagedTensorVersioned::version`
381+
# field
382+
if max_version >= our_own_dlpack_version:
383+
# Consumer understands us, just return a Capsule with our max version
384+
elif max_version[0] == our_own_dlpack_version[0]:
385+
# major versions match, we should still be fine here -
386+
# return our own max version
387+
else:
388+
# if we're at a higher major version internally, did we
389+
# keep an implementation of the older major version around?
390+
# For example, if the producer is on DLPack 1.x and the consumer
391+
# is 0.y, can the producer still export a capsule containing
392+
# DLManagedTensor and not DLManagedTensorVersioned?
393+
# If so, use that. Else, the producer should raise a BufferError
394+
# here to tell users that the consumer's max_version is too
395+
# old to allow the data exchange to happen.
396+
397+
And this logic for the consumer in ``from_dlpack``:
398+
399+
.. code:: python
400+
401+
try:
402+
x.__dlpack__(max_version=(1, 0))
403+
# if it succeeds, store info from the capsule named "dltensor_versioned",
404+
# and need to set the name to "used_dltensor_versioned" when we're done
405+
except TypeError:
406+
x.__dlpack__()
356407
357408
.. versionchanged:: 2022.12
358409
Added BufferError.
410+
411+
.. versionchanged:: 2023.12
412+
Added the ``max_version`` keyword.
359413
"""
360414

361415
def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:

Diff for: src/array_api_stubs/_draft/creation_functions.py

+5
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ def from_dlpack(x: object, /) -> array:
246246
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
247247
on the input array. This may happen for libraries that are never able
248248
to export their data with DLPack.
249+
250+
Notes
251+
-----
252+
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
253+
order to handle DLPack versioning correctly.
249254
"""
250255

251256

0 commit comments

Comments
 (0)