Skip to content

Commit 8034b88

Browse files
committed
feat: add cumulative_prod specification
Ref: data-apis#598
1 parent c305b82 commit 8034b88

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

Diff for: spec/draft/API_specification/statistical_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Objects in API
1818
:toctree: generated
1919
:template: method.rst
2020

21+
cumulative_prod
2122
cumulative_sum
2223
max
2324
mean

Diff for: src/array_api_stubs/_draft/statistical_functions.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
1+
__all__ = [
2+
"cumulative_sum",
3+
"cumulative_prod",
4+
"max",
5+
"mean",
6+
"min",
7+
"prod",
8+
"std",
9+
"sum",
10+
"var",
11+
]
212

313

414
from ._types import Optional, Tuple, Union, array, dtype
@@ -56,6 +66,56 @@ def cumulative_sum(
5666
"""
5767

5868

69+
def cumulative_prod(
70+
x: array,
71+
/,
72+
*,
73+
axis: Optional[int] = None,
74+
dtype: Optional[dtype] = None,
75+
include_initial: bool = False,
76+
) -> array:
77+
"""
78+
Calculates the cumulative product of elements in the input array ``x``.
79+
80+
Parameters
81+
----------
82+
x: array
83+
input array. Should have a numeric data type.
84+
axis: Optional[int]
85+
axis along which a cumulative product must be computed. If ``axis`` is negative, the function must determine the axis along which to compute a cumulative product by counting from the last dimension.
86+
87+
If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required.
88+
89+
dtype: Optional[dtype]
90+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
91+
92+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
93+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
94+
95+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
96+
97+
include_initial: bool
98+
boolean indicating whether to include the initial value as the first value in the output. By convention, the initial value must be the multiplicative identity (i.e., one). Default: ``False``.
99+
100+
Returns
101+
-------
102+
out: array
103+
an array containing the cumulative products. The returned array must have a data type as described by the ``dtype`` parameter above.
104+
105+
Let ``N`` be the size of the axis along which to compute the cumulative product. The returned array must have a shape determined according to the following rules:
106+
107+
- if ``include_initial`` is ``True``, the returned array must have the same shape as ``x``, except the size of the axis along which to compute the cumulative product must be ``N+1``.
108+
- if ``include_initial`` is ``False``, the returned array must have the same shape as ``x``.
109+
110+
Notes
111+
-----
112+
113+
**Special Cases**
114+
115+
For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.multiply`.
116+
"""
117+
118+
59119
def max(
60120
x: array,
61121
/,

0 commit comments

Comments
 (0)