Skip to content

Commit 01bf821

Browse files
Merge pull request #30 from oscarbenjamin/pr_binexpand
feat: Add Transformer class and bin_expand function
2 parents af470f1 + dcaa75e commit 01bf821

File tree

5 files changed

+156
-11
lines changed

5 files changed

+156
-11
lines changed

src/protosym/core/evaluate.py

+71-10
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from typing import TYPE_CHECKING as _TYPE_CHECKING
88
from typing import TypeVar
99

10+
from protosym.core.exceptions import NoEvaluationRuleError
1011
from protosym.core.tree import forward_graph
1112
from protosym.core.tree import TreeAtom
13+
from protosym.core.tree import TreeExpr
1214

1315

1416
if _TYPE_CHECKING:
15-
from protosym.core.tree import TreeExpr
1617
from protosym.core.atom import AnyValue as _AnyValue
1718
from protosym.core.atom import AtomType
1819

@@ -61,7 +62,17 @@ def add_opn(self, head: TreeExpr, func: OpN[_T]) -> None:
6162
"""Add an evaluation rule for a particular head."""
6263
self.operations[head] = (func, False)
6364

64-
def call(self, head: TreeExpr, argvals: Iterable[_T]) -> _T:
65+
def eval_atom(self, atom: TreeAtom[_S]) -> _T:
66+
"""Evaluate an atom."""
67+
atom_value = atom.value
68+
atom_func = self.atoms.get(atom_value.atom_type) # type: ignore
69+
if atom_func is not None:
70+
return atom_func(atom_value.value)
71+
else:
72+
msg = "No rule for AtomType: " + atom_value.atom_type.name
73+
raise NoEvaluationRuleError(msg)
74+
75+
def eval_operation(self, head: TreeExpr, argvals: Iterable[_T]) -> _T:
6576
"""Evaluate one function with some values."""
6677
op_func, star_args = self.operations[head]
6778
if star_args:
@@ -81,15 +92,13 @@ def eval_recursive(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
8192
return values[expr]
8293
elif isinstance(expr, TreeAtom):
8394
# Convert an Atom to _T
84-
value = expr.value
85-
atom_func = self.atoms[value.atom_type]
86-
return atom_func(value.value)
95+
return self.eval_atom(expr)
8796
else:
8897
# Recursively evaluate children and then apply this operation.
8998
head = expr.children[0]
9099
children = expr.children[1:]
91100
argvals = [self.eval_recursive(c, values) for c in children]
92-
return self.call(head, argvals)
101+
return self.eval_operation(head, argvals)
93102

94103
def eval_forward(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
95104
"""Evaluate the expression using forward evaluation."""
@@ -103,15 +112,13 @@ def eval_forward(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
103112
if value_get is not None:
104113
value = value_get
105114
else:
106-
atom_value = atom.value # type:ignore
107-
atom_func = self.atoms[atom_value.atom_type]
108-
value = atom_func(atom_value.value)
115+
value = self.eval_atom(atom) # type: ignore
109116
stack.append(value)
110117

111118
# Run forward evaluation through the operations
112119
for head, indices in graph.operations:
113120
argvals = [stack[i] for i in indices]
114-
stack.append(self.call(head, argvals))
121+
stack.append(self.eval_operation(head, argvals))
115122

116123
# Now stack is the values of the topological sort of expr and stack[-1]
117124
# is the value of expr.
@@ -124,3 +131,57 @@ def __call__(
124131
if values is None:
125132
values = {}
126133
return self.evaluate(expr, values)
134+
135+
136+
class Transformer(Evaluator[TreeExpr]):
137+
"""Specialized Evaluator for TreeExpr -> TreeExpr operations.
138+
139+
Whereas :class:`Evaluator` is used to evaluate an expression into a
140+
different type of object like ``float`` or ``str`` a :class:`Transformer`
141+
is used to transform a :class:`TreeExpr` into a new :class:`TreeExpr`.
142+
143+
The difference between using ``Transformer`` and using
144+
``Evaluator[TreeExpr]`` is that ``Transformer`` allows processing
145+
operations that have no associated rules leaving the expression unmodified.
146+
147+
Examples
148+
========
149+
150+
We first import the pieces and define some functions and symbols.
151+
152+
>>> from protosym.core.tree import funcs_symbols
153+
>>> from protosym.core.evaluate import Evaluator, Transformer
154+
>>> [f, g], [x, y] = funcs_symbols(['f', 'g'], ['x', 'y'])
155+
156+
Now make a :class:`Transformer` to replace ``f(...)`` with ``g(...)``.
157+
158+
>>> f2g = Transformer()
159+
>>> f2g.add_opn(f, lambda args: g(*args))
160+
>>> expr = f(g(x, f(y)), y)
161+
>>> print(expr)
162+
f(g(x, f(y)), y)
163+
>>> print(f2g(expr))
164+
g(g(x, g(y)), y)
165+
166+
By contrast with ``Evaluator[TreeExpr]`` the above would fail because no
167+
rule has been defined for the head ``g`` or for ``Symbol`` (the
168+
:class:`AtomType` of ``x`` and ``y``).
169+
170+
>>> f2g_eval = Evaluator[TreeExpr]()
171+
>>> f2g_eval.add_opn(f, lambda args: g(*args))
172+
>>> f2g_eval(expr)
173+
Traceback (most recent call last):
174+
...
175+
protosym.core.exceptions.NoEvaluationRuleError: No rule for AtomType: Symbol
176+
"""
177+
178+
def eval_atom(self, atom: TreeAtom[_S]) -> TreeExpr:
179+
"""Return the atom as is."""
180+
return atom
181+
182+
def eval_operation(self, head: TreeExpr, argvals: Iterable[TreeExpr]) -> TreeExpr:
183+
"""Return unevaluated operation if no rule supplied."""
184+
if head not in self.operations:
185+
return head(*argvals)
186+
else:
187+
return super().eval_operation(head, argvals)

src/protosym/core/exceptions.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Module for all protosym exceptions."""
2+
3+
4+
class ProtoSymError(Exception):
5+
"""Superclass for all protosym exceptions."""
6+
7+
pass
8+
9+
10+
class NoEvaluationRuleError(ProtoSymError):
11+
"""Raised when an :class:`Evaluator` has no rule for an expression."""
12+
13+
pass

src/protosym/simplecas.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import math
5+
from functools import reduce
56
from functools import wraps
67
from typing import Any
78
from typing import Callable
@@ -16,6 +17,7 @@
1617

1718
from protosym.core.atom import AtomType
1819
from protosym.core.evaluate import Evaluator
20+
from protosym.core.evaluate import Transformer
1921
from protosym.core.tree import forward_graph
2022
from protosym.core.tree import topological_sort
2123
from protosym.core.tree import TreeAtom
@@ -242,6 +244,18 @@ def diff(self, sym: Expr, ntimes: int = 1) -> Expr:
242244
deriv_rep = _diff_forward(deriv_rep, sym_rep)
243245
return Expr(deriv_rep)
244246

247+
def bin_expand(self) -> Expr:
248+
"""Expand associative operators to binary operations.
249+
250+
>>> from protosym.simplecas import Add
251+
>>> expr = Add(1, 2, 3, 4)
252+
>>> expr
253+
(1 + 2 + 3 + 4)
254+
>>> expr.bin_expand()
255+
(((1 + 2) + 3) + 4)
256+
"""
257+
return Expr(_bin_expand(self.rep))
258+
245259

246260
# Avoid importing SymPy if possible.
247261
_eval_to_sympy: Evaluator[Any] | None = None
@@ -345,6 +359,10 @@ def from_sympy(expr: Any) -> Expr:
345359
eval_latex.add_opn(Add.rep, lambda args: f'({" + ".join(args)})')
346360
eval_latex.add_opn(Mul.rep, lambda args: "(%s)" % r" \times ".join(args))
347361

362+
_bin_expand = Transformer()
363+
_bin_expand.add_opn(Add.rep, lambda args: reduce(Add.rep, args))
364+
_bin_expand.add_opn(Mul.rep, lambda args: reduce(Mul.rep, args))
365+
348366

349367
def _op1(a: int | str) -> int:
350368
return 1

tests/core/test_evaluate.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import math
22
from typing import Callable
33

4+
from pytest import raises
5+
46
from protosym.core.atom import AtomType
57
from protosym.core.evaluate import Evaluator
8+
from protosym.core.evaluate import Transformer
9+
from protosym.core.exceptions import NoEvaluationRuleError
10+
from protosym.core.tree import funcs_symbols
611
from protosym.core.tree import TreeAtom
712
from protosym.core.tree import TreeExpr
813

@@ -49,3 +54,18 @@ def test_Evaluator() -> None:
4954
for expr, vals, expected in test_cases:
5055
for func in eval_funcs:
5156
assert func(expr, vals) == expected
57+
58+
59+
def test_Transformer() -> None:
60+
"""Test defining and using a Transformer."""
61+
[f, g], [x, y] = funcs_symbols(["f", "g"], ["x", "y"])
62+
63+
f2g = Transformer()
64+
f2g.add_opn(f, lambda args: g(*args))
65+
expr = f(g(x, f(y)), y)
66+
assert f2g(expr) == g(g(x, g(y)), y)
67+
68+
# With Evaluator the above would fail without rules for Symbol and g:
69+
f2g_eval = Evaluator[TreeExpr]()
70+
f2g_eval.add_opn(f, lambda args: g(*args))
71+
raises(NoEvaluationRuleError, lambda: f2g_eval(expr))

tests/test_simplecas.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from pytest import raises
23
from pytest import skip
34

@@ -7,6 +8,7 @@
78
from protosym.simplecas import ExprAtomType
89
from protosym.simplecas import expressify
910
from protosym.simplecas import ExpressifyError
11+
from protosym.simplecas import Function
1012
from protosym.simplecas import Integer
1113
from protosym.simplecas import Mul
1214
from protosym.simplecas import negone
@@ -20,6 +22,7 @@
2022

2123

2224
two = Integer(2)
25+
f = Function("f")
2326

2427

2528
def test_simplecas_types() -> None:
@@ -105,6 +108,7 @@ def test_simplecas_repr() -> None:
105108
assert str(Integer) == "Integer"
106109
assert str(x) == "x"
107110
assert str(y) == "y"
111+
assert str(f) == "f"
108112
assert str(sin) == "sin"
109113
assert str(sin(cos(x))) == "sin(cos(x))"
110114
assert str(x + y) == "(x + y)"
@@ -114,6 +118,21 @@ def test_simplecas_repr() -> None:
114118
assert str(x + x + x) == "((x + x) + x)"
115119

116120

121+
@pytest.mark.xfail
122+
def test_simplecas_repr_xfail() -> None:
123+
"""Test printing an undefined function."""
124+
#
125+
# This fails because Evaluator expects every function to be defined. There
126+
# is a printing rule for Function but that only handles an uncalled
127+
# function like f rather than f(x). There needs to be a way to give a
128+
# default rule to an Evaluator for handling the cases where there is no
129+
# operation rule defined for the head.
130+
#
131+
assert str(f(x)) == "f(x)"
132+
assert repr(f(x)) == "f(x)"
133+
assert f(x).eval_latex() == "f(x)"
134+
135+
117136
def test_simplecas_latex() -> None:
118137
"""Test basic operations with simplecas."""
119138
assert x.eval_latex() == r"x"
@@ -135,7 +154,7 @@ def test_simplecas_repr_latex() -> None:
135154
def test_simplecas_to_sympy() -> None:
136155
"""Test converting a simplecas expression to a SymPy expression."""
137156
try:
138-
sympy = __import__("sympy")
157+
import sympy
139158
except ImportError:
140159
skip("SymPy not installed")
141160

@@ -230,3 +249,17 @@ def test_simplecas_differentation() -> None:
230249
assert (sin(x) + cos(x)).diff(x) == cos(x) + -1 * sin(x)
231250
assert (sin(x) ** 2).diff(x) == 2 * sin(x) ** Add(2, -1) * cos(x)
232251
assert (x * sin(x)).diff(x) == 1 * sin(x) + x * cos(x)
252+
253+
254+
def test_simplecas_bin_expand() -> None:
255+
"""Test Expr.bin_expand()."""
256+
expr1 = Add(1, 2, 3, 4)
257+
assert expr1.bin_expand() == Add(Add(Add(1, 2), 3), 4)
258+
assert str(expr1) == "(1 + 2 + 3 + 4)"
259+
assert str(expr1.bin_expand()) == "(((1 + 2) + 3) + 4)"
260+
261+
expr2 = Add(x, y, Mul(x, y, 1, f(x)))
262+
assert expr2.bin_expand() == Add(Add(x, y), Mul(Mul(Mul(x, y), 1), f(x)))
263+
# Fails because f(x) cannot be printed:
264+
# assert str(expr2) == "(x + y + x*y*1*f(x))"
265+
# assert str(expr2.bin_expand()) == "((x + y) + ((x*y)*1)*f(x))"

0 commit comments

Comments
 (0)