Skip to content

Commit c9f2e37

Browse files
elkhrtlanctot
authored andcommitted
Change parameter passing style for Game objects in order to support Python games.
The pybind smart_holder logic will create a shared_ptr for Python-created objects only when required to do so. This means that if a Python-implemented game is passed from Python to C++ as Game& and then a C++ function calls shared_from_this() on it, this will fail unless there's already a C++ shared_ptr for some other reason. The fix is either: a - Amend the C++ interface to take shared_ptr instead of refs b - Introduce a lambda function in the pybind interface, taking a shared_ptr and dereferencing it to call the ref-based C++ implementation Either option will result in pybind creating a shared_ptr for us before calling our C++ code. To minimize disruption to existing code, and forestall future failures, I've applied change (b) everywhere I could see, even though not every case was failing (because not every case called shared_from_this in the C++ implementation). For further details of the relevant pybind internals, see pybind/pybind11#3023 fixes: #905 PiperOrigin-RevId: 469016236 Change-Id: I9467eeb992f3463a432cc7060c46404d2bbd4638
1 parent f4121ea commit c9f2e37

File tree

8 files changed

+168
-97
lines changed

8 files changed

+168
-97
lines changed

open_spiel/python/games/kuhn_poker_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def test_exploitability_uniform_random_cc(self):
8181
self.assertAlmostEqual(
8282
pyspiel.exploitability(game, test_policy), expected_nash_conv / 2)
8383

84+
def test_cfr_cc(self):
85+
"""Runs a C++ CFR algorithm on the game."""
86+
game = pyspiel.load_game("python_kuhn_poker")
87+
unused_results = pyspiel.CFRSolver(game)
88+
8489

8590
if __name__ == "__main__":
8691
absltest.main()

open_spiel/python/pybind11/algorithms_corr_dist.cc

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,37 @@ void init_pyspiel_algorithms_corr_dist(py::module& m) {
5050
.def_readonly("conditional_best_response_policies",
5151
&CorrDistInfo::conditional_best_response_policies);
5252

53-
m.def("cce_dist",
54-
py::overload_cast<const Game&, const CorrelationDevice&, int, float>(
55-
&open_spiel::algorithms::CCEDist),
56-
"Returns a player's distance to a coarse-correlated equilibrium.",
57-
py::arg("game"),
58-
py::arg("correlation_device"),
59-
py::arg("player"),
60-
py::arg("prob_cut_threshold") = -1.0);
53+
m.def(
54+
"cce_dist",
55+
[](std::shared_ptr<const Game> game,
56+
const CorrelationDevice& correlation_device, int player,
57+
float prob_cut_threshold) {
58+
return algorithms::CCEDist(*game, correlation_device, player,
59+
prob_cut_threshold);
60+
},
61+
"Returns a player's distance to a coarse-correlated equilibrium.",
62+
py::arg("game"), py::arg("correlation_device"), py::arg("player"),
63+
py::arg("prob_cut_threshold") = -1.0);
6164

62-
m.def("cce_dist",
63-
py::overload_cast<const Game&, const CorrelationDevice&, float>(
64-
&open_spiel::algorithms::CCEDist),
65-
"Returns the distance to a coarse-correlated equilibrium.",
66-
py::arg("game"),
67-
py::arg("correlation_device"),
68-
py::arg("prob_cut_threshold") = -1.0);
65+
m.def(
66+
"cce_dist",
67+
[](std::shared_ptr<const Game> game,
68+
const CorrelationDevice& correlation_device,
69+
float prob_cut_threshold) {
70+
return algorithms::CCEDist(*game, correlation_device,
71+
prob_cut_threshold);
72+
},
73+
"Returns the distance to a coarse-correlated equilibrium.",
74+
py::arg("game"), py::arg("correlation_device"),
75+
py::arg("prob_cut_threshold") = -1.0);
6976

70-
m.def("ce_dist",
71-
py::overload_cast<const Game&, const CorrelationDevice&>(
72-
&open_spiel::algorithms::CEDist),
73-
"Returns the distance to a correlated equilibrium.");
77+
m.def(
78+
"ce_dist",
79+
[](std::shared_ptr<const Game> game,
80+
const CorrelationDevice& correlation_device) {
81+
return algorithms::CEDist(*game, correlation_device);
82+
},
83+
"Returns the distance to a correlated equilibrium.");
7484

7585
// TODO(author5): expose the rest of the functions.
7686
}

open_spiel/python/pybind11/algorithms_trajectories.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,28 @@ void init_pyspiel_algorithms_trajectories(py::module& m) {
5252
.def("resize_fields",
5353
&open_spiel::algorithms::BatchedTrajectory::ResizeFields);
5454

55-
m.def("record_batched_trajectories",
56-
py::overload_cast<
57-
const Game&, const std::vector<open_spiel::TabularPolicy>&,
58-
const std::unordered_map<std::string, int>&, int, bool, int, int>(
59-
&open_spiel::algorithms::RecordBatchedTrajectory),
60-
"Records a batch of trajectories.");
55+
m.def(
56+
"record_batched_trajectories",
57+
[](std::shared_ptr<const Game> game,
58+
const std::vector<TabularPolicy>& policies,
59+
const std::unordered_map<std::string, int>& state_to_index,
60+
int batch_size, bool include_full_observations, int seed,
61+
int max_unroll_length) {
62+
return open_spiel::algorithms::RecordBatchedTrajectory(
63+
*game, policies, state_to_index, batch_size,
64+
include_full_observations, seed, max_unroll_length);
65+
},
66+
"Records a batch of trajectories.");
6167

6268
py::class_<open_spiel::algorithms::TrajectoryRecorder>(m,
6369
"TrajectoryRecorder")
64-
.def(py::init<const Game&, const std::unordered_map<std::string, int>&,
65-
int>())
70+
.def(py::init(
71+
[](std::shared_ptr<const Game> game,
72+
const std::unordered_map<std::string, int>& state_to_index,
73+
int seed) {
74+
return new algorithms::TrajectoryRecorder(*game, state_to_index,
75+
seed);
76+
}))
6677
.def("record_batch",
6778
&open_spiel::algorithms::TrajectoryRecorder::RecordBatch);
6879
}

open_spiel/python/pybind11/bots.cc

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,20 @@ void init_pyspiel_bots(py::module& m) {
187187
"Returns a list of registered bot names.");
188188
m.def(
189189
"bots_that_can_play_game",
190-
py::overload_cast<const Game&, Player>(&open_spiel::BotsThatCanPlayGame),
190+
[](std::shared_ptr<const Game> game, int player) {
191+
return BotsThatCanPlayGame(*game, player);
192+
},
191193
py::arg("game"), py::arg("player"),
192194
"Returns a list of bot names that can play specified game for the "
193195
"given player.");
194-
m.def("bots_that_can_play_game",
195-
py::overload_cast<const Game&>(&open_spiel::BotsThatCanPlayGame),
196-
py::arg("game"),
197-
"Returns a list of bot names that can play specified game for any "
198-
"player.");
196+
m.def(
197+
"bots_that_can_play_game",
198+
[](std::shared_ptr<const Game> game) {
199+
return BotsThatCanPlayGame(*game);
200+
},
201+
py::arg("game"),
202+
"Returns a list of bot names that can play specified game for any "
203+
"player.");
199204

200205
py::class_<algorithms::Evaluator,
201206
std::shared_ptr<algorithms::Evaluator>> mcts_evaluator(
@@ -223,14 +228,21 @@ void init_pyspiel_bots(py::module& m) {
223228
.def("children_str", &SearchNode::ChildrenStr);
224229

225230
py::class_<algorithms::MCTSBot, Bot>(m, "MCTSBot")
226-
.def(py::init<const Game&, std::shared_ptr<Evaluator>, double, int,
227-
int64_t, bool, int, bool,
228-
::open_spiel::algorithms::ChildSelectionPolicy>(),
229-
py::arg("game"), py::arg("evaluator"), py::arg("uct_c"),
230-
py::arg("max_simulations"), py::arg("max_memory_mb"),
231-
py::arg("solve"), py::arg("seed"), py::arg("verbose"),
232-
py::arg("child_selection_policy") =
233-
algorithms::ChildSelectionPolicy::UCT)
231+
.def(
232+
py::init([](std::shared_ptr<const Game> game,
233+
std::shared_ptr<Evaluator> evaluator, double uct_c,
234+
int max_simulations, int64_t max_memory_mb, bool solve,
235+
int seed, bool verbose,
236+
algorithms::ChildSelectionPolicy child_selection_policy) {
237+
return new algorithms::MCTSBot(
238+
*game, evaluator, uct_c, max_simulations, max_memory_mb, solve,
239+
seed, verbose, child_selection_policy);
240+
}),
241+
py::arg("game"), py::arg("evaluator"), py::arg("uct_c"),
242+
py::arg("max_simulations"), py::arg("max_memory_mb"),
243+
py::arg("solve"), py::arg("seed"), py::arg("verbose"),
244+
py::arg("child_selection_policy") =
245+
algorithms::ChildSelectionPolicy::UCT)
234246
.def("step", &algorithms::MCTSBot::Step)
235247
.def("mcts_search", &algorithms::MCTSBot::MCTSearch);
236248

@@ -270,10 +282,13 @@ void init_pyspiel_bots(py::module& m) {
270282

271283
m.def("make_stateful_random_bot", open_spiel::MakeStatefulRandomBot,
272284
"A stateful random bot, for test purposes.");
273-
m.def("make_policy_bot",
274-
py::overload_cast<const Game&, Player, int, std::shared_ptr<Policy>>(
275-
open_spiel::MakePolicyBot),
276-
"A bot that samples from a policy.");
285+
m.def(
286+
"make_policy_bot",
287+
[](std::shared_ptr<const Game> game, Player player_id, int seed,
288+
std::shared_ptr<Policy> policy) {
289+
return MakePolicyBot(*game, player_id, seed, policy);
290+
},
291+
"A bot that samples from a policy.");
277292

278293
#if OPEN_SPIEL_BUILD_WITH_ROSHAMBO
279294
m.attr("ROSHAMBO_NUM_THROWS") = py::int_(open_spiel::roshambo::kNumThrows);

open_spiel/python/pybind11/game_transforms.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,35 @@ namespace py = ::pybind11;
2828

2929
void init_pyspiel_game_transforms(py::module& m) {
3030
m.def("load_game_as_turn_based",
31-
py::overload_cast<const std::string&>(&open_spiel::LoadGameAsTurnBased),
31+
py::overload_cast<const std::string&>(&LoadGameAsTurnBased),
3232
"Converts a simultaneous game into an turn-based game with infosets.");
3333

3434
m.def("load_game_as_turn_based",
3535
py::overload_cast<const std::string&, const GameParameters&>(
36-
&open_spiel::LoadGameAsTurnBased),
36+
&LoadGameAsTurnBased),
3737
"Converts a simultaneous game into an turn-based game with infosets.");
3838

39-
m.def("extensive_to_tensor_game", open_spiel::ExtensiveToTensorGame,
39+
m.def("extensive_to_tensor_game", ExtensiveToTensorGame,
4040
"Converts an extensive-game to its equivalent tensor game, "
4141
"which is exponentially larger. Use only with small games.");
4242

43-
m.def("convert_to_turn_based",
44-
[](const std::shared_ptr<open_spiel::Game>& game) {
45-
return open_spiel::ConvertToTurnBased(*game);
46-
},
47-
"Returns a turn-based version of the given game.");
43+
m.def(
44+
"convert_to_turn_based",
45+
[](std::shared_ptr<const Game> game) {
46+
return ConvertToTurnBased(*game);
47+
},
48+
"Returns a turn-based version of the given game.");
4849

49-
m.def("create_repeated_game",
50-
py::overload_cast<const Game&, const GameParameters&>(
51-
&open_spiel::CreateRepeatedGame),
52-
"Creates a repeated game from a stage game.");
50+
m.def(
51+
"create_repeated_game",
52+
[](std::shared_ptr<const Game> game, const GameParameters& params) {
53+
return CreateRepeatedGame(*game, params);
54+
},
55+
"Creates a repeated game from a stage game.");
5356

5457
m.def("create_repeated_game",
5558
py::overload_cast<const std::string&, const GameParameters&>(
56-
&open_spiel::CreateRepeatedGame),
59+
&CreateRepeatedGame),
5760
"Creates a repeated game from a stage game.");
5861
}
5962
} // namespace open_spiel

open_spiel/python/pybind11/observer.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ void init_pyspiel_observer(py::module& m) {
5858
// C++ Observation, intended only for the Python Observation class, not
5959
// for general Python code.
6060
py::class_<Observation>(m, "_Observation", py::buffer_protocol())
61-
.def(py::init<const Game&, std::shared_ptr<Observer>>(), py::arg("game"),
62-
py::arg("observer"))
61+
.def(py::init([](std::shared_ptr<const Game> game,
62+
std::shared_ptr<Observer> observer) {
63+
return new Observation(*game, observer);
64+
}),
65+
py::arg("game"), py::arg("observer"))
6366
.def("tensors", &Observation::tensors)
6467
.def("tensors_info", &Observation::tensors_info)
6568
.def("string_from", &Observation::StringFrom)

open_spiel/python/pybind11/policy.cc

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ void init_pyspiel_policy(py::module& m) {
131131
&open_spiel::PreferredActionPolicy::GetStatePolicy);
132132

133133
py::class_<open_spiel::algorithms::CFRSolver>(m, "CFRSolver")
134-
.def(py::init<const Game&>())
134+
.def(py::init([](std::shared_ptr<const Game> game) {
135+
return new algorithms::CFRSolver(*game);
136+
}))
135137
.def("evaluate_and_update_policy",
136138
&open_spiel::algorithms::CFRSolver::EvaluateAndUpdatePolicy)
137139
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -147,7 +149,9 @@ void init_pyspiel_policy(py::module& m) {
147149
}));
148150

149151
py::class_<open_spiel::algorithms::CFRPlusSolver>(m, "CFRPlusSolver")
150-
.def(py::init<const Game&>())
152+
.def(py::init([](std::shared_ptr<const Game> game) {
153+
return new algorithms::CFRPlusSolver(*game);
154+
}))
151155
.def("evaluate_and_update_policy",
152156
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
153157
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -163,7 +167,9 @@ void init_pyspiel_policy(py::module& m) {
163167
}));
164168

165169
py::class_<open_spiel::algorithms::CFRBRSolver>(m, "CFRBRSolver")
166-
.def(py::init<const Game&>())
170+
.def(py::init([](std::shared_ptr<const Game> game) {
171+
return new algorithms::CFRBRSolver(*game);
172+
}))
167173
.def("evaluate_and_update_policy",
168174
&open_spiel::algorithms::CFRPlusSolver::EvaluateAndUpdatePolicy)
169175
.def("current_policy", &open_spiel::algorithms::CFRSolver::CurrentPolicy)
@@ -184,7 +190,11 @@ void init_pyspiel_policy(py::module& m) {
184190

185191
py::class_<open_spiel::algorithms::ExternalSamplingMCCFRSolver>(
186192
m, "ExternalSamplingMCCFRSolver")
187-
.def(py::init<const Game&, int, open_spiel::algorithms::AverageType>(),
193+
.def(py::init([](std::shared_ptr<const Game> game, int seed,
194+
algorithms::AverageType average_type) {
195+
return new algorithms::ExternalSamplingMCCFRSolver(*game, seed,
196+
average_type);
197+
}),
188198
py::arg("game"), py::arg("seed") = 0,
189199
py::arg("avg_type") = open_spiel::algorithms::AverageType::kSimple)
190200
.def("run_iteration",
@@ -204,7 +214,12 @@ void init_pyspiel_policy(py::module& m) {
204214

205215
py::class_<open_spiel::algorithms::OutcomeSamplingMCCFRSolver>(
206216
m, "OutcomeSamplingMCCFRSolver")
207-
.def(py::init<const Game&, double, int>(), py::arg("game"),
217+
.def(py::init(
218+
[](std::shared_ptr<const Game> game, double epsilon, int seed) {
219+
return new algorithms::OutcomeSamplingMCCFRSolver(
220+
*game, epsilon, seed);
221+
}),
222+
py::arg("game"),
208223
py::arg("epsilon") = open_spiel::algorithms::
209224
OutcomeSamplingMCCFRSolver::kDefaultEpsilon,
210225
py::arg("seed") = -1)
@@ -267,45 +282,54 @@ void init_pyspiel_policy(py::module& m) {
267282
py::arg("use_infostate_get_policy"),
268283
py::arg("prob_cut_threshold") = 0.0);
269284

270-
m.def("exploitability",
271-
py::overload_cast<const Game&, const Policy&>(&Exploitability),
272-
"Returns the sum of the utility that a best responder wins when when "
273-
"playing against 1) the player 0 policy contained in `policy` and 2) "
274-
"the player 1 policy contained in `policy`."
275-
"This only works for two player, zero- or constant-sum sequential "
276-
"games, and raises a SpielFatalError if an incompatible game is passed "
277-
"to it.");
285+
m.def(
286+
"exploitability",
287+
[](std::shared_ptr<const Game> game, const Policy& policy) {
288+
return Exploitability(*game, policy);
289+
},
290+
"Returns the sum of the utility that a best responder wins when when "
291+
"playing against 1) the player 0 policy contained in `policy` and 2) "
292+
"the player 1 policy contained in `policy`."
293+
"This only works for two player, zero- or constant-sum sequential "
294+
"games, and raises a SpielFatalError if an incompatible game is passed "
295+
"to it.");
278296

279297
m.def(
280298
"exploitability",
281-
py::overload_cast<
282-
const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
283-
&Exploitability),
299+
[](std::shared_ptr<const Game> game,
300+
const std::unordered_map<std::string, ActionsAndProbs>& policy) {
301+
return Exploitability(*game, policy);
302+
},
284303
"Returns the sum of the utility that a best responder wins when when "
285304
"playing against 1) the player 0 policy contained in `policy` and 2) "
286305
"the player 1 policy contained in `policy`."
287306
"This only works for two player, zero- or constant-sum sequential "
288307
"games, and raises a SpielFatalError if an incompatible game is passed "
289308
"to it.");
290309

291-
m.def("nash_conv",
292-
py::overload_cast<const Game&, const Policy&, bool>(&NashConv),
293-
"Calculates a measure of how far the given policy is from a Nash "
294-
"equilibrium by returning the sum of the improvements in the value "
295-
"that each player could obtain by unilaterally changing their strategy "
296-
"while the opposing player maintains their current strategy (which "
297-
"for a Nash equilibrium, this value is 0). The third parameter is to "
298-
"indicate whether to use the Policy::GetStatePolicy(const State&) "
299-
"instead of Policy::GetStatePolicy(const std::string& info_state) for "
300-
"computation of the on-policy expected values.",
301-
py::arg("game"), py::arg("policy"),
302-
py::arg("use_state_get_policy") = false);
310+
m.def(
311+
"nash_conv",
312+
[](std::shared_ptr<const Game> game, const Policy& policy,
313+
bool use_state_get_policy) {
314+
return NashConv(*game, policy, use_state_get_policy);
315+
},
316+
"Calculates a measure of how far the given policy is from a Nash "
317+
"equilibrium by returning the sum of the improvements in the value "
318+
"that each player could obtain by unilaterally changing their strategy "
319+
"while the opposing player maintains their current strategy (which "
320+
"for a Nash equilibrium, this value is 0). The third parameter is to "
321+
"indicate whether to use the Policy::GetStatePolicy(const State&) "
322+
"instead of Policy::GetStatePolicy(const std::string& info_state) for "
323+
"computation of the on-policy expected values.",
324+
py::arg("game"), py::arg("policy"),
325+
py::arg("use_state_get_policy") = false);
303326

304327
m.def(
305328
"nash_conv",
306-
py::overload_cast<
307-
const Game&, const std::unordered_map<std::string, ActionsAndProbs>&>(
308-
&NashConv),
329+
[](std::shared_ptr<const Game> game,
330+
const std::unordered_map<std::string, ActionsAndProbs>& policy) {
331+
return NashConv(*game, policy);
332+
},
309333
"Calculates a measure of how far the given policy is from a Nash "
310334
"equilibrium by returning the sum of the improvements in the value "
311335
"that each player could obtain by unilaterally changing their strategy "

0 commit comments

Comments
 (0)