Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Transformer class and bin_expand function #30

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions src/protosym/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import TYPE_CHECKING as _TYPE_CHECKING
from typing import TypeVar

from protosym.core.exceptions import NoEvaluationRuleError
from protosym.core.tree import forward_graph
from protosym.core.tree import TreeAtom
from protosym.core.tree import TreeExpr


if _TYPE_CHECKING:
from protosym.core.tree import TreeExpr
from protosym.core.atom import AnyValue as _AnyValue
from protosym.core.atom import AtomType

Expand Down Expand Up @@ -61,7 +62,17 @@ def add_opn(self, head: TreeExpr, func: OpN[_T]) -> None:
"""Add an evaluation rule for a particular head."""
self.operations[head] = (func, False)

def call(self, head: TreeExpr, argvals: Iterable[_T]) -> _T:
def eval_atom(self, atom: TreeAtom[_S]) -> _T:
"""Evaluate an atom."""
atom_value = atom.value
atom_func = self.atoms.get(atom_value.atom_type) # type: ignore
if atom_func is not None:
return atom_func(atom_value.value)
else:
msg = "No rule for AtomType: " + atom_value.atom_type.name
raise NoEvaluationRuleError(msg)

def eval_operation(self, head: TreeExpr, argvals: Iterable[_T]) -> _T:
"""Evaluate one function with some values."""
op_func, star_args = self.operations[head]
if star_args:
Expand All @@ -81,15 +92,13 @@ def eval_recursive(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
return values[expr]
elif isinstance(expr, TreeAtom):
# Convert an Atom to _T
value = expr.value
atom_func = self.atoms[value.atom_type]
return atom_func(value.value)
return self.eval_atom(expr)
else:
# Recursively evaluate children and then apply this operation.
head = expr.children[0]
children = expr.children[1:]
argvals = [self.eval_recursive(c, values) for c in children]
return self.call(head, argvals)
return self.eval_operation(head, argvals)

def eval_forward(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
"""Evaluate the expression using forward evaluation."""
Expand All @@ -103,15 +112,13 @@ def eval_forward(self, expr: TreeExpr, values: dict[TreeExpr, _T]) -> _T:
if value_get is not None:
value = value_get
else:
atom_value = atom.value # type:ignore
atom_func = self.atoms[atom_value.atom_type]
value = atom_func(atom_value.value)
value = self.eval_atom(atom) # type: ignore
stack.append(value)

# Run forward evaluation through the operations
for head, indices in graph.operations:
argvals = [stack[i] for i in indices]
stack.append(self.call(head, argvals))
stack.append(self.eval_operation(head, argvals))

# Now stack is the values of the topological sort of expr and stack[-1]
# is the value of expr.
Expand All @@ -124,3 +131,57 @@ def __call__(
if values is None:
values = {}
return self.evaluate(expr, values)


class Transformer(Evaluator[TreeExpr]):
"""Specialized Evaluator for TreeExpr -> TreeExpr operations.

Whereas :class:`Evaluator` is used to evaluate an expression into a
different type of object like ``float`` or ``str`` a :class:`Transformer`
is used to transform a :class:`TreeExpr` into a new :class:`TreeExpr`.

The difference between using ``Transformer`` and using
``Evaluator[TreeExpr]`` is that ``Transformer`` allows processing
operations that have no associated rules leaving the expression unmodified.

Examples
========

We first import the pieces and define some functions and symbols.

>>> from protosym.core.tree import funcs_symbols
>>> from protosym.core.evaluate import Evaluator, Transformer
>>> [f, g], [x, y] = funcs_symbols(['f', 'g'], ['x', 'y'])

Now make a :class:`Transformer` to replace ``f(...)`` with ``g(...)``.

>>> f2g = Transformer()
>>> f2g.add_opn(f, lambda args: g(*args))
>>> expr = f(g(x, f(y)), y)
>>> print(expr)
f(g(x, f(y)), y)
>>> print(f2g(expr))
g(g(x, g(y)), y)

By contrast with ``Evaluator[TreeExpr]`` the above would fail because no
rule has been defined for the head ``g`` or for ``Symbol`` (the
:class:`AtomType` of ``x`` and ``y``).

>>> f2g_eval = Evaluator[TreeExpr]()
>>> f2g_eval.add_opn(f, lambda args: g(*args))
>>> f2g_eval(expr)
Traceback (most recent call last):
...
protosym.core.exceptions.NoEvaluationRuleError: No rule for AtomType: Symbol
"""

def eval_atom(self, atom: TreeAtom[_S]) -> TreeExpr:
"""Return the atom as is."""
return atom

def eval_operation(self, head: TreeExpr, argvals: Iterable[TreeExpr]) -> TreeExpr:
"""Return unevaluated operation if no rule supplied."""
if head not in self.operations:
return head(*argvals)
else:
return super().eval_operation(head, argvals)
13 changes: 13 additions & 0 deletions src/protosym/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Module for all protosym exceptions."""


class ProtoSymError(Exception):
"""Superclass for all protosym exceptions."""

pass


class NoEvaluationRuleError(ProtoSymError):
"""Raised when an :class:`Evaluator` has no rule for an expression."""

pass
18 changes: 18 additions & 0 deletions src/protosym/simplecas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import math
from functools import reduce
from functools import wraps
from typing import Any
from typing import Callable
Expand All @@ -16,6 +17,7 @@

from protosym.core.atom import AtomType
from protosym.core.evaluate import Evaluator
from protosym.core.evaluate import Transformer
from protosym.core.tree import forward_graph
from protosym.core.tree import topological_sort
from protosym.core.tree import TreeAtom
Expand Down Expand Up @@ -242,6 +244,18 @@ def diff(self, sym: Expr, ntimes: int = 1) -> Expr:
deriv_rep = _diff_forward(deriv_rep, sym_rep)
return Expr(deriv_rep)

def bin_expand(self) -> Expr:
"""Expand associative operators to binary operations.

>>> from protosym.simplecas import Add
>>> expr = Add(1, 2, 3, 4)
>>> expr
(1 + 2 + 3 + 4)
>>> expr.bin_expand()
(((1 + 2) + 3) + 4)
"""
return Expr(_bin_expand(self.rep))


# Avoid importing SymPy if possible.
_eval_to_sympy: Evaluator[Any] | None = None
Expand Down Expand Up @@ -345,6 +359,10 @@ def from_sympy(expr: Any) -> Expr:
eval_latex.add_opn(Add.rep, lambda args: f'({" + ".join(args)})')
eval_latex.add_opn(Mul.rep, lambda args: "(%s)" % r" \times ".join(args))

_bin_expand = Transformer()
_bin_expand.add_opn(Add.rep, lambda args: reduce(Add.rep, args))
_bin_expand.add_opn(Mul.rep, lambda args: reduce(Mul.rep, args))


def _op1(a: int | str) -> int:
return 1
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import math
from typing import Callable

from pytest import raises

from protosym.core.atom import AtomType
from protosym.core.evaluate import Evaluator
from protosym.core.evaluate import Transformer
from protosym.core.exceptions import NoEvaluationRuleError
from protosym.core.tree import funcs_symbols
from protosym.core.tree import TreeAtom
from protosym.core.tree import TreeExpr

Expand Down Expand Up @@ -49,3 +54,18 @@ def test_Evaluator() -> None:
for expr, vals, expected in test_cases:
for func in eval_funcs:
assert func(expr, vals) == expected


def test_Transformer() -> None:
"""Test defining and using a Transformer."""
[f, g], [x, y] = funcs_symbols(["f", "g"], ["x", "y"])

f2g = Transformer()
f2g.add_opn(f, lambda args: g(*args))
expr = f(g(x, f(y)), y)
assert f2g(expr) == g(g(x, g(y)), y)

# With Evaluator the above would fail without rules for Symbol and g:
f2g_eval = Evaluator[TreeExpr]()
f2g_eval.add_opn(f, lambda args: g(*args))
raises(NoEvaluationRuleError, lambda: f2g_eval(expr))
35 changes: 34 additions & 1 deletion tests/test_simplecas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from pytest import raises
from pytest import skip

Expand All @@ -7,6 +8,7 @@
from protosym.simplecas import ExprAtomType
from protosym.simplecas import expressify
from protosym.simplecas import ExpressifyError
from protosym.simplecas import Function
from protosym.simplecas import Integer
from protosym.simplecas import Mul
from protosym.simplecas import negone
Expand All @@ -20,6 +22,7 @@


two = Integer(2)
f = Function("f")


def test_simplecas_types() -> None:
Expand Down Expand Up @@ -105,6 +108,7 @@ def test_simplecas_repr() -> None:
assert str(Integer) == "Integer"
assert str(x) == "x"
assert str(y) == "y"
assert str(f) == "f"
assert str(sin) == "sin"
assert str(sin(cos(x))) == "sin(cos(x))"
assert str(x + y) == "(x + y)"
Expand All @@ -114,6 +118,21 @@ def test_simplecas_repr() -> None:
assert str(x + x + x) == "((x + x) + x)"


@pytest.mark.xfail
def test_simplecas_repr_xfail() -> None:
"""Test printing an undefined function."""
#
# This fails because Evaluator expects every function to be defined. There
# is a printing rule for Function but that only handles an uncalled
# function like f rather than f(x). There needs to be a way to give a
# default rule to an Evaluator for handling the cases where there is no
# operation rule defined for the head.
#
assert str(f(x)) == "f(x)"
assert repr(f(x)) == "f(x)"
assert f(x).eval_latex() == "f(x)"


def test_simplecas_latex() -> None:
"""Test basic operations with simplecas."""
assert x.eval_latex() == r"x"
Expand All @@ -135,7 +154,7 @@ def test_simplecas_repr_latex() -> None:
def test_simplecas_to_sympy() -> None:
"""Test converting a simplecas expression to a SymPy expression."""
try:
sympy = __import__("sympy")
import sympy
except ImportError:
skip("SymPy not installed")

Expand Down Expand Up @@ -230,3 +249,17 @@ def test_simplecas_differentation() -> None:
assert (sin(x) + cos(x)).diff(x) == cos(x) + -1 * sin(x)
assert (sin(x) ** 2).diff(x) == 2 * sin(x) ** Add(2, -1) * cos(x)
assert (x * sin(x)).diff(x) == 1 * sin(x) + x * cos(x)


def test_simplecas_bin_expand() -> None:
"""Test Expr.bin_expand()."""
expr1 = Add(1, 2, 3, 4)
assert expr1.bin_expand() == Add(Add(Add(1, 2), 3), 4)
assert str(expr1) == "(1 + 2 + 3 + 4)"
assert str(expr1.bin_expand()) == "(((1 + 2) + 3) + 4)"

expr2 = Add(x, y, Mul(x, y, 1, f(x)))
assert expr2.bin_expand() == Add(Add(x, y), Mul(Mul(Mul(x, y), 1), f(x)))
# Fails because f(x) cannot be printed:
# assert str(expr2) == "(x + y + x*y*1*f(x))"
# assert str(expr2.bin_expand()) == "((x + y) + ((x*y)*1)*f(x))"