diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | diplomacy/maps/convoy_paths_cache.pkl | bin | 0 -> 1383835 bytes | |||
-rw-r--r-- | diplomacy/maps/tests/test_map_gen.py | 19 | ||||
-rw-r--r-- | diplomacy/utils/convoy_paths.py | 132 |
4 files changed, 109 insertions, 43 deletions
@@ -103,6 +103,7 @@ diplomacy/games # Pickle *.p *.pkl +!diplomacy/maps/convoy_paths_cache.pkl # Outputs out* diff --git a/diplomacy/maps/convoy_paths_cache.pkl b/diplomacy/maps/convoy_paths_cache.pkl Binary files differnew file mode 100644 index 0000000..d597dfc --- /dev/null +++ b/diplomacy/maps/convoy_paths_cache.pkl diff --git a/diplomacy/maps/tests/test_map_gen.py b/diplomacy/maps/tests/test_map_gen.py index 48084fb..c231d23 100644 --- a/diplomacy/maps/tests/test_map_gen.py +++ b/diplomacy/maps/tests/test_map_gen.py @@ -19,9 +19,11 @@ """ import glob import os +import pickle import sys from diplomacy.engine.map import Map +from diplomacy.utils.convoy_paths import INTERNAL_CACHE_PATH, get_file_md5 MODULE_PATH = sys.modules['diplomacy'].__path__[0] @@ -43,3 +45,20 @@ def test_map_with_full_path(): this_map = Map(current_map) assert this_map.error == [], 'Map %s should have no errors' % current_map del this_map + +def test_internal_cache(): + """ Tests that all maps with a SVG are in the internal cache """ + maps = glob.glob(os.path.join(MODULE_PATH, 'maps', '*.map')) + assert maps, 'Expected maps to be found.' + assert os.path.exists(INTERNAL_CACHE_PATH), 'Expected internal cache to exist' + + # Checking that maps with a svg are in the internal cache + with open(INTERNAL_CACHE_PATH, 'rb') as cache_file: + internal_cache = pickle.load(cache_file) + for current_map in maps: + map_name = current_map[current_map.rfind('/') + 1:].replace('.map', '') + this_map = Map(map_name) + if not this_map.svg_path: + continue + assert get_file_md5(current_map) in internal_cache, 'Map "%s" not found in internal cache' % map_name + del this_map diff --git a/diplomacy/utils/convoy_paths.py b/diplomacy/utils/convoy_paths.py index 7c21863..1a1deb3 100644 --- a/diplomacy/utils/convoy_paths.py +++ b/diplomacy/utils/convoy_paths.py @@ -35,28 +35,28 @@ if HOME_DIRECTORY == '~': raise RuntimeError('Cannot find home directory. Unable to save cache') # Constants -__VERSION__ = '20180307_0955' +__VERSION__ = '20180913_0915' +COAST_TYPES = ('COAST', 'PORT') +WATER_TYPES = ('WATER', 'PORT') +MAX_CONVOY_LENGTH = 13 # Convoys over this length are not supported, too reduce generation time -# We need to cap convoy length, otherwise the problem gets exponential -SMALL_MAPS = ['standard', 'standard_france_austria', 'standard_germany_italy', 'ancmed', 'colonial', 'modern', 'pure', - 'ancmed_age_of_empires', 'standard_age_of_empires', 'standard_age_of_empires_2', 'standard_fleet_rome'] -SMALL_MAPS_CONVOY_LENGTH = 25 -ALL_MAPS_CONVOY_LENGTH = 12 CACHE_FILE_NAME = 'convoy_paths_cache.pkl' -DISK_CACHE_PATH = os.path.join(HOME_DIRECTORY, '.cache', 'diplomacy', CACHE_FILE_NAME) +INTERNAL_CACHE_PATH = os.path.join(settings.PACKAGE_DIR, 'maps', CACHE_FILE_NAME) +EXTERNAL_CACHE_PATH = os.path.join(HOME_DIRECTORY, '.cache', 'diplomacy', CACHE_FILE_NAME) -def display_progress_bar(queue, max_loop_iters): +def _display_progress_bar(queue, max_loop_iters): """ Displays a progress bar :param queue: Multiprocessing queue to display the progress bar :param max_loop_iters: The expected maximum number of iterations """ progress_bar = tqdm.tqdm(total=max_loop_iters) - for _ in iter(queue.get, None): - progress_bar.update() + for item in iter(queue.get, None): # type: int + for _ in range(item): + progress_bar.update() progress_bar.close() -def get_convoy_paths(map_object, start_location, max_convoy_length, queue): +def _get_convoy_paths(map_object, start_location, max_convoy_length, queue): """ Returns a list of possible convoy destinations with the required units to get there Does a breadth first search from the starting location @@ -70,24 +70,35 @@ def get_convoy_paths(map_object, start_location, max_convoy_length, queue): to_check = Queue() # Items in queue have format ({fleets location}, last fleet location) dest_paths = {} # Dict with dest as key and a list of all paths from start_location to dest as value + # To measure progress + last_completed_path_length = 0 + nb_water_locs = len([loc.upper() for loc in map_object.locs if map_object.area_type(loc) in WATER_TYPES]) + # We need to start on a coast / port - if map_object.area_type(start_location) not in ('COAST', 'PORT') or '/' in start_location: + if map_object.area_type(start_location) not in COAST_TYPES or '/' in start_location: + queue.put(nb_water_locs) return [] # Queuing all adjacent water locations from start for loc in [loc.upper() for loc in map_object.abut_list(start_location, incl_no_coast=True)]: - if map_object.area_type(loc) in ['WATER', 'PORT']: + if map_object.area_type(loc) in WATER_TYPES: to_check.put(({loc}, loc)) # Checking all subsequent adjacencies until no more adjacencies are possible while not to_check.empty(): fleets_loc, last_loc = to_check.get() + new_completed_path_length = len(fleets_loc) - 1 + + # Marking path length as completed + if new_completed_path_length > last_completed_path_length: + queue.put(new_completed_path_length - last_completed_path_length) + last_completed_path_length = new_completed_path_length # Checking adjacencies for loc in [loc.upper() for loc in map_object.abut_list(last_loc, incl_no_coast=True)]: # If we find adjacent coasts, we mark them as a possible result - if map_object.area_type(loc) in ('COAST', 'PORT') and '/' not in loc and loc != start_location: + if map_object.area_type(loc) in COAST_TYPES and '/' not in loc and loc != start_location: dest_paths.setdefault(loc, []) # If we already have a working path that is a subset of the current fleets, we can skip @@ -99,7 +110,7 @@ def get_convoy_paths(map_object, start_location, max_convoy_length, queue): dest_paths[loc] += [fleets_loc] # If we find adjacent water/port, we add them to the queue - elif map_object.area_type(loc) in ('WATER', 'PORT') \ + elif map_object.area_type(loc) in WATER_TYPES \ and loc not in fleets_loc \ and len(fleets_loc) < max_convoy_length: to_check.put((fleets_loc | {loc}, loc)) @@ -117,11 +128,14 @@ def get_convoy_paths(map_object, start_location, max_convoy_length, queue): for fleets, dests in similar_paths.items(): results += [(start_location, set(fleets), dests)] + # Marking as done + if nb_water_locs > last_completed_path_length: + queue.put(nb_water_locs - last_completed_path_length) + # Returning - queue.put(1) return results -def build_convoy_paths_cache(map_object, max_convoy_length): +def _build_convoy_paths_cache(map_object, max_convoy_length): """ Builds the convoy paths cache for a map :param map_object: The instantiated map object @@ -130,20 +144,22 @@ def build_convoy_paths_cache(map_object, max_convoy_length): the value is a list of convoy paths (start loc, {fleets}, {dest}) of that length for the map :type map_object: diplomacy.Map """ - print('Generating convoy paths for {}'.format(map_object.name)) - coasts = [loc.upper() for loc in map_object.locs - if map_object.area_type(loc) in ('COAST', 'PORT') if '/' not in loc] + print('Generating convoy paths for "{}"'.format(map_object.name)) + print('This is an operation that is required the first time a map is loaded. It might take several minutes...\n') + coasts = [loc.upper() for loc in map_object.locs if map_object.area_type(loc) in COAST_TYPES and '/' not in loc] + water_locs = [loc.upper() for loc in map_object.locs if map_object.area_type(loc) in WATER_TYPES] # Starts the progress bar loop manager = multiprocessing.Manager() queue = manager.Queue() - progress_bar = threading.Thread(target=display_progress_bar, args=(queue, len(coasts))) + progress_bar = threading.Thread(target=_display_progress_bar, args=(queue, len(coasts) * len(water_locs))) progress_bar.start() - # Getting all paths for each coasts in parallel - pool = multiprocessing.Pool(multiprocessing.cpu_count()) + # Getting all paths for each coasts in parallel (except if the map is large, to avoid high memory usage) + nb_cores = multiprocessing.cpu_count() if (len(water_locs) <= 30 or max_convoy_length <= MAX_CONVOY_LENGTH) else 1 + pool = multiprocessing.Pool(nb_cores) tasks = [(map_object, coast, max_convoy_length, queue) for coast in coasts] - results = pool.starmap(get_convoy_paths, tasks) + results = pool.starmap(_get_convoy_paths, tasks) pool.close() results = [item for sublist in results for item in sublist] queue.put(None) @@ -170,24 +186,34 @@ def get_file_md5(file_path): hash_md5.update(chunk) return hash_md5.hexdigest() -def add_to_cache(map_name): +def add_to_cache(map_name, max_convoy_length=MAX_CONVOY_LENGTH): """ Lazy generates convoys paths for a map and adds it to the disk cache :param map_name: The name of the map + :param max_convoy_length: The maximum convoy length permitted :return: The convoy_paths for that map """ - disk_convoy_paths = {'__version__': __VERSION__} # Uses hash as key + convoy_paths = {'__version__': __VERSION__} # Uses hash as key + external_convoy_paths = {'__version__': __VERSION__} # Uses hash as key + + # Loading from internal cache first + if os.path.exists(INTERNAL_CACHE_PATH): + try: + cache_data = pickle.load(open(INTERNAL_CACHE_PATH, 'rb')) + if cache_data.get('__version__', '') == __VERSION__: + convoy_paths.update(cache_data) + except (pickle.UnpicklingError, EOFError): + pass - # Loading cache from disk (only if it's the correct version) - if os.path.exists(DISK_CACHE_PATH): + # Loading external cache + if os.path.exists(EXTERNAL_CACHE_PATH): try: - cache_data = pickle.load(open(DISK_CACHE_PATH, 'rb')) + cache_data = pickle.load(open(EXTERNAL_CACHE_PATH, 'rb')) if cache_data.get('__version__', '') != __VERSION__: print('Upgrading cache from "%s" to "%s"' % (cache_data.get('__version__', '<N/A>'), __VERSION__)) else: - disk_convoy_paths.update(cache_data) - - # Invalid pickle file - Rebuilding + convoy_paths.update(cache_data) + external_convoy_paths.update(cache_data) except (pickle.UnpicklingError, EOFError): pass @@ -200,28 +226,35 @@ def add_to_cache(map_name): return None map_hash = get_file_md5(map_path) - # Determining the depth of the search (small maps can have larger depth) - max_convoy_length = SMALL_MAPS_CONVOY_LENGTH if map_name in SMALL_MAPS else ALL_MAPS_CONVOY_LENGTH - # Generating and adding to alternate cache paths - if map_hash not in disk_convoy_paths: + if map_hash not in convoy_paths: map_object = Map(map_name, use_cache=False) - disk_convoy_paths[map_hash] = build_convoy_paths_cache(map_object, max_convoy_length) - os.makedirs(os.path.dirname(DISK_CACHE_PATH), exist_ok=True) - pickle.dump(disk_convoy_paths, open(DISK_CACHE_PATH, 'wb')) + convoy_paths[map_hash] = _build_convoy_paths_cache(map_object, max_convoy_length) + external_convoy_paths[map_hash] = convoy_paths[map_hash] + os.makedirs(os.path.dirname(EXTERNAL_CACHE_PATH), exist_ok=True) + pickle.dump(external_convoy_paths, open(EXTERNAL_CACHE_PATH, 'wb')) # Returning - return disk_convoy_paths[map_hash] + return convoy_paths[map_hash] def get_convoy_paths_cache(): """ Returns the current cache from disk """ disk_convoy_paths = {} # Uses hash as key cache_convoy_paths = {} # Use map name as key - # Loading cache from disk (only if it's the correct version) - if os.path.exists(DISK_CACHE_PATH): + # Loading from internal cache first + if os.path.exists(INTERNAL_CACHE_PATH): try: - cache_data = pickle.load(open(DISK_CACHE_PATH, 'rb')) + cache_data = pickle.load(open(INTERNAL_CACHE_PATH, 'rb')) + if cache_data.get('__version__', '') == __VERSION__: + disk_convoy_paths.update(cache_data) + except (pickle.UnpicklingError, EOFError): + pass + + # Loading external cache + if os.path.exists(EXTERNAL_CACHE_PATH): + try: + cache_data = pickle.load(open(EXTERNAL_CACHE_PATH, 'rb')) if cache_data.get('__version__', '') == __VERSION__: disk_convoy_paths.update(cache_data) except (pickle.UnpicklingError, EOFError): @@ -238,3 +271,16 @@ def get_convoy_paths_cache(): # Returning return cache_convoy_paths + +def rebuild_all_maps(): + """ Rebuilds all the maps in the external cache """ + if os.path.exists(EXTERNAL_CACHE_PATH): + os.remove(EXTERNAL_CACHE_PATH) + + files_path = glob.glob(settings.PACKAGE_DIR + '/maps/*.map') + for file_path in files_path: + map_name = file_path.replace(settings.PACKAGE_DIR + '/maps/', '').replace('.map', '') + map_hash = get_file_md5(file_path) + print('-' * 80) + print('Adding {} (Hash: {}) to cache\n'.format(file_path, map_hash)) + add_to_cache(map_name) |