From b6d43d11e96a76269852b8d603bcfd27d1d4472f Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Fri, 3 Jan 2025 13:29:20 +0000 Subject: [PATCH] TYP: import annotations for sklearn --- src/array_api_extra/_funcs.py | 5 ++++- src/array_api_extra/_lib/_compat.pyi | 3 +++ src/array_api_extra/_lib/_utils.py | 3 +++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 03b4829..e0fa5f5 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -1,5 +1,8 @@ """Public API Functions.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + import operator import warnings @@ -719,7 +722,7 @@ def __init__( self._x = x self._idx = idx - def __getitem__(self, idx: Index, /) -> "at": # numpydoc ignore=PR01,RT01 + def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01 """ Allow for the alternate syntax ``at(x)[start:stop:step]``. diff --git a/src/array_api_extra/_lib/_compat.pyi b/src/array_api_extra/_lib/_compat.pyi index f65a28f..4d06a7f 100644 --- a/src/array_api_extra/_lib/_compat.pyi +++ b/src/array_api_extra/_lib/_compat.pyi @@ -1,5 +1,8 @@ """Static type stubs for `_compat.py`.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + from types import ModuleType from ._typing import Array, Device diff --git a/src/array_api_extra/_lib/_utils.py b/src/array_api_extra/_lib/_utils.py index 523c21b..1191b4f 100644 --- a/src/array_api_extra/_lib/_utils.py +++ b/src/array_api_extra/_lib/_utils.py @@ -1,5 +1,8 @@ """Utility functions used by `array_api_extra/_funcs.py`.""" +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + from . import _compat from ._typing import Array, ModuleType