Skip to content

Commit 30d28f4

Browse files
authored
add n_draws and t_sampling report to smc (#3931)
* add n_draws and t_sampling report to smc * add _n_tune * update release notes * resolve conflicts
1 parent 1522492 commit 30d28f4

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- `pm.Data` container can now be used for index variables, i.e with integer data and not only floats (issue [#3813](https://github.com/pymc-devs/pymc3/issues/3813), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
1818
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
1919
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
20+
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
2021

2122
### Maintenance
2223
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).

Diff for: pymc3/smc/sample_smc.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .smc import SMC
15+
import time
1616
import logging
17+
from .smc import SMC
1718

1819

1920
def sample_smc(
@@ -144,6 +145,7 @@ def sample_smc(
144145
random_seed=random_seed,
145146
)
146147

148+
t1 = time.time()
147149
_log = logging.getLogger("pymc3")
148150
_log.info("Sample initial stage: ...")
149151
stage = 0
@@ -170,5 +172,7 @@ def sample_smc(
170172
smc.pool.join()
171173

172174
trace = smc.posterior_to_trace()
173-
175+
trace.report._n_draws = smc.draws
176+
trace.report._n_tune = 0
177+
trace.report._t_sampling = time.time() - t1
174178
return trace

0 commit comments

Comments
 (0)