Skip to content

Commit b9704f3

Browse files
alcreneAlexandre Renémichaelosthege
authored
Fix an issue with empty preallocation (#113)
* Fix issue with NumPy chain and preallocation=0 Original logic in `grow_append` was to extend data by 10% of its length. This is a problem with the original data length is 0, since it then never extends. This commit amends `grow_append` to always extend by at least 10 elements. * [NumPyBackend] Prevent ``preallocate=0`` from creating object arrays. `grow_append` cannot know if ``preallocate = 0`` was used: it only looks at the `rigid` value to determine how to append. Because of this, will always fail when we use `preallocate = 0` with tensor variables, since then the shapes of `target` and `extension` don’t match. A simple fix is to simply deactivate the special behavior for `preallocate = 0`. - This commit extends `test_growing` with a `preallocate` parameter, so that we test both cases where it is 0 and positive. - We also fix `test_growing` to match the new behavior of `grow_append` introduced in #ea812b0, where data arrays always grow by at least 10. * Remove trailing whitespace --------- Co-authored-by: Alexandre René <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent b35be42 commit b9704f3

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

mcbackend/backends/numpy.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def grow_append(
2323
length = len(target)
2424
if length == draw_idx:
2525
# Grow the array by 10 %
26-
ngrow = math.ceil(0.1 * length)
26+
ngrow = max(10, math.ceil(0.1 * length))
2727
if rigid[vn]:
2828
extension = numpy.empty((ngrow,) + numpy.shape(v))
2929
else:
@@ -52,7 +52,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
5252
and grow the allocated memory by 10 % when needed.
5353
Exceptions are variables with non-rigid shapes (indicated by 0 in the shape tuple)
5454
where the correct amount of memory cannot be pre-allocated.
55-
In these cases, and when ``preallocate == 0`` object arrays are used.
55+
In these cases object arrays are used.
5656
"""
5757
self._var_is_rigid: Dict[str, bool] = {}
5858
self._samples: Dict[str, numpy.ndarray] = {}
@@ -68,11 +68,11 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
6868
for var in variables:
6969
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
7070
rigid_dict[var.name] = rigid
71-
if preallocate > 0 and rigid:
71+
if rigid:
7272
reserve = (preallocate, *var.shape)
7373
target_dict[var.name] = numpy.empty(reserve, var.dtype)
7474
else:
75-
target_dict[var.name] = numpy.array([None] * preallocate)
75+
target_dict[var.name] = numpy.array([None] * preallocate, dtype=object)
7676

7777
super().__init__(cmeta, rmeta)
7878

mcbackend/test_backend_numpy.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hagelkorn
44
import numpy
5+
import pytest
56

67
from mcbackend.backends.numpy import NumPyBackend, NumPyChain, NumPyRun
78
from mcbackend.core import RunMeta
@@ -43,8 +44,9 @@ def test_targets(self):
4344
assert chain._samples["changeling"].dtype == object
4445
pass
4546

46-
def test_growing(self):
47-
imb = NumPyBackend(preallocate=15)
47+
@pytest.mark.parametrize("preallocate", [0, 75])
48+
def test_growing(self, preallocate):
49+
imb = NumPyBackend(preallocate=preallocate)
4850
rm = RunMeta(
4951
rid=hagelkorn.random(),
5052
variables=[
@@ -62,19 +64,27 @@ def test_growing(self):
6264
)
6365
run = imb.init_run(rm)
6466
chain = run.init_chain(0)
65-
assert chain._samples["A"].shape == (15, 2)
66-
assert chain._samples["B"].shape == (15,)
67-
for _ in range(22):
67+
assert chain._samples["A"].shape == (preallocate, 2)
68+
assert chain._samples["B"].shape == (preallocate,)
69+
for _ in range(130):
6870
draw = {
6971
"A": numpy.random.uniform(size=(2,)),
7072
"B": numpy.random.uniform(size=(random.randint(0, 10),)),
7173
}
7274
chain.append(draw)
73-
# Growth: 15 → 17 → 19 → 21 → 24
74-
assert chain._samples["A"].shape == (24, 2)
75-
assert chain._samples["B"].shape == (24,)
76-
assert chain.get_draws("A").shape == (22, 2)
77-
assert chain.get_draws("B").shape == (22,)
75+
# NB: Growth algorithm adds max(10, ceil(0.1*length))
76+
if preallocate == 75:
77+
# 75 → 85 → 95 → 105 → 116 → 128 → 141
78+
assert chain._samples["A"].shape == (141, 2)
79+
assert chain._samples["B"].shape == (141,)
80+
elif preallocate == 0:
81+
# 10 → 20 → ... → 90 → 100 → 110 → 121 → 134
82+
assert chain._samples["A"].shape == (134, 2)
83+
assert chain._samples["B"].shape == (134,)
84+
else:
85+
assert False, f"Missing test for {preallocate=}"
86+
assert chain.get_draws("A").shape == (130, 2)
87+
assert chain.get_draws("B").shape == (130,)
7888
pass
7989

8090

0 commit comments

Comments
 (0)