@@ -131,7 +131,9 @@ void init_pyspiel_policy(py::module& m) {
131
131
&open_spiel::PreferredActionPolicy::GetStatePolicy);
132
132
133
133
py::class_<open_spiel::algorithms::CFRSolver>(m, " CFRSolver" )
134
- .def (py::init<const Game&>())
134
+ .def (py::init ([](std::shared_ptr<const Game> game) {
135
+ return new algorithms::CFRSolver (*game);
136
+ }))
135
137
.def (" evaluate_and_update_policy" ,
136
138
&open_spiel::algorithms::CFRSolver::EvaluateAndUpdatePolicy)
137
139
.def (" current_policy" , &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -147,7 +149,9 @@ void init_pyspiel_policy(py::module& m) {
147
149
}));
148
150
149
151
py::class_<open_spiel::algorithms::CFRPlusSolver>(m, " CFRPlusSolver" )
150
- .def (py::init<const Game&>())
152
+ .def (py::init ([](std::shared_ptr<const Game> game) {
153
+ return new algorithms::CFRPlusSolver (*game);
154
+ }))
151
155
.def (" evaluate_and_update_policy" ,
152
156
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
153
157
.def (" current_policy" , &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -163,7 +167,9 @@ void init_pyspiel_policy(py::module& m) {
163
167
}));
164
168
165
169
py::class_<open_spiel::algorithms::CFRBRSolver>(m, " CFRBRSolver" )
166
- .def (py::init<const Game&>())
170
+ .def (py::init ([](std::shared_ptr<const Game> game) {
171
+ return new algorithms::CFRBRSolver (*game);
172
+ }))
167
173
.def (" evaluate_and_update_policy" ,
168
174
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
169
175
.def (" current_policy" , &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -184,7 +190,11 @@ void init_pyspiel_policy(py::module& m) {
184
190
185
191
py::class_<open_spiel::algorithms::ExternalSamplingMCCFRSolver>(
186
192
m, " ExternalSamplingMCCFRSolver" )
187
- .def (py::init<const Game&, int , open_spiel::algorithms::AverageType>(),
193
+ .def (py::init ([](std::shared_ptr<const Game> game, int seed,
194
+ algorithms::AverageType average_type) {
195
+ return new algorithms::ExternalSamplingMCCFRSolver (*game, seed,
196
+ average_type);
197
+ }),
188
198
py::arg (" game" ), py::arg (" seed" ) = 0 ,
189
199
py::arg (" avg_type" ) = open_spiel::algorithms::AverageType::kSimple )
190
200
.def (" run_iteration" ,
@@ -204,7 +214,12 @@ void init_pyspiel_policy(py::module& m) {
204
214
205
215
py::class_<open_spiel::algorithms::OutcomeSamplingMCCFRSolver>(
206
216
m, " OutcomeSamplingMCCFRSolver" )
207
- .def (py::init<const Game&, double , int >(), py::arg (" game" ),
217
+ .def (py::init (
218
+ [](std::shared_ptr<const Game> game, double epsilon, int seed) {
219
+ return new algorithms::OutcomeSamplingMCCFRSolver (
220
+ *game, epsilon, seed);
221
+ }),
222
+ py::arg (" game" ),
208
223
py::arg (" epsilon" ) = open_spiel::algorithms::
209
224
OutcomeSamplingMCCFRSolver::kDefaultEpsilon ,
210
225
py::arg (" seed" ) = -1 )
@@ -267,45 +282,54 @@ void init_pyspiel_policy(py::module& m) {
267
282
py::arg (" use_infostate_get_policy" ),
268
283
py::arg (" prob_cut_threshold" ) = 0.0 );
269
284
270
- m.def (" exploitability" ,
271
- py::overload_cast<const Game&, const Policy&>(&Exploitability),
272
- " Returns the sum of the utility that a best responder wins when when "
273
- " playing against 1) the player 0 policy contained in `policy` and 2) "
274
- " the player 1 policy contained in `policy`."
275
- " This only works for two player, zero- or constant-sum sequential "
276
- " games, and raises a SpielFatalError if an incompatible game is passed "
277
- " to it." );
285
+ m.def (
286
+ " exploitability" ,
287
+ [](std::shared_ptr<const Game> game, const Policy& policy) {
288
+ return Exploitability (*game, policy);
289
+ },
290
+ " Returns the sum of the utility that a best responder wins when when "
291
+ " playing against 1) the player 0 policy contained in `policy` and 2) "
292
+ " the player 1 policy contained in `policy`."
293
+ " This only works for two player, zero- or constant-sum sequential "
294
+ " games, and raises a SpielFatalError if an incompatible game is passed "
295
+ " to it." );
278
296
279
297
m.def (
280
298
" exploitability" ,
281
- py::overload_cast<
282
- const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
283
- &Exploitability),
299
+ [](std::shared_ptr<const Game> game,
300
+ const std::unordered_map<std::string, ActionsAndProbs>& policy) {
301
+ return Exploitability (*game, policy);
302
+ },
284
303
" Returns the sum of the utility that a best responder wins when when "
285
304
" playing against 1) the player 0 policy contained in `policy` and 2) "
286
305
" the player 1 policy contained in `policy`."
287
306
" This only works for two player, zero- or constant-sum sequential "
288
307
" games, and raises a SpielFatalError if an incompatible game is passed "
289
308
" to it." );
290
309
291
- m.def (" nash_conv" ,
292
- py::overload_cast<const Game&, const Policy&, bool >(&NashConv),
293
- " Calculates a measure of how far the given policy is from a Nash "
294
- " equilibrium by returning the sum of the improvements in the value "
295
- " that each player could obtain by unilaterally changing their strategy "
296
- " while the opposing player maintains their current strategy (which "
297
- " for a Nash equilibrium, this value is 0). The third parameter is to "
298
- " indicate whether to use the Policy::GetStatePolicy(const State&) "
299
- " instead of Policy::GetStatePolicy(const std::string& info_state) for "
300
- " computation of the on-policy expected values." ,
301
- py::arg (" game" ), py::arg (" policy" ),
302
- py::arg (" use_state_get_policy" ) = false );
310
+ m.def (
311
+ " nash_conv" ,
312
+ [](std::shared_ptr<const Game> game, const Policy& policy,
313
+ bool use_state_get_policy) {
314
+ return NashConv (*game, policy, use_state_get_policy);
315
+ },
316
+ " Calculates a measure of how far the given policy is from a Nash "
317
+ " equilibrium by returning the sum of the improvements in the value "
318
+ " that each player could obtain by unilaterally changing their strategy "
319
+ " while the opposing player maintains their current strategy (which "
320
+ " for a Nash equilibrium, this value is 0). The third parameter is to "
321
+ " indicate whether to use the Policy::GetStatePolicy(const State&) "
322
+ " instead of Policy::GetStatePolicy(const std::string& info_state) for "
323
+ " computation of the on-policy expected values." ,
324
+ py::arg (" game" ), py::arg (" policy" ),
325
+ py::arg (" use_state_get_policy" ) = false );
303
326
304
327
m.def (
305
328
" nash_conv" ,
306
- py::overload_cast<
307
- const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
308
- &NashConv),
329
+ [](std::shared_ptr<const Game> game,
330
+ const std::unordered_map<std::string, ActionsAndProbs>& policy) {
331
+ return NashConv (*game, policy);
332
+ },
309
333
" Calculates a measure of how far the given policy is from a Nash "
310
334
" equilibrium by returning the sum of the improvements in the value "
311
335
" that each player could obtain by unilaterally changing their strategy "
0 commit comments