Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sym as a superclass for Expr #43

Merged
merged 8 commits into from
Mar 17, 2023
Merged

Add Sym as a superclass for Expr #43

merged 8 commits into from
Mar 17, 2023

Conversation

oscarbenjamin
Copy link
Owner

This PR adds a class Sym in the core to be the superclass for user-facing classes like Expr. This makes it possible to define a nicer interface that can be used to specify the rules of Evaluators so instead of

eval_f64.add_op(Pow, math.pow)

we can have something that looks more like pattern matching:

eval_f64[a**b] = f64_pow(a, b)

Currently making this work in a type-safe way needs lots of function wrappers and is a bit complicated. I had been working on a way to represent Python code symbolically that could have been used to make that part nicer but I'm going to postpone that for now because this gets the syntax to where I want it to be for specifying evaluation rules.

The idea is to have one module somewhere that defines lots of type-specific operations like:

# f64.py
f64_pow = PyOp2[float](math.pow)
f64_sum = PyOpN[float](math.fsum)
...

And then these functions are reusable and can be used as evaluation rules in a symbolic-looking way:

eval_f64 = Expr.new_evaluator()
eval_f64[a**b] = f64_pow(a, b)
eval_f64[Add(star(a))] = f64_sum(a)

Then if the front-end uses things like f64_pow from the core a lower-level implementation in C or Rust could internally provide its own special versions of these and optimise them internally rather than using the Python level callables. Then the frontend code for simplecas can look more like a declarative DSL rather than having lots of Python callables and imperative code. This is the main goal that I'm striving for.

The need for the Sym class to be in core is so that it can define the common interface .rep that would be needed for implementing things in the core while making them directly usable with something like Expr. I intend to move the differentiation code into a generic Differentiator class in the core which is also made possible by this.

Making this work for Evaluator has been quite complex because of the interaction between the symbolic and lower-level types (the internal values of the atoms). In the case of symbolic rewrites it should be a lot easier but I would intend to have the same sort of syntax e.g.:

exp2trig = Rewriter()
exp2trig[exp(a)] = cosh(a) + sinh(a)
exp2trig[exp(I*a)] = cos(a) + I*sin(a)
...

There it is a lot easier though because all the types are simple.

) -> None:
...

def __setitem__( # noqa: C901
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flake8 complains that this function is too complicated which is probably true.


# e.g. eval_f64[cos(a)] = f64_cos(a)
@overload
def __setitem__( # noqa: D105
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because of a flake8 bug.

Comment on lines 25 to 112
def test_Sym() -> None:
"""Test defining a simple subclass of Sym."""

class Expr(Sym):
def __repr__(self) -> str:
return to_str(self)

def __call__(self, *args: Expr) -> Expr:
args_rep = [arg.rep for arg in args]
return Expr(self.rep(*args_rep))

Integer = Expr.new_atom("Integer", int)
Symbol = Expr.new_atom("Symbol", str)
Function = Expr.new_atom("Function", str)
cos = Function("cos")
sin = Function("sin")
Add = Function("Add")
one = Integer(1)
x = Symbol("x")

assert str(Integer) == repr(Integer) == "Integer"

raises(TypeError, lambda: Expr(1)) # type:ignore
assert Expr(one.rep) is one

assert type(Integer) is SymAtomType
assert type(Function) is SymAtomType
assert type(cos) is Expr
assert type(Add) is Expr
assert type(cos.rep) is TreeAtom

a = Expr.new_wild("a")
b = Expr.new_wild("b")
assert type(a) == type(b) == Expr

to_str = Expr.new_evaluator("to_str", str)
to_str[Integer[a]] = PyFunc1[int, str](str)(a)
to_str[AtomRule[a]] = AtomFunc(str)(a)
to_str[cos(a)] = PyOp1(lambda s: f"cos({s})")(a)
to_str[Add(star(a))] = PyOpN(" + ".join)(a)
to_str[HeadRule(a, b)] = HeadOp(lambda f, a: f"{f}({', '.join(a)})")(a, b)

assert to_str(cos(one)) == "cos(1)"
assert to_str(Add(one, one, one)) == "1 + 1 + 1"

# Test the generic rules
assert to_str(sin) == "sin"
assert to_str(sin(one)) == "sin(1)"

assert type(to_str) == SymEvaluator
assert repr(to_str) == repr(to_str) == "to_str"

eval_f64 = Expr.new_evaluator("eval_f64", float)
eval_f64[Integer[a]] = PyFunc1[int, float](float)(a)
eval_f64[cos(a)] = PyOp1(math.cos)(a)
eval_f64[Add(a, b)] = PyOp2[float](lambda a, b: a + b)(a, b)

assert eval_f64(cos(one)) == approx(0.5403023058681398)
assert eval_f64(Add(one, one)) == 2.0
assert eval_f64(Add(x, one), {x: -1.0}) == 0.0

s_one = Sym.new_atom("Integer", int)(1)
assert str(s_one) == "1"
assert repr(s_one) == "Sym(TreeAtom(Integer(1)))"

bad_examples = [
(Integer[a], PyOp1(math.cos)(a)),
(cos(a), PyFunc1(math.cos)(a)),
(AtomRule[a], PyOp1(math.cos)(a)),
(HeadRule(a, b), PyOp2(math.atan2)(a, b)),
(Integer[a], PyFunc1(int)),
(Integer[a], PyFunc1(int)(b)),
(cos(a), PyOp2(math.atan2)(a, b)),
(AtomRule[a], PyOp2(math.atan2)(a, b)),
(HeadRule(a, b), PyOp1(math.cos)(a)),
(Add(a), PyOpN[float](sum)(a)),
]

def set_bad_rule(k: Any, v: Any) -> None:
eval_f64[k] = v

for key, value in bad_examples:
raises(BadRuleError, lambda: set_bad_rule(key, value))

def set_bad_op() -> None:
eval_f64.add_op(cos, lambda x: x) # type:ignore

raises(BadRuleError, set_bad_op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of test logic inside a single test function. If there's a failure early on then you won't know about other potential failures further down until the original failure is addressed. Although I can appreciate that creating Expr and Integer/Symbol/Function are testing the codebase themselves so it's nontrivial to refactor this into smaller chunks.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. Probably best is to make a reusable helper function that can set up some of the pieces.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've broken this up into smaller functions. I had to move Expr to top-level so the type hints would work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great refactor, thanks.

Comment on lines +79 to +80
This class should not be used directly but rather subclassed to make a
user-facing symbolic expression type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the case could we have it subclass abc.ABC from the standard library to enforce this behaviour?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. I seem to remember that the ABCs are slow or that they slow down isinstance testing.

Right now it's just difficult to use Sym directly because it doesn't provide the methods that you would want. The main missing method is __call__ but I didn't want to export that to all potential subclasses so I didn't add it. It already has too many attributes and methods now.

I've never really used ABC so I'm not sure exactly what the benefits are. Does it just mean that you can't create instances?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also add something to make Sym easier to use although I'm not sure what the usecase would be for using it directly. It might be useful for Sym (and all subclasses) to accept a tuple as an argument like:

expr = Sym((cos, one))

That would at least make it a bit easier to create a Sym instance making it potentially more useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never really used ABC so I'm not sure exactly what the benefits are. Does it just mean that you can't create instances?

It lets you specify which methods a concrete class must implement for it to be able to create instances. For example, if you wanted to ensure that all subclasses of Sym implement their own __init__ method then you could do something like

from abc import ABC, abstractmethod

class Sym(ABC):

    @abstractmethod
    def __init__(self):
        pass

Then if a user tries to create an instance of Sym directly they'll get an error: TypeError: Can't instantiate abstract class Sym with abstract method __init__

Similarly, if a user tries to implement a concrete subclass of Sym, e.g. Expr, without overriding the __init__ method : TypeError: Can't instantiate abstract class Expr with abstract method __init__

But this is only useful if there are specific methods/properties that you want to enforce have to be overridden. If there aren't and you just want to ensure that Sym is never instantiated then you'd likely need to do something like

class Sym:

    def __new__(cos, *args, **kwargs):
        if cls is BaseClass:
            raise TypeError(f"'{cls.__name__}' can't be instantiated directly")
        return object.__new__(cls, *args, **kwargs)

But in this case, just documenting that Sym is intended to be subclassed, as you've done, is probably the best approach.

I also don't know whether there's any performance penalty associated with having the base class inherit from abc.ABC.

Comment on lines +343 to +344
def __call__(self: T_op, args: T_sym) -> WildCall[T_sym, T_op]:
return WildCall(self, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should args not be type hinted as Sequence[T_sym]? And the return WildCall(self, *args)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it shouldn't. This is quite confusing with all these classes and I would like to simplify it. I thought I had a better approach but gave up (postponed) because it was taking too long and I thought I was overthinking things.

Here the rule is:

f64_add = PyOpN[float](math.fsum)
eval_f64[Add(star(a))] = f64_add(a)

The class we are looking at here is PyOpN so f64_add is an instance which wraps the function fsum that takes a single iterable argument. When f64_add itself is called it will be called with a single argument a which is a wild symbol representing the single iterable argument to fsum. So the argument is T_sym (meaning a). In the WildCall that argument a is like a placeholder and when the actual call to fsum happens it will be called with Iterable[float] in place of a.

It all looks very confusing in the internals but hopefully in the external part the star(a) captures the fact that we're mapping from Add(*args) to fsum(args):

eval_f64[Add(star(a))] = f64_add(a)

I did contemplate implementing __iter__ so that you could literally write it as

eval_f64[Add(*a)] = f64_add(a)

That would be nice but I didn't want to add __iter__ to all of Sym. If there was a way to keep it local to one class then that could work but currently a wild has the same type as the class that creates it (e.g. Expr.new_wild('a') is Expr).

Perhaps f64_sum is a better name to convey the fact that it takes an iterable like sum and fsum rather than e.g. operator.add or __add__ which are binary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, that makes it clear now.

if isinstance(call.op, PyOpN):
(callarg,) = call.args
if pattern.args != (star(callarg),):
raise BadRuleError("Nary function needs a star-rule.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a second to understand what "Nary" was referring to here. Thoughts on using multary or multiary instead?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe n-ary is better?

I've never heard of multary or multiary.

Another possibility would be "var-args" which is more like "polyadic". Perhaps these are actually more accurate because it is not a fixed "n" here.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to "Varargs" which is perhaps more well known to Python programmers at least...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Wikipedia article on Arity states that the term for "more than 2-ary" can be either multary or multiary. I've come across both multary and multiary alongside unary and binary in textbooks before.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The article uses n-ary more than either of those terms but in any case what we have is actually what the article refers to as "variadic" i.e. the number of arguments is not any fixed n but can be different in each call.

I suppose it helps here to see the error message in context:

In [7]: from protosym.simplecas import *

In [8]: eval_f64[Add(a)] = PyOpN(sum)(a)
...
BadRuleError: varargs function needs a star-rule.

In [9]: eval_f64[Add(star(a))] = PyOpN(sum)(a)  # ok

I'm sure the message can be improved but it's a bit tricky to get it exactly right because it runtime it is hard to figure out exactly what the user is doing or trying to do. The Evaluator does not intrinsically have any idea that Add itself should be variadic so it's difficult to say "Add is variadic and needs a star rule". The problem is that the rhs uses a PyOpN with parameter a and so the lhs needs to have star(a) somewhere.

Testing this out I see other cases that are accepted at runtime for example:

In [10]: eval_f64[Add(a)] = PyOp1(sum)(a)  # should be an error

The type-checker would reject that because PyOp1(sum) implies a function that is not T -> T.

from __future__ import annotations

from protosym.core.sym import PyOp1
from protosym.simplecas import eval_f64, Add, a

eval_f64[Add(a)] = PyOp1(sum)(a)

Here mypy gives:

t.py:6:20: error: Cannot infer type argument 1 of "PyOp1"  [misc]
    eval_f64[Add(a)] = PyOp1(sum)(a)
                       ^~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)

It can't infer the type because PyOp1 expects Callable[[S], S] and sum is like Callable[[Iterable[T]], T] and it is not clear how to find an S that unifies Iterable[T] and T.

I saw a few cases like this where I thought that the type-checker was just not smart enough and tried using type: ignore. Sure enough when I ran the tests and traced the failures back it came to right where mypy had been complaining.

Copy link
Collaborator

@brocksam brocksam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the new syntax for specifying rules. Much clearer and more intuitive.

@oscarbenjamin oscarbenjamin added the enhancement New feature or request label Mar 16, 2023
@oscarbenjamin
Copy link
Owner Author

Thanks for the quick review!

@oscarbenjamin
Copy link
Owner Author

Thanks again for the quick review. I think I've addressed all comments.

Copy link
Collaborator

@brocksam brocksam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved provided you're still happy with var-args after reading about arity.

@oscarbenjamin
Copy link
Owner Author

Thanks for the reviews! I'll merge this for now.

@oscarbenjamin oscarbenjamin merged commit 48bc65c into main Mar 17, 2023
@oscarbenjamin oscarbenjamin deleted the pr_sym branch March 17, 2023 10:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants