Skip to content

Commit 753331a

Browse files
author
Jake Simones
committed
Add unfold
1 parent f34a534 commit 753331a

File tree

4 files changed

+73
-2
lines changed

4 files changed

+73
-2
lines changed

doc/source/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ Functoolz
7171
pipe
7272
thread_first
7373
thread_last
74+
unfold
75+
unfold_
7476

7577
Dicttoolz
7678
---------

toolz/curried/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
update_in = toolz.curry(toolz.update_in)
9999
valfilter = toolz.curry(toolz.valfilter)
100100
valmap = toolz.curry(toolz.valmap)
101+
unfold = toolz.curry(toolz.unfold)
102+
unfold_ = toolz.curry(toolz.unfold_)
101103

102104
del exceptions
103105
del toolz

toolz/functoolz.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
__all__ = ('identity', 'apply', 'thread_first', 'thread_last', 'memoize',
1414
'compose', 'compose_left', 'pipe', 'complement', 'juxt', 'do',
15-
'curry', 'flip', 'excepts')
15+
'curry', 'flip', 'excepts', 'unfold', 'unfold_')
1616

1717

1818
def identity(x):
@@ -825,6 +825,58 @@ def __name__(self):
825825
return 'excepting'
826826

827827

828+
def unfold(func, x):
829+
""" Generate values from a seed value
830+
831+
Each iteration, the generator yields ``func(x)[0]`` and evaluates
832+
``func(x)[1]`` to determine the next ``x`` value. Iteration proceeds as
833+
long as ``func(x)`` is not None.
834+
835+
>>> def doubles(x):
836+
... if x > 10:
837+
... return None
838+
... else:
839+
... return (x * 2, x + 1)
840+
...
841+
>>> list(unfold(doubles, 1))
842+
[2, 4, 6, 9, 10, 12, 14, 16, 18, 20]
843+
844+
If ``x`` has type ``A`` and the generator yields values of type ``B``,
845+
then ``func`` has type ``Callable[[A], Optional[Tuple[B, A]]]``.
846+
847+
"""
848+
while True:
849+
t = func(x)
850+
if t is None:
851+
break
852+
else:
853+
yield t[0]
854+
x = t[1]
855+
856+
857+
def unfold_(predicate, func, succ, x):
858+
""" Alternative formulation of unfold
859+
860+
Each iteration, the generator yields ``func(x)`` and evaluates
861+
``succ(x)`` to determine the next ``x`` value. Iteration proceeds as long
862+
as ``predicate(x)`` is True.
863+
864+
>>> lte10 = lambda x: x <= 10
865+
>>> double = lambda x: x * 2
866+
>>> inc = lambda x: x + 1
867+
>>> list(unfold(lte10, double, inc, 1))
868+
[2, 4, 6, 9, 10, 12, 14, 16, 18, 20]
869+
870+
If ``x`` has type ``A`` and the generator yields values of type ``B``,
871+
then ``predicate`` has type ``Callable[[A], bool]``, ``func`` has type
872+
``Callable[[A], B]``, and ``succ`` has type ``Callable[[A], A]``.
873+
874+
"""
875+
while predicate(x):
876+
yield func(x)
877+
x = succ(x)
878+
879+
828880
if PY3: # pragma: py2 no cover
829881
def _check_sigspec(sigspec, func, builtin_func, *builtin_args):
830882
if sigspec is None:

toolz/tests/test_functoolz.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import toolz
33
from toolz.functoolz import (thread_first, thread_last, memoize, curry,
44
compose, compose_left, pipe, complement, do, juxt,
5-
flip, excepts, apply)
5+
flip, excepts, apply, unfold, unfold_)
66
from toolz.compatibility import PY3
77
from operator import add, mul, itemgetter
88
from toolz.utils import raises
@@ -796,3 +796,18 @@ def raise_(a):
796796
excepting = excepts(object(), object(), object())
797797
assert excepting.__name__ == 'excepting'
798798
assert excepting.__doc__ == excepts.__doc__
799+
800+
801+
def test_unfold():
802+
expected = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
803+
804+
def doubles(x):
805+
if x > 10:
806+
return None
807+
else:
808+
return (x * 2, x + 1)
809+
assert list(unfold(doubles, 1)) == expected
810+
811+
def lte10(x):
812+
return x <= 10
813+
assert list(unfold_(lte10, double, inc, 1)) == expected

0 commit comments

Comments
 (0)