aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSetepenre <pierre.delaunay.tr@gmail.com>2019-02-05 11:13:38 -0600
committerSetepenre <pierre.delaunay.tr@gmail.com>2019-04-18 11:23:06 -0400
commita1508abc80278d5004c4bf7ec6b7814a2358dbff (patch)
treed00804bd7169fe7f91a3d9005712460ccef91b50
parent0c75691479ddfde6db7f80432a1f38cfcc051eb6 (diff)
Add _unit_owner caching (#11)
-rw-r--r--diplomacy/engine/game.py52
-rw-r--r--diplomacy/tests/test_game.py18
2 files changed, 58 insertions, 12 deletions
diff --git a/diplomacy/engine/game.py b/diplomacy/engine/game.py
index 03a25af..66a4cf9 100644
--- a/diplomacy/engine/game.py
+++ b/diplomacy/engine/game.py
@@ -152,6 +152,11 @@ class Game(Jsonable):
e.g. 3
- zobrist_hash - Contains the zobrist hash representing the current state of this game
e.g. 12545212418541325
+
+ ----- Caches ----
+ - _unit_owner_cache - Contains a dictionary with (unit, coast_required) as key and owner as value
+ - Set to Note when the cache is not built
+ e.g. {('A PAR', True): <FRANCE>, ('A PAR', False): <FRANCE>), ...}
"""
# pylint: disable=too-many-instance-attributes
__slots__ = ['victory', 'no_rules', 'meta_rules', 'phase', 'note', 'map', 'powers', 'outcome', 'error', 'popped',
@@ -160,7 +165,7 @@ class Game(Jsonable):
'convoy_paths_dest', 'zobrist_hash', 'renderer', 'game_id', 'map_name', 'role', 'rules',
'message_history', 'state_history', 'result_history', 'status', 'timestamp_created', 'n_controls',
'deadline', 'registration_password', 'observer_level', 'controlled_powers', '_phase_wrapper_type',
- 'phase_abbr']
+ 'phase_abbr', '_unit_owner_cache']
zobrist_tables = {}
rule_cache = ()
model = {
@@ -228,6 +233,9 @@ class Game(Jsonable):
self.observer_level = None
self.controlled_powers = None
+ # Caches
+ self._unit_owner_cache = None # {(unit, coast_required): owner}
+
# Remove rules from kwargs (if present), as we want to add them manually using self.add_rule().
rules = kwargs.pop(strings.RULES, None)
@@ -1124,6 +1132,7 @@ class Game(Jsonable):
def clear_cache(self):
""" Clears all caches """
self.convoy_paths_possible, self.convoy_paths_dest = None, None
+ self._unit_owner_cache = None
def get_current_phase(self):
""" Returns the current phase (format 'S1901M' or 'FORMING' or 'COMPLETED' """
@@ -1253,12 +1262,19 @@ class Game(Jsonable):
self.order_history.put(previous_phase, previous_orders)
self.message_history.put(previous_phase, previous_messages)
self.state_history.put(previous_phase, previous_state)
+
return GamePhaseData(name=str(previous_phase),
state=previous_state,
orders=previous_orders,
messages=previous_messages,
results=self.result_history[previous_phase])
+ def build_caches(self):
+ """ Rebuilds the various caches """
+ self.clear_cache()
+ self._build_list_possible_convoys()
+ self._build_unit_owner_cache()
+
def rebuild_hash(self):
""" Completely recalculate the Zobrist hash
:return: The updated hash value
@@ -1457,7 +1473,7 @@ class Game(Jsonable):
# Rebuilding hash and returning
self.rebuild_hash()
- self._build_list_possible_convoys()
+ self.build_caches()
def get_all_possible_orders(self, loc=None):
""" Computes a list of all possible orders for a unit in a given location
@@ -2530,15 +2546,19 @@ class Game(Jsonable):
self._move_to_start_phase()
self.note = ''
self.win = self.victory[0]
+
# Create dummy power objects for non-loaded powers.
for power_name in self.map.powers:
if power_name not in self.powers:
self.powers[power_name] = Power(self, power_name, role=self.role)
- # Initialize all powers.
+
+ # Initialize all powers - Starter having type won't be initialized.
for starter in self.powers.values():
- # Starter having type won't be initialized.
starter.initialize(self)
+ # Build caches
+ self.build_caches()
+
def _process(self):
""" Processes the current phase of the game """
# Convert all raw movement phase "ORDER"s in a NO_CHECK game to standard orders before calling
@@ -2590,8 +2610,8 @@ class Game(Jsonable):
else:
raise Exception("FailedToAdvancePhase")
- # Rebuilding the convoy cache
- self._build_list_possible_convoys()
+ # Rebuilding the caches
+ self.build_caches()
# Returning
return []
@@ -3335,6 +3355,18 @@ class Game(Jsonable):
return 0
return 1
+ def _build_unit_owner_cache(self):
+ """ Builds the unit_owner cache """
+ if self._unit_owner_cache is not None:
+ return
+ self._unit_owner_cache = {}
+ for owner in self.powers.values():
+ for unit in owner.units:
+ self._unit_owner_cache[(unit, True)] = owner # (unit, coast_required): owner
+ self._unit_owner_cache[(unit, False)] = owner
+ if '/' in unit:
+ self._unit_owner_cache[(unit.split('/')[0], False)] = owner
+
def _unit_owner(self, unit, coast_required=1):
""" Finds the power who owns a unit
:param unit: The name of the unit to find (e.g. 'A PAR')
@@ -3345,12 +3377,8 @@ class Game(Jsonable):
# If coast_required is 0 and unit does not contain a '/'
# return the owner if we find a unit that starts with unit
# Don't count the unit if it needs to retreat (i.e. it has been dislodged)
- for owner in self.powers.values():
- if unit in owner.units:
- return owner
- if not coast_required and '/' not in unit and [1 for x in owner.units if x.find(unit) == 0]:
- return owner
- return None
+ self._build_unit_owner_cache()
+ return self._unit_owner_cache.get((unit, bool(coast_required)), None)
def _occupant(self, site, any_coast=0):
""" Finds the occupant of a site
diff --git a/diplomacy/tests/test_game.py b/diplomacy/tests/test_game.py
index 76ba146..20994c8 100644
--- a/diplomacy/tests/test_game.py
+++ b/diplomacy/tests/test_game.py
@@ -515,6 +515,7 @@ def test_set_current_phase():
power = game.get_power('FRANCE')
power.units.remove('A PAR')
game.set_current_phase('W1901A')
+ game.clear_cache()
assert game.get_current_phase() == 'W1901A'
assert game.phase_type == 'A'
assert 'A PAR B' in game.get_all_possible_orders('PAR')
@@ -648,3 +649,20 @@ def test_result_history():
phase_data = game.get_phase_from_history(short_phase_name)
assert 'bounce' in phase_data.results['A PAR']
assert 'bounce' in phase_data.results['A MAR']
+
+def test_unit_owner():
+ """ Test Unit Owner Resolver making sure the cached results are correct """
+ game = Game()
+ print(game.get_units('RUSSIA'))
+
+ assert game._unit_owner('F STP/SC', coast_required=1) is game.get_power('RUSSIA') # pylint: disable=protected-access
+ assert game._unit_owner('F STP/SC', coast_required=0) is game.get_power('RUSSIA') # pylint: disable=protected-access
+
+ assert game._unit_owner('F STP', coast_required=1) is None # pylint: disable=protected-access
+ assert game._unit_owner('F STP', coast_required=0) is game.get_power('RUSSIA') # pylint: disable=protected-access
+
+ assert game._unit_owner('A WAR', coast_required=0) is game.get_power('RUSSIA') # pylint: disable=protected-access
+ assert game._unit_owner('A WAR', coast_required=1) is game.get_power('RUSSIA') # pylint: disable=protected-access
+
+ assert game._unit_owner('F SEV', coast_required=0) is game.get_power('RUSSIA') # pylint: disable=protected-access
+ assert game._unit_owner('F SEV', coast_required=1) is game.get_power('RUSSIA') # pylint: disable=protected-access