@@ -59,11 +59,36 @@ def dummy_df():
59
59
60
60
61
61
@pytest .mark .parametrize (
62
- argnames = "total_budget, budget_bounds, parameters, minimize_kwargs, expected_optimal, expected_response" ,
62
+ argnames = "total_budget, budget_bounds, x0, parameters, minimize_kwargs, expected_optimal, expected_response" ,
63
63
argvalues = [
64
64
(
65
65
100 ,
66
66
None ,
67
+ None ,
68
+ {
69
+ "saturation_params" : {
70
+ "lam" : np .array (
71
+ [[[0.1 , 0.2 ], [0.3 , 0.4 ]], [[0.5 , 0.6 ], [0.7 , 0.8 ]]]
72
+ ), # dims: chain, draw, channel
73
+ "beta" : np .array (
74
+ [[[0.5 , 1.0 ], [0.5 , 1.0 ]], [[0.5 , 1.0 ], [0.5 , 1.0 ]]]
75
+ ), # dims: chain, draw, channel
76
+ },
77
+ "adstock_params" : {
78
+ "alpha" : np .array (
79
+ [[[0.5 , 0.7 ], [0.5 , 0.7 ]], [[0.5 , 0.7 ], [0.5 , 0.7 ]]]
80
+ ) # dims: chain, draw, channel
81
+ },
82
+ },
83
+ None ,
84
+ {"channel_1" : 54.78357587906867 , "channel_2" : 45.21642412093133 },
85
+ 48.8 ,
86
+ ),
87
+ # set x0 manually
88
+ (
89
+ 100 ,
90
+ None ,
91
+ np .array ([50 , 50 ]),
67
92
{
68
93
"saturation_params" : {
69
94
"lam" : np .array (
@@ -91,6 +116,7 @@ def dummy_df():
91
116
channel = ["channel_1" , "channel_2" ],
92
117
bound = ["lower" , "upper" ],
93
118
),
119
+ None ,
94
120
{
95
121
"saturation_params" : {
96
122
"lam" : np .array (
@@ -121,6 +147,7 @@ def dummy_df():
121
147
channel = ["channel_1" , "channel_2" ],
122
148
bound = ["lower" , "upper" ],
123
149
),
150
+ None ,
124
151
{
125
152
"saturation_params" : {
126
153
"lam" : np .array (
@@ -142,11 +169,17 @@ def dummy_df():
142
169
2.38e-10 ,
143
170
),
144
171
],
145
- ids = ["default_minimizer_kwargs" , "custom_minimizer_kwargs" , "zero_total_budget" ],
172
+ ids = [
173
+ "default_minimizer_kwargs" ,
174
+ "manually_set_x0" ,
175
+ "custom_minimizer_kwargs" ,
176
+ "zero_total_budget" ,
177
+ ],
146
178
)
147
179
def test_allocate_budget (
148
180
total_budget ,
149
181
budget_bounds ,
182
+ x0 ,
150
183
parameters ,
151
184
minimize_kwargs ,
152
185
expected_optimal ,
@@ -184,6 +217,7 @@ def test_allocate_budget(
184
217
optimal_budgets , optimization_res = optimizer .allocate_budget (
185
218
total_budget = total_budget ,
186
219
budget_bounds = budget_bounds ,
220
+ x0 = x0 ,
187
221
minimize_kwargs = minimize_kwargs ,
188
222
)
189
223
0 commit comments