aboutsummaryrefslogtreecommitdiff
path: root/diplomacy/utils/export.py
blob: 183386b2b72e3b834ba6e53801afb8f7ae61a81a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# ==============================================================================
# Copyright (C) 2019 - Philip Paquette, Steven Bocco
#
#  This program is free software: you can redistribute it and/or modify it under
#  the terms of the GNU Affero General Public License as published by the Free
#  Software Foundation, either version 3 of the License, or (at your option) any
#  later version.
#
#  This program is distributed in the hope that it will be useful, but WITHOUT
#  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more
#  details.
#
#  You should have received a copy of the GNU Affero General Public License along
#  with this program.  If not, see <https://www.gnu.org/licenses/>.
# ==============================================================================
""" Exporter
    - Responsible for exporting games in a standardized format to disk
"""
import logging
import os
import ujson as json
from diplomacy.engine.game import Game
from diplomacy.engine.map import Map
from diplomacy.utils import strings
from diplomacy.utils.game_phase_data import GamePhaseData

# Constants
LOGGER = logging.getLogger(__name__)
RULES_TO_SKIP = ['SOLITAIRE', 'NO_DEADLINE', 'CD_DUMMIES', 'ALWAYS_WAIT', 'IGNORE_ERRORS']

def to_saved_game_format(game, output_path=None, output_mode='a'):
    """ Converts a game to a standardized JSON format

        :param game: game to convert.
        :param output_path: Optional path to file. If set, the json.dumps() of the saved_game is written to that file.
        :param output_mode: Optional. The mode to use to write to the output_path (if provided). Defaults to 'a'
        :return: A game in the standard format used to saved game, that can be converted to JSON for serialization
        :type game: diplomacy.engine.game.Game
        :type output_path: str | None, optional
        :type output_mode: str, optional
        :rtype: Dict
    """
    phases = Game.get_phase_history(game)                                       # Get phase history.
    phases.append(Game.get_phase_data(game))                                    # Add current game phase.
    rules = [rule for rule in game.rules if rule not in RULES_TO_SKIP]          # Filter rules.

    # Extend states fields.
    phases_to_dict = [phase.to_dict() for phase in phases]
    for phase_dct in phases_to_dict:
        phase_dct['state']['game_id'] = game.game_id
        phase_dct['state']['map'] = game.map_name
        phase_dct['state']['rules'] = rules

    # Building saved game
    saved_game = {'id': game.game_id,
                  'map': game.map_name,
                  'rules': rules,
                  'phases': phases_to_dict}

    # Writing to disk
    if output_path:
        with open(output_path, output_mode) as output_file:
            output_file.write(json.dumps(saved_game) + '\n')

    # Returning
    return saved_game

def from_saved_game_format(saved_game):
    """ Rebuilds a :class:`diplomacy.engine.game.Game` object from the saved game (python :class:`Dict`)
        saved_game is the dictionary. It can be built by calling json.loads(json_line).

        :param saved_game: The saved game exported from :meth:`to_saved_game_format`
        :type saved_game: Dict
        :rtype: diplomacy.engine.game.Game
        :return: The game object restored from the saved game
    """
    game_id = saved_game.get('id', None)
    kwargs = {strings.MAP_NAME: saved_game.get('map', 'standard'),
              strings.RULES: saved_game.get('rules', [])}

    # Building game
    game = Game(game_id=game_id, **kwargs)
    phase_history = []

    # Restoring every phase
    for phase_dct in saved_game.get('phases', []):
        phase_history.append(GamePhaseData.from_dict(phase_dct))
    game.set_phase_data(phase_history, clear_history=True)

    # Returning game
    return game

def load_saved_games_from_disk(input_path, on_error='raise'):
    """ Rebuids multiple :class:`diplomacy.engine.game.Game` from each line in a .jsonl file

        :param input_path: The path to the input file. Expected content is one saved_game json per line.
        :param on_error: Optional. What to do if a game conversion fails. Either 'raise', 'warn', 'ignore'
        :type input_path: str
        :rtype: List[diplomacy.Game]
        :return: A list of :class:`diplomacy.engine.game.Game` objects.
    """
    loaded_games = []
    assert on_error in ('raise', 'warn', 'ignore'), 'Expected values for on_error are "raise", "warn", "ignore".'

    # File does not exist
    if not os.path.exists(input_path):
        LOGGER.warning('File %s does not exist. Aborting.', input_path)
        return loaded_games

    # Importing each game
    with open(input_path, 'r') as file:
        for line in file:
            try:
                saved_game = json.loads(line.rstrip('\n'))
                game = from_saved_game_format(saved_game)
                loaded_games.append(game)
            except Exception as exc:                    # pylint: disable=broad-except
                if on_error == 'raise':
                    raise exc
                if on_error == 'warn':
                    LOGGER.warning(exc)

    # Returning
    return loaded_games

def is_valid_saved_game(saved_game):
    """ Checks if the saved game is valid.
        This is an expensive operation because it replays the game.

        :param saved_game: The saved game (from to_saved_game_format)
        :return: A boolean that indicates if the game is valid
    """
    # pylint: disable=too-many-return-statements, too-many-nested-blocks, too-many-branches
    nb_forced_phases = 0
    max_nb_forced_phases = 1 if 'DIFFERENT_ADJUDICATION' in saved_game.get('rules', []) else 0

    # Validating default fields
    if 'id' not in saved_game or not saved_game['id']:
        return False
    if 'map' not in saved_game:
        return False
    map_object = Map(saved_game['map'])
    if map_object.name != saved_game['map']:
        return False
    if 'rules' not in saved_game:
        return False
    if 'phases' not in saved_game:
        return False

    # Validating each phase
    nb_messages = 0
    nb_phases = len(saved_game['phases'])
    last_time_sent = -1
    for phase_ix in range(nb_phases):
        current_phase = saved_game['phases'][phase_ix]
        state = current_phase['state']
        phase_orders = current_phase['orders']
        previous_phase_name = 'FORMING' if phase_ix == 0 else saved_game['phases'][phase_ix - 1]['name']
        next_phase_name = 'COMPLETED' if phase_ix == nb_phases - 1 else saved_game['phases'][phase_ix + 1]['name']
        power_names = list(state['units'].keys())

        # Validating messages
        for message in saved_game['phases'][phase_ix]['messages']:
            nb_messages += 1
            if map_object.compare_phases(previous_phase_name, message['phase']) >= 0:
                return False
            if map_object.compare_phases(message['phase'], next_phase_name) > 0:
                return False
            if message['sender'] not in power_names + ['SYSTEM']:
                return False
            if message['recipient'] not in power_names + ['GLOBAL']:
                return False
            if message['time_sent'] < last_time_sent:
                return False
            last_time_sent = message['time_sent']

        # Validating phase
        if phase_ix < (nb_phases - 1):
            is_forced_phase = False

            # Setting game state
            game = Game(saved_game['id'], map_name=saved_game['map'], rules=['SOLITAIRE'] + saved_game['rules'])
            game.set_phase_data(GamePhaseData.from_dict(current_phase))

            # Determining what phase we should expect from the dataset.
            next_state = saved_game['phases'][phase_ix + 1]['state']

            # Setting orders
            game.clear_orders()
            for power_name in phase_orders:
                game.set_orders(power_name, phase_orders[power_name])

            # Validating orders
            orders = game.get_orders()
            possible_orders = game.get_all_possible_orders()
            for power_name in orders:
                if sorted(orders[power_name]) != sorted(current_phase['orders'][power_name]):
                    return False
                if 'NO_CHECK' not in game.rules:
                    for order in orders[power_name]:
                        loc = order.split()[1]
                        if order not in possible_orders[loc]:
                            return False

            # Validating resulting state
            game.process()

            # Checking phase name
            if game.get_current_phase() != next_state['name']:
                is_forced_phase = True

            # Checking zobrist hash
            if game.get_hash() != next_state['zobrist_hash']:
                is_forced_phase = True

            # Checking units
            units = game.get_units()
            for power_name in units:
                if sorted(units[power_name]) != sorted(next_state['units'][power_name]):
                    is_forced_phase = True

            # Checking centers
            centers = game.get_centers()
            for power_name in centers:
                if sorted(centers[power_name]) != sorted(next_state['centers'][power_name]):
                    is_forced_phase = True

            # Allowing 1 forced phase if DIFFERENT_ADJUDICATION is in rule
            if is_forced_phase:
                nb_forced_phases += 1
            if nb_forced_phases > max_nb_forced_phases:
                return False

    # Making sure NO_PRESS is not set
    if 'NO_PRESS' in saved_game['rules'] and nb_messages > 0:
        return False

    # The data is valid
    return True