Skip to content

Commit fdfa5bd

Browse files
authored
Add complex number support to linalg.svd (#561)
1 parent a1d7edc commit fdfa5bd

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

Diff for: spec/API_specification/array_api/linalg.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -465,28 +465,51 @@ def solve(x1: array, x2: array, /) -> array:
465465
"""
466466

467467
def svd(x: array, /, *, full_matrices: bool = True) -> Union[array, Tuple[array, ...]]:
468-
"""
469-
Returns a singular value decomposition A = USVh of a matrix (or a stack of matrices) ``x``, where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.
468+
r"""
469+
Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) ``x``.
470+
471+
If ``x`` is real-valued, let :math:`\mathbb{K}` be the set of real numbers :math:`\mathbb{R}`, and, if ``x`` is complex-valued, let :math:`\mathbb{K}` be the set of complex numbers :math:`\mathbb{C}`.
472+
473+
The full **singular value decomposition** of an :math:`m \times n` matrix :math:`x \in\ \mathbb{K}^{m \times n}` is a factorization of the form
474+
475+
.. math::
476+
x = U \Sigma V^H
477+
478+
where :math:`U \in\ \mathbb{K}^{m \times m}`, :math:`\Sigma \in\ \mathbb{K}^{m \times\ n}`, :math:`\operatorname{diag}(\Sigma) \in\ \mathbb{R}^{k}` with :math:`k = \operatorname{min}(m, n)`, :math:`V^H \in\ \mathbb{K}^{n \times n}`, and where :math:`V^H` is the conjugate transpose when :math:`V` is complex and the transpose when :math:`V` is real-valued. When ``x`` is real-valued, :math:`U`, :math:`V` (and thus :math:`V^H`) are orthogonal, and, when ``x`` is complex, :math:`U`, :math:`V` (and thus :math:`V^H`) are unitary.
479+
480+
When :math:`m \gt n` (tall matrix), we can drop the last :math:`m - n` columns of :math:`U` to form the reduced SVD
481+
482+
.. math::
483+
x = U \Sigma V^H
484+
485+
where :math:`U \in\ \mathbb{K}^{m \times k}`, :math:`\Sigma \in\ \mathbb{K}^{k \times\ k}`, :math:`\operatorname{diag}(\Sigma) \in\ \mathbb{R}^{k}`, and :math:`V^H \in\ \mathbb{K}^{k \times n}`. In this case, :math:`U` and :math:`V` have orthonormal columns.
486+
487+
Similarly, when :math:`n \gt m` (wide matrix), we can drop the last :math:`n - m` columns of :math:`V` to also form a reduced SVD.
488+
489+
This function returns the decomposition :math:`U`, :math:`S`, and :math:`V^H`, where :math:`S = \operatorname{diag}(\Sigma)`.
490+
491+
When ``x`` is a stack of matrices, the function must compute the singular value decomposition for each matrix in the stack.
492+
493+
.. warning::
494+
The returned arrays :math:`U` and :math:`V` are neither unique nor continuous with respect to ``x``. Because :math:`U` and :math:`V` are not unique, different hardware and software may compute different singular vectors.
495+
496+
Non-uniqueness stems from the fact that multiplying any pair of singular vectors :math:`u_k`, :math:`v_k` by :math:`-1` when ``x`` is real-valued and by :math:`e^{\phi j}` (:math:`\phi \in \mathbb{R}`) when ``x`` is complex produces another two valid singular vectors of the matrix.
470497
471498
Parameters
472499
----------
473500
x: array
474-
input array having shape ``(..., M, N)`` and whose innermost two dimensions form matrices on which to perform singular value decomposition. Should have a real-valued floating-point data type.
501+
input array having shape ``(..., M, N)`` and whose innermost two dimensions form matrices on which to perform singular value decomposition. Should have a floating-point data type.
475502
full_matrices: bool
476503
If ``True``, compute full-sized ``U`` and ``Vh``, such that ``U`` has shape ``(..., M, M)`` and ``Vh`` has shape ``(..., N, N)``. If ``False``, compute on the leading ``K`` singular vectors, such that ``U`` has shape ``(..., M, K)`` and ``Vh`` has shape ``(..., K, N)`` and where ``K = min(M, N)``. Default: ``True``.
477504
478505
Returns
479506
-------
480-
..
481-
NOTE: once complex numbers are supported, each square matrix must be Hermitian.
482507
out: Union[array, Tuple[array, ...]]
483508
a namedtuple ``(U, S, Vh)`` whose
484509
485-
- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``.
486-
- second element must have the field name ``S`` and must be an array with shape ``(..., K)`` that contains the vector(s) of singular values of length ``K``, where ``K = min(M, N)``. For each vector, the singular values must be sorted in descending order by magnitude, such that ``s[..., 0]`` is the largest value, ``s[..., 1]`` is the second largest value, et cetera. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``.
487-
- third element must have the field name ``Vh`` and must be an array whose shape depends on the value of ``full_matrices`` and contain orthonormal rows (i.e., the rows are the right singular vectors and the array is the adjoint). If ``full_matrices`` is ``True``, the array must have shape ``(..., N, N)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., K, N)`` where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``.
488-
489-
Each returned array must have the same real-valued floating-point data type as ``x``.
510+
- first element must have the field name ``U`` and must be an array whose shape depends on the value of ``full_matrices`` and contain matrices with orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True``, the array must have shape ``(..., M, M)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.
511+
- second element must have the field name ``S`` and must be an array with shape ``(..., K)`` that contains the vector(s) of singular values of length ``K``, where ``K = min(M, N)``. For each vector, the singular values must be sorted in descending order by magnitude, such that ``s[..., 0]`` is the largest value, ``s[..., 1]`` is the second largest value, et cetera. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have a real-valued floating-point data type having the same precision as ``x`` (e.g., if ``x`` is ``complex64``, ``S`` must have a ``float32`` data type).
512+
- third element must have the field name ``Vh`` and must be an array whose shape depends on the value of ``full_matrices`` and contain orthonormal rows (i.e., the rows are the right singular vectors and the array is the adjoint). If ``full_matrices`` is ``True``, the array must have shape ``(..., N, N)``. If ``full_matrices`` is ``False``, the array must have shape ``(..., K, N)`` where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x``. Must have the same data type as ``x``.
490513
"""
491514

492515
def svdvals(x: array, /) -> array:

0 commit comments

Comments
 (0)