-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Minimal tensorflow stub structure #7319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
version = "2.8.*" | ||
requires = ["numpy"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Alias for bool is used because tensorflow name shadows bool with tf.bool. | ||
from builtins import bool as _bool | ||
from typing import Any, Iterable, Iterator, NoReturn, overload | ||
|
||
import numpy as np | ||
from tensorflow._aliases import _TensorCompatible | ||
|
||
# Most tf.math functions are exported from tf., but not all of them are. | ||
from tensorflow.math import abs as abs | ||
|
||
def __getattr__(name: str) -> Any: ... # incomplete | ||
|
||
class Tensor: | ||
@property | ||
def shape(self) -> TensorShape: ... | ||
def get_shape(self) -> TensorShape: ... | ||
@property | ||
def name(self) -> str: ... | ||
def numpy(self) -> np.ndarray[Any, Any]: ... | ||
def __int__(self) -> int: ... | ||
def __abs__(self) -> Tensor: ... | ||
def __add__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __radd__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __sub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rsub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __mul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __matmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __truediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __neg__(self) -> Tensor: ... | ||
def __and__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rand__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __or__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __ror__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __eq__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __ne__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __ge__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __gt__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __le__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __lt__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __bool__(self) -> NoReturn: ... | ||
def __getitem__(self, slice_spec: int | slice | tuple[int | slice, ...]) -> Tensor: ... | ||
def __len__(self) -> int: ... | ||
# This only works for rank 0 tensors. | ||
def __index__(self) -> int: ... | ||
def __getattr__(self, name: str) -> Any: ... # incomplete | ||
|
||
class TensorShape: | ||
def __init__(self, dims: Iterable[int | None]): ... | ||
@property | ||
def rank(self) -> int: ... | ||
def as_list(self) -> list[int | None]: ... | ||
def assert_has_rank(self, rank: int) -> None: ... | ||
def __bool__(self) -> _bool: ... | ||
@overload | ||
def __getitem__(self, key: int) -> int | None: ... | ||
@overload | ||
def __getitem__(self, key: slice) -> TensorShape: ... | ||
def __iter__(self) -> Iterator[int | None]: ... | ||
def __len__(self) -> int: ... | ||
def __add__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __eq__(self, other: Iterable[int | None]) -> _bool: ... # type: ignore | ||
def __getattr__(self, name: str) -> Any: ... # incomplete |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Collection of commonly need type aliases. These are all private | ||
# and do not exist at runtime. | ||
|
||
from typing import Iterable, Mapping, Optional, Sequence, TypeVar, Union | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
# These aliases mostly ignore rank/shape/dtype information as that | ||
# will complicate the types heavily and can be a follow up problem. | ||
_FloatDataSequence = Union[Sequence[float], Sequence["_FloatDataSequence"]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Several of these type aliases are recursive. What would be recommended way to write them? Arbitrarily nested sequences or json like containers (nested maps/sequences) are very common types needed. Should I leave them like this, use a recursive protocol, avoid recursion and fallback to Any beyond given depth, or something else? _ContainerGeneric and _DataSequence alises are the main recursive ones here. edit: It looks like recursion produces this error for pytype, stubs/tensorflow/tensorflow/init.pyi (3.9): ParseError: Sequence['_FloatDataSequence'] not supported I'm surprised mypy passes but I think mypy treats recursive aliases like Any currently. |
||
_StrDataSequence = Union[Sequence[str], Sequence["_StrDataSequence"]] | ||
_ScalarTensorConvertible = Union[str, float, np.number, np.ndarray] | ||
_ScalarTensorCompatible = Union[tf.Tensor, _ScalarTensorConvertible] | ||
_TensorConvertible = Union[_ScalarTensorConvertible, _FloatDataSequence, _StrDataSequence] | ||
_TensorCompatible = Union[tf.Tensor, _TensorConvertible] | ||
|
||
# Sparse tensors need to be treated carefully. Most functions do | ||
# not document if they handle sparse tensors. Most functions do | ||
# not support them. Ragged tensors usually work and are documented | ||
# here, https://www.tensorflow.org/api_docs/python/tf/ragged | ||
_SparseTensorCompatible = Union[_TensorCompatible, tf.SparseTensor] | ||
_RaggedTensorCompatible = Union[_TensorCompatible, tf.RaggedTensor] | ||
_AnyTensorCompatible = Union[_TensorCompatible, tf.Tensor, tf.Variable] | ||
|
||
_SparseTensorLike = Union[tf.Tensor, tf.SparseTensor] | ||
_RaggedTensorLike = Union[tf.Tensor, tf.RaggedTensor] | ||
_AnyTensorLike = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] | ||
|
||
_T1 = TypeVar("_T1", covariant=True) | ||
_ContainerGeneric = Union[Mapping[str, "_ContainerGeneric[_T1]"], Sequence["_ContainerGeneric[_T1]"], _T1] | ||
_ContainerTensors = _ContainerGeneric[tf.Tensor] | ||
_ContainerTensorCompatible = _ContainerGeneric[_TensorCompatible] | ||
|
||
_ShapeLike = Union[tf.TensorShape, Iterable[Optional[int]], int, tf.Tensor] | ||
_DTypeLike = Union[tf.DType, str, np.dtype] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from typing import overload | ||
|
||
from tensorflow import RaggedTensor, SparseTensor, Tensor | ||
from tensorflow._aliases import _TensorCompatible | ||
|
||
@overload | ||
def abs(x: _TensorCompatible, name: str | None = ...) -> Tensor: ... | ||
@overload | ||
def abs(x: SparseTensor, name: str | None = ...) -> SparseTensor: ... | ||
@overload | ||
def abs(x: RaggedTensor, name: str | None = ...) -> RaggedTensor: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do I handle numpy errors? I thought requires line in METADATA.toml would allow using numpy?
numpy is only library I think needed as most tensorflow functions accept numpy arrays too.
I think all of unknown pyright errors are about np.ndarray/np.number.
edit: @jakebailey Any advice on pyright action and handling a stub package that depends on a separate python package? Would it be best have numpy installed as part of pyright check, add this folder to exclude list for pyrightconfig.json, or something else?