# ============================================================================== # Copyright (C) 2019 - Philip Paquette, Steven Bocco # # This program is free software: you can redistribute it and/or modify it under # the terms of the GNU Affero General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. # # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more # details. # # You should have received a copy of the GNU Affero General Public License along # with this program. If not, see . # ============================================================================== """ Helper class to provide a dict with sorted keys. """ from diplomacy.utils.common import is_dictionary from diplomacy.utils.sorted_set import SortedSet class SortedDict(): """ Dict with sorted keys. """ __slots__ = ['__val_type', '__keys', '__couples'] def __init__(self, key_type, val_type, kwargs=None): """ Initialize a typed SortedDict. :param key_type: expected type for keys. :param val_type: expected type for values. :param kwargs: (optional) dictionary-like object: initial values for sorted dict. """ self.__val_type = val_type self.__keys = SortedSet(key_type) self.__couples = {} if kwargs is not None: assert is_dictionary(kwargs) for key, value in kwargs.items(): self.put(key, value) @staticmethod def builder(key_type, val_type): """ Return a function to build sorted dicts from a dictionary-like object. Returned function expects a dictionary parameter (an object with method items()). builder_fn = SortedDict.builder(str, int) my_sorted_dict = builder_fn({'a': 1, 'b': 2}) :param key_type: expected type for keys. :param val_type: expected type for values. :return: callable """ return lambda dictionary: SortedDict(key_type, val_type, dictionary) @property def key_type(self): """ Get key type. """ return self.__keys.element_type @property def val_type(self): """ Get value type. """ return self.__val_type def __str__(self): return 'SortedDict{%s}' % ', '.join('%s:%s' % (k, self.__couples[k]) for k in self.__keys) def __bool__(self): return bool(self.__keys) def __len__(self): return len(self.__keys) def __eq__(self, other): """ Return True if self and other are equal. Note that self and other must also have same key and value types. """ assert isinstance(other, SortedDict) return (self.key_type is other.key_type and self.val_type is other.val_type and len(self) == len(other) and all(key in other and self[key] == other[key] for key in self.__keys)) def __getitem__(self, key): return self.__couples[key] def __setitem__(self, key, value): self.put(key, value) def __delitem__(self, key): self.remove(key) def __iter__(self): return self.__keys.__iter__() def __contains__(self, key): return key in self.__couples def get(self, key, default=None): """ Return value associated with key, or default value if key not found. """ return self.__couples.get(key, default) def put(self, key, value): """ Add a key with a value to the dict. """ if not isinstance(value, self.__val_type): raise TypeError('Expected value type %s, got %s' % (self.__val_type, type(value))) if key not in self.__keys: self.__keys.add(key) self.__couples[key] = value def remove(self, key): """ Pop (remove and return) value associated with given key, or None if key not found. """ if key in self.__couples: self.__keys.remove(key) return self.__couples.pop(key, None) def first_key(self): """ Get the lowest key from the dict. """ return self.__keys[0] def first_value(self): """ Get the value associated to lowest key in the dict. """ return self.__couples[self.__keys[0]] def last_key(self): """ Get the highest key from the dict. """ return self.__keys[-1] def last_value(self): """ Get the value associated to highest key in the dict. """ return self.__couples[self.__keys[-1]] def last_item(self): """ Get the item (key-value pair) for the highest key in the dict. """ return self.__keys[-1], self.__couples[self.__keys[-1]] def keys(self): """ Get an iterator to the keys in the dict. """ return iter(self.__keys) def values(self): """ Get an iterator to the values in the dict. """ return (self.__couples[k] for k in self.__keys) def reversed_values(self): """ Get an iterator to the values in the dict in reversed order or keys. """ return (self.__couples[k] for k in reversed(self.__keys)) def items(self): """ Get an iterator to the items in the dict. """ return ((k, self.__couples[k]) for k in self.__keys) def reversed_items(self): """ Get an iterator to the items in the dict in reversed order of keys. """ return ((k, self.__couples[k]) for k in reversed(self.__keys)) def sub_keys(self, key_from=None, key_to=None): """ Return list of keys between key_from and key_to (both bounds included). """ position_from, position_to = self._get_keys_interval(key_from, key_to) return self.__keys[position_from:(position_to + 1)] def sub(self, key_from=None, key_to=None): """ Return a list of values associated to keys between key_from and key_to (both bounds included). If key_from is None, lowest key in dict is used. If key_to is None, greatest key in dict is used. If key_from is not in dict, lowest key in dict greater than key_from is used. If key_to is not in dict, greatest key in dict less than key_to is used. If dict is empty, return empty list. With keys (None, None) return a copy of all values. With keys (None, key_to), return values from first to the one associated to key_to. With keys (key_from, None), return values from the one associated to key_from to the last value. :param key_from: start key :param key_to: end key :return: list: values in closed keys interval [key_from; key_to] """ position_from, position_to = self._get_keys_interval(key_from, key_to) return [self.__couples[k] for k in self.__keys[position_from:(position_to + 1)]] def remove_sub(self, key_from=None, key_to=None): """ Remove values associated to keys between key_from and key_to (both bounds included). See sub() doc about key_from and key_to. :param key_from: start key :param key_to: end key :return: nothing """ position_from, position_to = self._get_keys_interval(key_from, key_to) keys_to_remove = self.__keys[position_from:(position_to + 1)] for key in keys_to_remove: self.remove(key) def key_from_index(self, index): """ Return key matching given position in sorted dict, or None for invalid position. """ return self.__keys[index] if -len(self.__keys) <= index < len(self.__keys) else None def get_previous_key(self, key): """ Return greatest key lower than given key, or None if not exists. """ return self.__keys.get_previous_value(key) def get_next_key(self, key): """ Return smallest key greater then given key, or None if not exists. """ return self.__keys.get_next_value(key) def _get_keys_interval(self, key_from, key_to): """ Get a couple of internal key positions (index of key_from, index of key_to) allowing to easily retrieve values in closed interval [index of key_from; index of key_to] corresponding to Python slice [index of key_from : (index of key_to + 1)] If dict is empty, return (0, -1), so that python slice [0 : -1 + 1] corresponds to empty interval. If key_from is None, lowest key in dict is used. If key_to is None, greatest key in dict is used. If key_from is not in dict, lowest key in dict greater than key_from is used. If key_to is not in dict, greatest key in dict less than key_to is used. Thus: - With keys (None, None), we get interval of all values. - With keys (key_from, None), we get interval for values from key_from to the last key. - With keys (None, key_to), we get interval for values from the first key to key_to. :param key_from: start key :param key_to: end key :return: (int, int): couple of integers: (index of key_from, index of key_to). """ if not self: return 0, -1 if key_from is not None and key_from not in self.__couples: key_from = self.__keys.get_next_value(key_from) if key_from is None: return 0, -1 if key_to is not None and key_to not in self.__couples: key_to = self.__keys.get_previous_value(key_to) if key_to is None: return 0, -1 if key_from is None and key_to is None: key_from = self.first_key() key_to = self.last_key() elif key_from is not None and key_to is None: key_to = self.last_key() elif key_from is None and key_to is not None: key_from = self.first_key() if key_from > key_to: raise IndexError('expected key_from <= key_to (%s vs %s)' % (key_from, key_to)) position_from = self.__keys.index(key_from) position_to = self.__keys.index(key_to) assert position_from is not None and position_to is not None return position_from, position_to def clear(self): """ Remove all items from dict. """ self.__couples.clear() self.__keys.clear() def fill(self, dct): """ Add given dict to this sorted dict. """ if dct: assert is_dictionary(dct) for key, value in dct.items(): self.put(key, value) def copy(self): """ Return a copy of this sorted dict. """ return SortedDict(self.__keys.element_type, self.__val_type, self.__couples)