Skip to content

Commit c492972

Browse files
authored
feat: add cumulative_prod specification
PR-URL: #793 Closes: #598
1 parent 5ffffee commit c492972

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

Diff for: spec/draft/API_specification/statistical_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Objects in API
1818
:toctree: generated
1919
:template: method.rst
2020

21+
cumulative_prod
2122
cumulative_sum
2223
max
2324
mean

Diff for: src/array_api_stubs/_draft/statistical_functions.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,71 @@
1-
__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
1+
__all__ = [
2+
"cumulative_sum",
3+
"cumulative_prod",
4+
"max",
5+
"mean",
6+
"min",
7+
"prod",
8+
"std",
9+
"sum",
10+
"var",
11+
]
212

313

414
from ._types import Optional, Tuple, Union, array, dtype
515

616

17+
def cumulative_prod(
18+
x: array,
19+
/,
20+
*,
21+
axis: Optional[int] = None,
22+
dtype: Optional[dtype] = None,
23+
include_initial: bool = False,
24+
) -> array:
25+
"""
26+
Calculates the cumulative product of elements in the input array ``x``.
27+
28+
Parameters
29+
----------
30+
x: array
31+
input array. Should have one or more dimensions (axes). Should have a numeric data type.
32+
axis: Optional[int]
33+
axis along which a cumulative product must be computed. If ``axis`` is negative, the function must determine the axis along which to compute a cumulative product by counting from the last dimension.
34+
35+
If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required.
36+
37+
dtype: Optional[dtype]
38+
data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases:
39+
40+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
41+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
42+
43+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``.
44+
45+
include_initial: bool
46+
boolean indicating whether to include the initial value as the first value in the output. By convention, the initial value must be the multiplicative identity (i.e., one). Default: ``False``.
47+
48+
Returns
49+
-------
50+
out: array
51+
an array containing the cumulative products. The returned array must have a data type as described by the ``dtype`` parameter above.
52+
53+
Let ``N`` be the size of the axis along which to compute the cumulative product. The returned array must have a shape determined according to the following rules:
54+
55+
- if ``include_initial`` is ``True``, the returned array must have the same shape as ``x``, except the size of the axis along which to compute the cumulative product must be ``N+1``.
56+
- if ``include_initial`` is ``False``, the returned array must have the same shape as ``x``.
57+
58+
Notes
59+
-----
60+
61+
- When ``x`` is a zero-dimensional array, behavior is unspecified and thus implementation-defined.
62+
63+
**Special Cases**
64+
65+
For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.multiply`.
66+
"""
67+
68+
769
def cumulative_sum(
870
x: array,
971
/,

0 commit comments

Comments
 (0)