Skip to content

Commit 4c771f6

Browse files
authored
Merge pull request #3619 from ColCarroll/matmul
Add matrix multiplication infix operator
2 parents e3b667c + 79dcea4 commit 4c771f6

File tree

4 files changed

+43
-4
lines changed

4 files changed

+43
-4
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
1010
- Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547).
1111
- Sampling from variational approximation now allows for alternative trace backends [#3550].
12+
- Infix `@` operator now works with random variables and deterministics [#3619](https://github.com/pymc-devs/pymc3/pull/3619).
1213
- [ArviZ](https://arviz-devs.github.io/arviz/) is now a requirement, and handles plotting, diagnostics, and statistical checks.
1314

1415
### Maintenance

Diff for: docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
weights = pm.Normal('weights', mu=0, sigma=1)
2727
noise = pm.Gamma('noise', alpha=2, beta=1)
2828
y_observed = pm.Normal('y_observed',
29-
mu=X.dot(weights),
29+
mu=X @ weights,
3030
sigma=noise,
3131
observed=y)
3232

Diff for: pymc3/model.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')
3131

3232

33+
class PyMC3Variable(TensorVariable):
34+
"""Class to wrap Theano TensorVariable for custom behavior."""
35+
36+
# Implement matrix multiplication infix operator: X @ w
37+
__matmul__ = tt.dot
38+
39+
def __rmatmul__(self, other):
40+
return tt.dot(other, self)
41+
42+
3343
class InstanceMethod:
3444
"""Class for hiding references to instance methods so they can be pickled.
3545
@@ -1245,7 +1255,7 @@ def _get_scaling(total_size, shape, ndim):
12451255
return tt.as_tensor(floatX(coef))
12461256

12471257

1248-
class FreeRV(Factor, TensorVariable):
1258+
class FreeRV(Factor, PyMC3Variable):
12491259
"""Unobserved random variable that a model is specified in terms of."""
12501260

12511261
def __init__(self, type=None, owner=None, index=None, name=None,
@@ -1354,7 +1364,7 @@ def as_tensor(data, name, model, distribution):
13541364
return data
13551365

13561366

1357-
class ObservedRV(Factor, TensorVariable):
1367+
class ObservedRV(Factor, PyMC3Variable):
13581368
"""Observed random variable that a model is specified in terms of.
13591369
Potentially partially observed.
13601370
"""
@@ -1525,7 +1535,7 @@ def Potential(name, var, model=None):
15251535
return var
15261536

15271537

1528-
class TransformedRV(TensorVariable):
1538+
class TransformedRV(PyMC3Variable):
15291539
"""
15301540
Parameters
15311541
----------

Diff for: pymc3/tests/test_model.py

+28
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,34 @@ def test_nested(self):
157157
assert theano.config.compute_test_value == 'ignore'
158158
assert theano.config.compute_test_value == 'off'
159159

160+
def test_matrix_multiplication():
161+
# Check matrix multiplication works between RVs, transformed RVs,
162+
# Deterministics, and numpy arrays
163+
with pm.Model() as linear_model:
164+
matrix = pm.Normal('matrix', shape=(2, 2))
165+
transformed = pm.Gamma('transformed', alpha=2, beta=1, shape=2)
166+
rv_rv = pm.Deterministic('rv_rv', matrix @ transformed)
167+
np_rv = pm.Deterministic('np_rv', np.ones((2, 2)) @ transformed)
168+
rv_np = pm.Deterministic('rv_np', matrix @ np.ones(2))
169+
rv_det = pm.Deterministic('rv_det', matrix @ rv_rv)
170+
det_rv = pm.Deterministic('det_rv', rv_rv @ transformed)
171+
172+
posterior = pm.sample(10,
173+
tune=0,
174+
compute_convergence_checks=False,
175+
progressbar=False)
176+
for point in posterior.points():
177+
npt.assert_almost_equal(point['matrix'] @ point['transformed'],
178+
point['rv_rv'])
179+
npt.assert_almost_equal(np.ones((2, 2)) @ point['transformed'],
180+
point['np_rv'])
181+
npt.assert_almost_equal(point['matrix'] @ np.ones(2),
182+
point['rv_np'])
183+
npt.assert_almost_equal(point['matrix'] @ point['rv_rv'],
184+
point['rv_det'])
185+
npt.assert_almost_equal(point['rv_rv'] @ point['transformed'],
186+
point['det_rv'])
187+
160188

161189
def test_duplicate_vars():
162190
with pytest.raises(ValueError) as err:

0 commit comments

Comments
 (0)