Skip to content

Commit 61694d0

Browse files
authored
Add complex number support to linalg.qr (#548)
1 parent 775dc7a commit 61694d0

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

Diff for: spec/API_specification/array_api/linalg.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -328,16 +328,42 @@ def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
328328
"""
329329

330330
def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tuple[array, array]:
331-
"""
332-
Returns the qr decomposition x = QR of a full column rank matrix (or a stack of matrices), where ``Q`` is an orthonormal matrix (or a stack of matrices) and ``R`` is an upper-triangular matrix (or a stack of matrices).
331+
r"""
332+
Returns the QR decomposition of a full column rank matrix (or a stack of matrices).
333+
334+
If ``x`` is real-valued, let :math:`\mathbb{K}` be the set of real numbers :math:`\mathbb{R}`, and, if ``x`` is complex-valued, let :math:`\mathbb{K}` be the set of complex numbers :math:`\mathbb{C}`.
335+
336+
The **complete QR decomposition** of a matrix :math:`x \in\ \mathbb{K}^{n \times n}` is defined as
337+
338+
.. math::
339+
x = QR
340+
341+
where :math:`Q \in\ \mathbb{K}^{m \times m}` is orthogonal when ``x`` is real-valued and unitary when ``x`` is complex-valued and where :math:`R \in\ \mathbb{K}^{m \times n}` is an upper triangular matrix with real diagonal (even when ``x`` is complex-valued).
342+
343+
When :math:`m \gt n` (tall matrix), as :math:`R` is upper triangular, the last :math:`m - n` rows are zero. In this case, the last :math:`m - n` columns of :math:`Q` can be dropped to form the **reduced QR decomposition**.
344+
345+
.. math::
346+
x = QR
347+
348+
where :math:`Q \in\ \mathbb{K}^{m \times n}` and :math:`R \in\ \mathbb{K}^{n \times n}`.
349+
350+
The reduced QR decomposition equals with the complete QR decomposition when :math:`n \qeq m` (wide matrix).
351+
352+
When ``x`` is a stack of matrices, the function must compute the QR decomposition for each matrix in the stack.
333353
334354
.. note::
335355
Whether an array library explicitly checks whether an input array is a full column rank matrix (or a stack of full column rank matrices) is implementation-defined.
336356
357+
.. warning::
358+
The elements in the diagonal of :math:`R` are not necessarily positive. Accordingly, the returned QR decomposition is only unique up to the sign of the diagonal of :math:`R`, and different libraries or inputs on different devices may produce different valid decompositions.
359+
360+
.. warning::
361+
The QR decomposition is only well-defined if the first ``k = min(m,n)`` columns of every matrix in ``x`` are linearly independent.
362+
337363
Parameters
338364
----------
339365
x: array
340-
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices of rank ``N``. Should have a real-valued floating-point data type.
366+
input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices of rank ``N``. Should have a floating-point data type.
341367
mode: Literal['reduced', 'complete']
342368
decomposition mode. Should be one of the following modes:
343369
@@ -354,7 +380,7 @@ def qr(x: array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> Tupl
354380
- first element must have the field name ``Q`` and must be an array whose shape depends on the value of ``mode`` and contain matrices with orthonormal columns. If ``mode`` is ``'complete'``, the array must have shape ``(..., M, M)``. If ``mode`` is ``'reduced'``, the array must have shape ``(..., M, K)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same size as those of the input array ``x``.
355381
- second element must have the field name ``R`` and must be an array whose shape depends on the value of ``mode`` and contain upper-triangular matrices. If ``mode`` is ``'complete'``, the array must have shape ``(..., M, N)``. If ``mode`` is ``'reduced'``, the array must have shape ``(..., K, N)``, where ``K = min(M, N)``. The first ``x.ndim-2`` dimensions must have the same size as those of the input ``x``.
356382
357-
Each returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
383+
Each returned array must have a floating-point data type determined by :ref:`type-promotion`.
358384
"""
359385

360386
def slogdet(x: array, /) -> Tuple[array, array]:

0 commit comments

Comments
 (0)