forked from google-deepmind/hanabi-learning-environment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
game_example.cc
153 lines (139 loc) · 5.29 KB
/
game_example.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <iostream>
#include <numeric>
#include <random>
#include <string>
#include <unordered_map>
#include "hanabi_game.h"
#include "hanabi_state.h"
struct GameResult {
int score;
int fireworks_played; // Number of successful Play moves.
int num_steps; // Number of moves by a player.
};
constexpr const char* kGameParamArgPrefix = "--config.hanabi.";
GameResult SimulateGame(const hanabi_learning_env::HanabiGame& game,
bool verbose, std::mt19937* rng) {
hanabi_learning_env::HanabiState state(&game);
GameResult result = {0, 0, 0};
while (!state.IsTerminal()) {
if (state.CurPlayer() == hanabi_learning_env::kChancePlayerId) {
// All of this could be replaced with state.ApplyRandomChance().
// Only done this way to demonstrate picking specific chance moves.
auto chance_outcomes = state.ChanceOutcomes();
std::discrete_distribution<std::mt19937::result_type> dist(
chance_outcomes.second.begin(), chance_outcomes.second.end());
auto move = chance_outcomes.first[dist(*rng)];
if (verbose) {
std::cout << "Legal chance:";
for (int i = 0; i < chance_outcomes.first.size(); ++i) {
std::cout << " <" << chance_outcomes.first[i].ToString() << ", "
<< chance_outcomes.second[i] << ">";
}
std::cout << "\n";
std::cout << "Sampled move: " << move.ToString() << "\n\n";
}
state.ApplyMove(move);
continue;
}
auto legal_moves = state.LegalMoves(state.CurPlayer());
std::uniform_int_distribution<std::mt19937::result_type> dist(
0, legal_moves.size() - 1);
auto move = legal_moves[dist(*rng)];
if (verbose) {
std::cout << "Current player: " << state.CurPlayer() << "\n";
std::cout << state.ToString() << "\n\n";
std::cout << "Legal moves:";
for (int i = 0; i < legal_moves.size(); ++i) {
std::cout << " " << legal_moves[i].ToString();
}
std::cout << "\n";
std::cout << "Sampled move: " << move.ToString() << "\n\n";
}
state.ApplyMove(move);
++result.num_steps;
if (state.MoveHistory().back().scored) {
++result.fireworks_played;
}
}
if (verbose) {
std::cout << "Game done, terminal state:\n" << state.ToString() << "\n\n";
std::cout << "score = " << state.Score() << "\n\n";
}
result.score = state.Score();
return result;
}
void SimulateGames(
const std::unordered_map<std::string, std::string>& game_params,
int num_trials = 1, bool verbose = true) {
std::mt19937 rng;
rng.seed(std::random_device()());
hanabi_learning_env::HanabiGame game(game_params);
auto params = game.Parameters();
std::cout << "Hanabi game created, with parameters:\n";
for (const auto& item : params) {
std::cout << " " << item.first << "=" << item.second << "\n";
}
std::vector<GameResult> results;
results.reserve(num_trials);
for (int trial = 0; trial < num_trials; ++trial) {
results.push_back(SimulateGame(game, verbose, &rng));
}
if (num_trials > 1) {
GameResult avg_score = std::accumulate(
results.begin(), results.end(), GameResult(),
[](const GameResult& lhs, const GameResult& rhs) {
GameResult result = {lhs.score + rhs.score,
lhs.fireworks_played + rhs.fireworks_played,
lhs.num_steps + rhs.num_steps};
return result;
});
std::cout << "Average score: "
<< static_cast<double>(avg_score.score) / results.size()
<< " average number of fireworks played: "
<< static_cast<double>(avg_score.fireworks_played) /
results.size()
<< " average num_steps: "
<< static_cast<double>(avg_score.num_steps) / results.size()
<< "\n";
}
}
std::unordered_map<std::string, std::string> ParseArguments(int argc,
char** argv) {
std::unordered_map<std::string, std::string> game_params;
const auto prefix_len = strlen(kGameParamArgPrefix);
for (int i = 1; i < argc; ++i) {
std::string param = argv[i];
if (param.compare(0, prefix_len, kGameParamArgPrefix) == 0 &&
param.size() > prefix_len) {
std::string value;
param = param.substr(prefix_len, std::string::npos);
auto value_pos = param.find("=");
if (value_pos != std::string::npos) {
value = param.substr(value_pos + 1, std::string::npos);
param = param.substr(0, value_pos);
}
game_params[param] = value;
}
}
return game_params;
}
int main(int argc, char** argv) {
auto game_params = ParseArguments(argc, argv);
SimulateGames(game_params);
return 0;
}