From a1508abc80278d5004c4bf7ec6b7814a2358dbff Mon Sep 17 00:00:00 2001
From: Setepenre <pierre.delaunay.tr@gmail.com>
Date: Tue, 5 Feb 2019 11:13:38 -0600
Subject: Add _unit_owner caching (#11)

---
 diplomacy/engine/game.py     | 52 ++++++++++++++++++++++++++++++++++----------
 diplomacy/tests/test_game.py | 18 +++++++++++++++
 2 files changed, 58 insertions(+), 12 deletions(-)

diff --git a/diplomacy/engine/game.py b/diplomacy/engine/game.py
index 03a25af..66a4cf9 100644
--- a/diplomacy/engine/game.py
+++ b/diplomacy/engine/game.py
@@ -152,6 +152,11 @@ class Game(Jsonable):
                 e.g. 3
         - zobrist_hash - Contains the zobrist hash representing the current state of this game
                 e.g. 12545212418541325
+
+        ----- Caches ----
+        - _unit_owner_cache - Contains a dictionary with (unit, coast_required) as key and owner as value
+                            - Set to Note when the cache is not built
+                e.g. {('A PAR', True): <FRANCE>, ('A PAR', False): <FRANCE>), ...}
     """
     # pylint: disable=too-many-instance-attributes
     __slots__ = ['victory', 'no_rules', 'meta_rules', 'phase', 'note', 'map', 'powers', 'outcome', 'error', 'popped',
@@ -160,7 +165,7 @@ class Game(Jsonable):
                  'convoy_paths_dest', 'zobrist_hash', 'renderer', 'game_id', 'map_name', 'role', 'rules',
                  'message_history', 'state_history', 'result_history', 'status', 'timestamp_created', 'n_controls',
                  'deadline', 'registration_password', 'observer_level', 'controlled_powers', '_phase_wrapper_type',
-                 'phase_abbr']
+                 'phase_abbr', '_unit_owner_cache']
     zobrist_tables = {}
     rule_cache = ()
     model = {
@@ -228,6 +233,9 @@ class Game(Jsonable):
         self.observer_level = None
         self.controlled_powers = None
 
+        # Caches
+        self._unit_owner_cache = None               # {(unit, coast_required): owner}
+
         # Remove rules from kwargs (if present), as we want to add them manually using self.add_rule().
         rules = kwargs.pop(strings.RULES, None)
 
@@ -1124,6 +1132,7 @@ class Game(Jsonable):
     def clear_cache(self):
         """ Clears all caches """
         self.convoy_paths_possible, self.convoy_paths_dest = None, None
+        self._unit_owner_cache = None
 
     def get_current_phase(self):
         """ Returns the current phase (format 'S1901M' or 'FORMING' or 'COMPLETED' """
@@ -1253,12 +1262,19 @@ class Game(Jsonable):
         self.order_history.put(previous_phase, previous_orders)
         self.message_history.put(previous_phase, previous_messages)
         self.state_history.put(previous_phase, previous_state)
+
         return GamePhaseData(name=str(previous_phase),
                              state=previous_state,
                              orders=previous_orders,
                              messages=previous_messages,
                              results=self.result_history[previous_phase])
 
+    def build_caches(self):
+        """ Rebuilds the various caches """
+        self.clear_cache()
+        self._build_list_possible_convoys()
+        self._build_unit_owner_cache()
+
     def rebuild_hash(self):
         """ Completely recalculate the Zobrist hash
             :return: The updated hash value
@@ -1457,7 +1473,7 @@ class Game(Jsonable):
 
         # Rebuilding hash and returning
         self.rebuild_hash()
-        self._build_list_possible_convoys()
+        self.build_caches()
 
     def get_all_possible_orders(self, loc=None):
         """ Computes a list of all possible orders for a unit in a given location
@@ -2530,15 +2546,19 @@ class Game(Jsonable):
         self._move_to_start_phase()
         self.note = ''
         self.win = self.victory[0]
+
         # Create dummy power objects for non-loaded powers.
         for power_name in self.map.powers:
             if power_name not in self.powers:
                 self.powers[power_name] = Power(self, power_name, role=self.role)
-        # Initialize all powers.
+
+        # Initialize all powers - Starter having type won't be initialized.
         for starter in self.powers.values():
-            # Starter having type won't be initialized.
             starter.initialize(self)
 
+        # Build caches
+        self.build_caches()
+
     def _process(self):
         """ Processes the current phase of the game """
         # Convert all raw movement phase "ORDER"s in a NO_CHECK game to standard orders before calling
@@ -2590,8 +2610,8 @@ class Game(Jsonable):
         else:
             raise Exception("FailedToAdvancePhase")
 
-        # Rebuilding the convoy cache
-        self._build_list_possible_convoys()
+        # Rebuilding the caches
+        self.build_caches()
 
         # Returning
         return []
@@ -3335,6 +3355,18 @@ class Game(Jsonable):
             return 0
         return 1
 
+    def _build_unit_owner_cache(self):
+        """ Builds the unit_owner cache """
+        if self._unit_owner_cache is not None:
+            return
+        self._unit_owner_cache = {}
+        for owner in self.powers.values():
+            for unit in owner.units:
+                self._unit_owner_cache[(unit, True)] = owner                    # (unit, coast_required): owner
+                self._unit_owner_cache[(unit, False)] = owner
+                if '/' in unit:
+                    self._unit_owner_cache[(unit.split('/')[0], False)] = owner
+
     def _unit_owner(self, unit, coast_required=1):
         """ Finds the power who owns a unit
             :param unit: The name of the unit to find (e.g. 'A PAR')
@@ -3345,12 +3377,8 @@ class Game(Jsonable):
         # If coast_required is 0 and unit does not contain a '/'
         # return the owner if we find a unit that starts with unit
         # Don't count the unit if it needs to retreat (i.e. it has been dislodged)
-        for owner in self.powers.values():
-            if unit in owner.units:
-                return owner
-            if not coast_required and '/' not in unit and [1 for x in owner.units if x.find(unit) == 0]:
-                return owner
-        return None
+        self._build_unit_owner_cache()
+        return self._unit_owner_cache.get((unit, bool(coast_required)), None)
 
     def _occupant(self, site, any_coast=0):
         """ Finds the occupant of a site
diff --git a/diplomacy/tests/test_game.py b/diplomacy/tests/test_game.py
index 76ba146..20994c8 100644
--- a/diplomacy/tests/test_game.py
+++ b/diplomacy/tests/test_game.py
@@ -515,6 +515,7 @@ def test_set_current_phase():
     power = game.get_power('FRANCE')
     power.units.remove('A PAR')
     game.set_current_phase('W1901A')
+    game.clear_cache()
     assert game.get_current_phase() == 'W1901A'
     assert game.phase_type == 'A'
     assert 'A PAR B' in game.get_all_possible_orders('PAR')
@@ -648,3 +649,20 @@ def test_result_history():
     phase_data = game.get_phase_from_history(short_phase_name)
     assert 'bounce' in phase_data.results['A PAR']
     assert 'bounce' in phase_data.results['A MAR']
+
+def test_unit_owner():
+    """ Test Unit Owner Resolver making sure the cached results are correct """
+    game = Game()
+    print(game.get_units('RUSSIA'))
+
+    assert game._unit_owner('F STP/SC', coast_required=1) is game.get_power('RUSSIA')                                   # pylint: disable=protected-access
+    assert game._unit_owner('F STP/SC', coast_required=0) is game.get_power('RUSSIA')                                   # pylint: disable=protected-access
+
+    assert game._unit_owner('F STP', coast_required=1) is None                                                          # pylint: disable=protected-access
+    assert game._unit_owner('F STP', coast_required=0) is game.get_power('RUSSIA')                                      # pylint: disable=protected-access
+
+    assert game._unit_owner('A WAR', coast_required=0) is game.get_power('RUSSIA')                                      # pylint: disable=protected-access
+    assert game._unit_owner('A WAR', coast_required=1) is game.get_power('RUSSIA')                                      # pylint: disable=protected-access
+
+    assert game._unit_owner('F SEV', coast_required=0) is game.get_power('RUSSIA')                                      # pylint: disable=protected-access
+    assert game._unit_owner('F SEV', coast_required=1) is game.get_power('RUSSIA')                                      # pylint: disable=protected-access
-- 
cgit v1.2.3