Skip to content

Fix svd function return dtype #619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array_api_stubs/_2021_12/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def solve(x1: array, x2: array, /) -> array:
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
"""

def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
"""
Returns a singular value decomposition A = USVh of a matrix (or a stack of matrices) ``x``, where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.

Expand All @@ -379,7 +379,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,
-------
..
NOTE: once complex numbers are supported, each square matrix must be Hermitian.
out: Union[array, Tuple[array, ...]]
out: Tuple[array, array, array]
a namedtuple ``(U, S, Vh)`` whose

- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``.
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_stubs/_2022_12/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def solve(x1: array, x2: array, /) -> array:
"""


def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
r"""
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) ``x``.

Expand Down Expand Up @@ -565,7 +565,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,

Returns
-------
out: Union[array, Tuple[array, ...]]
out: Tuple[array, array, array]
a namedtuple ``(U, S, Vh)`` whose

- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_stubs/_draft/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def solve(x1: array, x2: array, /) -> array:
"""


def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
r"""
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) ``x``.

Expand Down Expand Up @@ -649,7 +649,7 @@ def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array,

Returns
-------
out: Union[array, Tuple[array, ...]]
out: Tuple[array, array, array]
a namedtuple ``(U, S, Vh)`` whose

- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.
Expand Down