aboutsummaryrefslogtreecommitdiff
path: root/diplomacy/utils/sorted_dict.py
blob: f5ac6d7d0079525bc39f765fe050f171b846b7fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# ==============================================================================
# Copyright (C) 2019 - Philip Paquette, Steven Bocco
#
#  This program is free software: you can redistribute it and/or modify it under
#  the terms of the GNU Affero General Public License as published by the Free
#  Software Foundation, either version 3 of the License, or (at your option) any
#  later version.
#
#  This program is distributed in the hope that it will be useful, but WITHOUT
#  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more
#  details.
#
#  You should have received a copy of the GNU Affero General Public License along
#  with this program.  If not, see <https://www.gnu.org/licenses/>.
# ==============================================================================
""" Helper class to provide a dict with sorted keys. """
from diplomacy.utils.common import is_dictionary
from diplomacy.utils.sorted_set import SortedSet

class SortedDict:
    """ Dict with sorted keys. """
    __slots__ = ['__val_type', '__keys', '__couples']

    def __init__(self, key_type, val_type, kwargs=None):
        """ Initialize a typed SortedDict.

            :param key_type: expected type for keys.
            :param val_type: expected type for values.
            :param kwargs: (optional) dictionary-like object: initial values for sorted dict.
        """
        self.__val_type = val_type
        self.__keys = SortedSet(key_type)
        self.__couples = {}
        if kwargs is not None:
            assert is_dictionary(kwargs)
            for key, value in kwargs.items():
                self.put(key, value)

    @staticmethod
    def builder(key_type, val_type):
        """ Return a function to build sorted dicts from a dictionary-like object.
            Returned function expects a dictionary parameter (an object with method items()).

            .. code-block:: python

                builder_fn = SortedDict.builder(str, int)
                my_sorted_dict = builder_fn({'a': 1, 'b': 2})

            :param key_type: expected type for keys.
            :param val_type: expected type for values.
            :return: callable
        """
        return lambda dictionary: SortedDict(key_type, val_type, dictionary)

    @property
    def key_type(self):
        """ Get key type. """
        return self.__keys.element_type

    @property
    def val_type(self):
        """ Get value type. """
        return self.__val_type

    def __str__(self):
        return 'SortedDict{%s}' % ', '.join('%s:%s' % (k, self.__couples[k]) for k in self.__keys)

    def __bool__(self):
        return bool(self.__keys)

    def __len__(self):
        return len(self.__keys)

    def __eq__(self, other):
        """ Return True if self and other are equal.
            Note that self and other must also have same key and value types.
        """
        assert isinstance(other, SortedDict)
        return (self.key_type is other.key_type
                and self.val_type is other.val_type
                and len(self) == len(other)
                and all(key in other and self[key] == other[key] for key in self.__keys))

    def __getitem__(self, key):
        return self.__couples[key]

    def __setitem__(self, key, value):
        self.put(key, value)

    def __delitem__(self, key):
        self.remove(key)

    def __iter__(self):
        return self.__keys.__iter__()

    def __contains__(self, key):
        return key in self.__couples

    def get(self, key, default=None):
        """ Return value associated with key, or default value if key not found. """
        return self.__couples.get(key, default)

    def put(self, key, value):
        """ Add a key with a value to the dict. """
        if not isinstance(value, self.__val_type):
            raise TypeError('Expected value type %s, got %s' % (self.__val_type, type(value)))
        if key not in self.__keys:
            self.__keys.add(key)
        self.__couples[key] = value

    def remove(self, key):
        """ Pop (remove and return) value associated with given key, or None if key not found. """
        if key in self.__couples:
            self.__keys.remove(key)
        return self.__couples.pop(key, None)

    def first_key(self):
        """ Get the lowest key from the dict. """
        return self.__keys[0]

    def first_value(self):
        """ Get the value associated to lowest key in the dict. """
        return self.__couples[self.__keys[0]]

    def last_key(self):
        """ Get the highest key from the dict. """
        return self.__keys[-1]

    def last_value(self):
        """ Get the value associated to highest key in the dict. """
        return self.__couples[self.__keys[-1]]

    def last_item(self):
        """ Get the item (key-value pair) for the highest key in the dict. """
        return self.__keys[-1], self.__couples[self.__keys[-1]]

    def keys(self):
        """ Get an iterator to the keys in the dict. """
        return iter(self.__keys)

    def values(self):
        """ Get an iterator to the values in the dict. """
        return (self.__couples[k] for k in self.__keys)

    def reversed_values(self):
        """ Get an iterator to the values in the dict in reversed order or keys. """
        return (self.__couples[k] for k in reversed(self.__keys))

    def items(self):
        """ Get an iterator to the items in the dict. """
        return ((k, self.__couples[k]) for k in self.__keys)

    def reversed_items(self):
        """ Get an iterator to the items in the dict in reversed order of keys. """
        return ((k, self.__couples[k]) for k in reversed(self.__keys))

    def sub_keys(self, key_from=None, key_to=None):
        """ Return list of keys between key_from and key_to (both bounds included). """
        position_from, position_to = self._get_keys_interval(key_from, key_to)
        return self.__keys[position_from:(position_to + 1)]

    def sub(self, key_from=None, key_to=None):
        """ Return a list of values associated to keys between key_from and key_to
            (both bounds included).

            - If key_from is None, lowest key in dict is used.
            - If key_to is None, greatest key in dict is used.
            - If key_from is not in dict, lowest key in dict greater than key_from is used.
            - If key_to is not in dict, greatest key in dict less than key_to is used.

            - If dict is empty, return empty list.
            - With keys (None, None) return a copy of all values.
            - With keys (None, key_to), return values from first to the one associated to key_to.
            - With keys (key_from, None), return values from the one associated to key_from to the last value.

            :param key_from: start key
            :param key_to: end key
            :return: list: values in closed keys interval [key_from; key_to]
        """
        position_from, position_to = self._get_keys_interval(key_from, key_to)
        return [self.__couples[k] for k in self.__keys[position_from:(position_to + 1)]]

    def remove_sub(self, key_from=None, key_to=None):
        """ Remove values associated to keys between key_from and key_to (both bounds included).

            See sub() doc about key_from and key_to.

            :param key_from: start key
            :param key_to: end key
            :return: nothing
        """
        position_from, position_to = self._get_keys_interval(key_from, key_to)
        keys_to_remove = self.__keys[position_from:(position_to + 1)]
        for key in keys_to_remove:
            self.remove(key)

    def key_from_index(self, index):
        """ Return key matching given position in sorted dict, or None for invalid position. """
        return self.__keys[index] if -len(self.__keys) <= index < len(self.__keys) else None

    def get_previous_key(self, key):
        """ Return greatest key lower than given key, or None if not exists. """
        return self.__keys.get_previous_value(key)

    def get_next_key(self, key):
        """ Return smallest key greater then given key, or None if not exists. """
        return self.__keys.get_next_value(key)

    def _get_keys_interval(self, key_from, key_to):
        """ Get a couple of internal key positions (index of key_from, index of key_to) allowing
            to easily retrieve values in closed interval [index of key_from; index of key_to]
            corresponding to Python slice [index of key_from : (index of key_to + 1)]

            - If dict is empty, return (0, -1), so that python slice [0 : -1 + 1] corresponds to empty interval.
            - If key_from is None, lowest key in dict is used.
            - If key_to is None, greatest key in dict is used.
            - If key_from is not in dict, lowest key in dict greater than key_from is used.
            - If key_to is not in dict, greatest key in dict less than key_to is used.

            - With keys (None, None), we get interval of all values.
            - With keys (key_from, None), we get interval for values from key_from to the last key.
            - With keys (None, key_to), we get interval for values from the first key to key_to.

            :param key_from: start key
            :param key_to: end key
            :return: (int, int): couple of integers: (index of key_from, index of key_to).
        """
        if not self:
            return 0, -1
        if key_from is not None and key_from not in self.__couples:
            key_from = self.__keys.get_next_value(key_from)
            if key_from is None:
                return 0, -1
        if key_to is not None and key_to not in self.__couples:
            key_to = self.__keys.get_previous_value(key_to)
            if key_to is None:
                return 0, -1
        if key_from is None and key_to is None:
            key_from = self.first_key()
            key_to = self.last_key()
        elif key_from is not None and key_to is None:
            key_to = self.last_key()
        elif key_from is None and key_to is not None:
            key_from = self.first_key()
        if key_from > key_to:
            raise IndexError('expected key_from <= key_to (%s vs %s)' % (key_from, key_to))
        position_from = self.__keys.index(key_from)
        position_to = self.__keys.index(key_to)
        assert position_from is not None and position_to is not None
        return position_from, position_to

    def clear(self):
        """ Remove all items from dict. """
        self.__couples.clear()
        self.__keys.clear()

    def fill(self, dct):
        """ Add given dict to this sorted dict. """
        if dct:
            assert is_dictionary(dct)
            for key, value in dct.items():
                self.put(key, value)

    def copy(self):
        """ Return a copy of this sorted dict. """
        return SortedDict(self.__keys.element_type, self.__val_type, self.__couples)