Skip to content

Commit 2f83818

Browse files
Merge pull request #5320 from ricardoV94/joint_logpt
Refactor `Factor` properties
2 parents 3c5cafe + 75ea232 commit 2f83818

35 files changed

+509
-613
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ jobs:
4949
--ignore=pymc/tests/test_updates.py
5050
--ignore=pymc/tests/test_gp.py
5151
--ignore=pymc/tests/test_model.py
52-
--ignore=pymc/tests/test_model_func.py
5352
--ignore=pymc/tests/test_ode.py
5453
--ignore=pymc/tests/test_posdef_sym.py
5554
--ignore=pymc/tests/test_quadpotential.py
@@ -82,7 +81,6 @@ jobs:
8281
pymc/tests/test_distributions_timeseries.py
8382
pymc/tests/test_gp.py
8483
pymc/tests/test_model.py
85-
pymc/tests/test_model_func.py
8684
pymc/tests/test_model_graph.py
8785
pymc/tests/test_ode.py
8886
pymc/tests/test_posdef_sym.py
@@ -166,7 +164,6 @@ jobs:
166164
pymc/tests/test_ode.py
167165
- |
168166
pymc/tests/test_model.py
169-
pymc/tests/test_model_func.py
170167
pymc/tests/test_modelcontext.py
171168
pymc/tests/test_model_graph.py
172169
pymc/tests/test_pickling.py

RELEASE-NOTES.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ All of the above apply to:
3939
- `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)).
4040
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
4141
- `pm.sample(trace=...)` no longer accepts `MultiTrace` or `len(.) > 0` traces ([see 5019#](https://github.com/pymc-devs/pymc/pull/5019)).
42+
- `logpt`, `logpt_sum`, `logp_elemwiset` and `nojac` variations were removed. Use `Model.logpt(jacobian=True/False, sum=True/False)` instead.
43+
- `dlogp_nojact` and `d2logp_nojact` were removed. Use `Model.dlogpt` and `d2logpt` with `jacobian=False` instead.
44+
- `logp`, `dlogp`, and `d2logp` and `nojac` variations were removed. Use `Model.compile_logp`, `compile_dlgop` and `compile_d2logp` with `jacobian` keyword instead.
45+
- `model.makefn` is now called `Model.compile_fn`, and `model.fn` was removed.
46+
- Methods starting with `fast_*`, such as `Model.fast_logp`, were removed. Same applies to `PointFunc` classes
4247
- The GLM submodule was removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
4348
- `pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
44-
- `pm.logpt(transformed=...)` kwarg was removed (816b5f).
4549
- `Model(model=...)` kwarg was removed
4650
- `Model(theano_config=...)` kwarg was removed
4751
- `Model.size` property was removed (use `Model.ndim` instead).

pymc/backends/arviz.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,26 @@ def _extract_log_likelihood(self, trace):
246246
# TODO: We no longer need one function per observed variable
247247
if self.log_likelihood is True:
248248
cached = [
249-
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
249+
(
250+
var,
251+
self.model.compile_fn(
252+
self.model.logpt(var, sum=False)[0],
253+
inputs=self.model.value_vars,
254+
on_unused_input="ignore",
255+
),
256+
)
250257
for var in self.model.observed_RVs
251258
]
252259
else:
253260
cached = [
254-
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
261+
(
262+
var,
263+
self.model.compile_fn(
264+
self.model.logpt(var, sum=False)[0],
265+
inputs=self.model.value_vars,
266+
on_unused_input="ignore",
267+
),
268+
)
255269
for var in self.model.observed_RVs
256270
if var.name in self.log_likelihood
257271
]

pymc/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, name, model=None, vars=None, test_point=None):
6565

6666
self.vars = vars
6767
self.varnames = [var.name for var in vars]
68-
self.fn = model.fastfn(vars)
68+
self.fn = model.compile_fn(vars, inputs=model.value_vars, on_unused_input="ignore")
6969

7070
# Get variable shapes. Most backends will need this
7171
# information.

pymc/distributions/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
logcdf,
1717
logp,
1818
logp_transform,
19-
logpt,
20-
logpt_sum,
19+
joint_logpt,
2120
)
2221

2322
from pymc.distributions.bound import Bound
@@ -191,9 +190,8 @@
191190
"Censored",
192191
"CAR",
193192
"PolyaGamma",
194-
"logpt",
193+
"joint_logpt",
195194
"logp",
196195
"logp_transform",
197196
"logcdf",
198-
"logpt_sum",
199197
]

pymc/distributions/logprob.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
1514

1615
from collections.abc import Mapping
1716
from functools import singledispatch
@@ -118,7 +117,7 @@ def _get_scaling(total_size, shape, ndim):
118117
)
119118

120119

121-
def logpt(
120+
def joint_logpt(
122121
var: Union[TensorVariable, List[TensorVariable]],
123122
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
124123
*,
@@ -264,17 +263,3 @@ def logcdf(rv, value):
264263

265264
value = at.as_tensor_variable(value, dtype=rv.dtype)
266265
return logcdf_aeppl(rv, value)
267-
268-
269-
def logpt_sum(*args, **kwargs):
270-
"""Return the sum of the logp values for the given observations.
271-
272-
Subclasses can use this to improve the speed of logp evaluations
273-
if only the sum of the logp values is needed.
274-
"""
275-
warnings.warn(
276-
"logpt_sum has been deprecated, you can use logpt instead, which now defaults"
277-
"to the same behavior of logpt_sum",
278-
DeprecationWarning,
279-
)
280-
return logpt(*args, sum=True, **kwargs)

0 commit comments

Comments
 (0)