Skip to content

Commit a94c35a

Browse files
committed
add new bo test with manually set x0
1 parent 0405de8 commit a94c35a

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

tests/mmm/test_budget_optimizer.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,36 @@ def dummy_df():
5959

6060

6161
@pytest.mark.parametrize(
62-
argnames="total_budget, budget_bounds, parameters, minimize_kwargs, expected_optimal, expected_response",
62+
argnames="total_budget, budget_bounds, x0, parameters, minimize_kwargs, expected_optimal, expected_response",
6363
argvalues=[
6464
(
6565
100,
6666
None,
67+
None,
68+
{
69+
"saturation_params": {
70+
"lam": np.array(
71+
[[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]
72+
), # dims: chain, draw, channel
73+
"beta": np.array(
74+
[[[0.5, 1.0], [0.5, 1.0]], [[0.5, 1.0], [0.5, 1.0]]]
75+
), # dims: chain, draw, channel
76+
},
77+
"adstock_params": {
78+
"alpha": np.array(
79+
[[[0.5, 0.7], [0.5, 0.7]], [[0.5, 0.7], [0.5, 0.7]]]
80+
) # dims: chain, draw, channel
81+
},
82+
},
83+
None,
84+
{"channel_1": 54.78357587906867, "channel_2": 45.21642412093133},
85+
48.8,
86+
),
87+
# set x0 manually
88+
(
89+
100,
90+
None,
91+
np.array([50, 50]),
6792
{
6893
"saturation_params": {
6994
"lam": np.array(
@@ -91,6 +116,7 @@ def dummy_df():
91116
channel=["channel_1", "channel_2"],
92117
bound=["lower", "upper"],
93118
),
119+
None,
94120
{
95121
"saturation_params": {
96122
"lam": np.array(
@@ -121,6 +147,7 @@ def dummy_df():
121147
channel=["channel_1", "channel_2"],
122148
bound=["lower", "upper"],
123149
),
150+
None,
124151
{
125152
"saturation_params": {
126153
"lam": np.array(
@@ -142,11 +169,17 @@ def dummy_df():
142169
2.38e-10,
143170
),
144171
],
145-
ids=["default_minimizer_kwargs", "custom_minimizer_kwargs", "zero_total_budget"],
172+
ids=[
173+
"default_minimizer_kwargs",
174+
"manually_set_x0",
175+
"custom_minimizer_kwargs",
176+
"zero_total_budget",
177+
],
146178
)
147179
def test_allocate_budget(
148180
total_budget,
149181
budget_bounds,
182+
x0,
150183
parameters,
151184
minimize_kwargs,
152185
expected_optimal,
@@ -184,6 +217,7 @@ def test_allocate_budget(
184217
optimal_budgets, optimization_res = optimizer.allocate_budget(
185218
total_budget=total_budget,
186219
budget_bounds=budget_bounds,
220+
x0=x0,
187221
minimize_kwargs=minimize_kwargs,
188222
)
189223

0 commit comments

Comments
 (0)