Skip to content

ENH: cache helper functions #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 15, 2025

Speed up helper functions through caching.

>>> import sys
>>> import array_api_compat.numpy as np
>>> import array_api_compat.torch as torch
>>> from array_api_compat import is_numpy_namespace, is_numpy_array, is_torch_array
>>> a = np.asarray(1)
>>> b = torch.asarray(1)

>>> %timeit is_numpy_namespace(np)
BEFORE 333 ns ± 1.83 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 62 ns ± 0.227 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_namespace(sys)
BEFORE 334 ns ± 4.79 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 69.3 ns ± 1.18 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_array(a)
BEFORE 382 ns ± 2.01 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 281 ns ± 4.03 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_numpy_array(1)
BEFORE 272 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 175 ns ± 5.22 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_numpy_array([1])
BEFORE 288 ns ± 1.16 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 213 ns ± 7.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_torch_array(b)
BEFORE 214 ns ± 0.244 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 126 ns ± 1.05 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_torch_array(1)
BEFORE 249 ns ± 4.49 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 121 ns ± 0.724 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_array_api_obj(a)
BEFORE 423 ns ± 1.35 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 99.5 ns ± 1.47 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_array_api_obj(1)
BEFORE 773 ns ± 2.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 142 ns ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_lazy_array(a)
BEFORE 437 ns ± 7.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 381 ns ± 5.56 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_lazy_array(1)
BEFORE 2.12 μs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
AFTER 624 ns ± 8.06 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

>>> %timeit is_writeable_array(a)
BEFORE 491 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 183 ns ± 0.446 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_writeable_array(b)
BEFORE 1.15 μs ± 5.57 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 185 ns ± 0.766 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

>>> %timeit is_writeable_array(1)
BEFORE 1.52 μs ± 13.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
AFTER 189 ns ± 5.21 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Note: is_numpy_array and is_jax_array are slower than the other equivalent functions due to the reclassification of JAX zero gradient arrays.

@Copilot Copilot AI review requested due to automatic review settings April 15, 2025 16:27
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (1)

array_api_compat/common/_helpers.py:944

  • [nitpick] Consider adding a comment to explain why 'cache' is included in _all_ignore to improve clarity for future maintainers.
_all_ignore = ['cache', 'sys', 'math', 'inspect', 'warnings']

from typing import Optional, Union, Any

from ._typing import Array, Device, Namespace


def _is_jax_zero_gradient_array(x: object) -> bool:
@cache
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This in theory could lead to a memory leak for a user that somehow dynamically defines and then forgets a lot of classes. I don't think it's something to worry about in real life?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant