Skip to content

Minimal tensorflow stub structure #7319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stubs/tensorflow/METADATA.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
version = "2.8.*"
requires = ["numpy"]
68 changes: 68 additions & 0 deletions stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Alias for bool is used because tensorflow name shadows bool with tf.bool.
from builtins import bool as _bool
from typing import Any, Iterable, Iterator, NoReturn, overload

import numpy as np
Copy link
Contributor Author

@hmc-cs-mdrissi hmc-cs-mdrissi Feb 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I handle numpy errors? I thought requires line in METADATA.toml would allow using numpy?

numpy is only library I think needed as most tensorflow functions accept numpy arrays too.

I think all of unknown pyright errors are about np.ndarray/np.number.

edit: @jakebailey Any advice on pyright action and handling a stub package that depends on a separate python package? Would it be best have numpy installed as part of pyright check, add this folder to exclude list for pyrightconfig.json, or something else?

from tensorflow._aliases import _TensorCompatible

# Most tf.math functions are exported from tf., but not all of them are.
from tensorflow.math import abs as abs

def __getattr__(name: str) -> Any: ... # incomplete

class Tensor:
@property
def shape(self) -> TensorShape: ...
def get_shape(self) -> TensorShape: ...
@property
def name(self) -> str: ...
def numpy(self) -> np.ndarray[Any, Any]: ...
def __int__(self) -> int: ...
def __abs__(self) -> Tensor: ...
def __add__(self, other: _TensorCompatible) -> Tensor: ...
def __radd__(self, other: _TensorCompatible) -> Tensor: ...
def __sub__(self, other: _TensorCompatible) -> Tensor: ...
def __rsub__(self, other: _TensorCompatible) -> Tensor: ...
def __mul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmul__(self, other: _TensorCompatible) -> Tensor: ...
def __matmul__(self, other: _TensorCompatible) -> Tensor: ...
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ...
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ...
def __truediv__(self, other: _TensorCompatible) -> Tensor: ...
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ...
def __neg__(self) -> Tensor: ...
def __and__(self, other: _TensorCompatible) -> Tensor: ...
def __rand__(self, other: _TensorCompatible) -> Tensor: ...
def __or__(self, other: _TensorCompatible) -> Tensor: ...
def __ror__(self, other: _TensorCompatible) -> Tensor: ...
def __eq__(self, other: _TensorCompatible) -> Tensor: ...
def __ne__(self, other: _TensorCompatible) -> Tensor: ...
def __ge__(self, other: _TensorCompatible) -> Tensor: ...
def __gt__(self, other: _TensorCompatible) -> Tensor: ...
def __le__(self, other: _TensorCompatible) -> Tensor: ...
def __lt__(self, other: _TensorCompatible) -> Tensor: ...
def __bool__(self) -> NoReturn: ...
def __getitem__(self, slice_spec: int | slice | tuple[int | slice, ...]) -> Tensor: ...
def __len__(self) -> int: ...
# This only works for rank 0 tensors.
def __index__(self) -> int: ...
def __getattr__(self, name: str) -> Any: ... # incomplete

class TensorShape:
def __init__(self, dims: Iterable[int | None]): ...
@property
def rank(self) -> int: ...
def as_list(self) -> list[int | None]: ...
def assert_has_rank(self, rank: int) -> None: ...
def __bool__(self) -> _bool: ...
@overload
def __getitem__(self, key: int) -> int | None: ...
@overload
def __getitem__(self, key: slice) -> TensorShape: ...
def __iter__(self) -> Iterator[int | None]: ...
def __len__(self) -> int: ...
def __add__(self, other: Iterable[int | None]) -> TensorShape: ...
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ...
def __eq__(self, other: Iterable[int | None]) -> _bool: ... # type: ignore
def __getattr__(self, name: str) -> Any: ... # incomplete
36 changes: 36 additions & 0 deletions stubs/tensorflow/tensorflow/_aliases.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Collection of commonly need type aliases. These are all private
# and do not exist at runtime.

from typing import Iterable, Mapping, Optional, Sequence, TypeVar, Union

import numpy as np
import tensorflow as tf

# These aliases mostly ignore rank/shape/dtype information as that
# will complicate the types heavily and can be a follow up problem.
_FloatDataSequence = Union[Sequence[float], Sequence["_FloatDataSequence"]]
Copy link
Contributor Author

@hmc-cs-mdrissi hmc-cs-mdrissi Feb 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several of these type aliases are recursive. What would be recommended way to write them? Arbitrarily nested sequences or json like containers (nested maps/sequences) are very common types needed. Should I leave them like this, use a recursive protocol, avoid recursion and fallback to Any beyond given depth, or something else?

_ContainerGeneric and _DataSequence alises are the main recursive ones here.

edit: It looks like recursion produces this error for pytype,

stubs/tensorflow/tensorflow/init.pyi (3.9): ParseError: Sequence['_FloatDataSequence'] not supported

I'm surprised mypy passes but I think mypy treats recursive aliases like Any currently.

_StrDataSequence = Union[Sequence[str], Sequence["_StrDataSequence"]]
_ScalarTensorConvertible = Union[str, float, np.number, np.ndarray]
_ScalarTensorCompatible = Union[tf.Tensor, _ScalarTensorConvertible]
_TensorConvertible = Union[_ScalarTensorConvertible, _FloatDataSequence, _StrDataSequence]
_TensorCompatible = Union[tf.Tensor, _TensorConvertible]

# Sparse tensors need to be treated carefully. Most functions do
# not document if they handle sparse tensors. Most functions do
# not support them. Ragged tensors usually work and are documented
# here, https://www.tensorflow.org/api_docs/python/tf/ragged
_SparseTensorCompatible = Union[_TensorCompatible, tf.SparseTensor]
_RaggedTensorCompatible = Union[_TensorCompatible, tf.RaggedTensor]
_AnyTensorCompatible = Union[_TensorCompatible, tf.Tensor, tf.Variable]

_SparseTensorLike = Union[tf.Tensor, tf.SparseTensor]
_RaggedTensorLike = Union[tf.Tensor, tf.RaggedTensor]
_AnyTensorLike = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]

_T1 = TypeVar("_T1", covariant=True)
_ContainerGeneric = Union[Mapping[str, "_ContainerGeneric[_T1]"], Sequence["_ContainerGeneric[_T1]"], _T1]
_ContainerTensors = _ContainerGeneric[tf.Tensor]
_ContainerTensorCompatible = _ContainerGeneric[_TensorCompatible]

_ShapeLike = Union[tf.TensorShape, Iterable[Optional[int]], int, tf.Tensor]
_DTypeLike = Union[tf.DType, str, np.dtype]
11 changes: 11 additions & 0 deletions stubs/tensorflow/tensorflow/math.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import overload

from tensorflow import RaggedTensor, SparseTensor, Tensor
from tensorflow._aliases import _TensorCompatible

@overload
def abs(x: _TensorCompatible, name: str | None = ...) -> Tensor: ...
@overload
def abs(x: SparseTensor, name: str | None = ...) -> SparseTensor: ...
@overload
def abs(x: RaggedTensor, name: str | None = ...) -> RaggedTensor: ...