aboutsummaryrefslogtreecommitdiff
path: root/diplomacy/daide/connection_handler.py
diff options
context:
space:
mode:
authorSatya Ortiz-Gagne <satya.ortiz-gagne@mila.quebec>2019-06-10 10:20:39 -0400
committerPhilip Paquette <pcpaquette@gmail.com>2019-06-14 15:08:29 -0400
commit08b9469e2c71e06fdd70d607f281686746755073 (patch)
tree72066b7920e1c6aa0c8a0b390418c25a1afec1ad /diplomacy/daide/connection_handler.py
parent9628cdadbf5a6380098168846dc51e4feadef6ad (diff)
DAIDE - Added connection_handler and server
- Ability to open and close port when DAIDE games are started and stopped - Can get the DAIDE port using a request
Diffstat (limited to 'diplomacy/daide/connection_handler.py')
-rw-r--r--diplomacy/daide/connection_handler.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/diplomacy/daide/connection_handler.py b/diplomacy/daide/connection_handler.py
new file mode 100644
index 0000000..0b606bf
--- /dev/null
+++ b/diplomacy/daide/connection_handler.py
@@ -0,0 +1,200 @@
+# ==============================================================================
+# Copyright (C) 2019 - Philip Paquette
+#
+# This program is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Affero General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option) any
+# later version.
+#
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Affero General Public License along
+# with this program. If not, see <https://www.gnu.org/licenses/>.
+# ==============================================================================
+""" Tornado stream wrapper, used internally to abstract a DAIDE stream connection from a WebSocketConnection. """
+import logging
+from tornado import gen
+from tornado.concurrent import Future
+from tornado.iostream import StreamClosedError
+from diplomacy.daide import notifications, request_managers, responses
+from diplomacy.daide.messages import DiplomacyMessage, DaideMessage, ErrorMessage, RepresentationMessage, MessageType
+from diplomacy.daide.notification_managers import translate_notification
+from diplomacy.daide.requests import RequestBuilder
+from diplomacy.daide.utils import bytes_to_str
+from diplomacy.utils import exceptions
+
+# Constants
+LOGGER = logging.getLogger(__name__)
+
+class ConnectionHandler():
+ """ ConnectionHandler class. Properties:
+ - server: server object representing running server.
+ """
+ _NAME_VARIANT_PREFIX = 'DAIDE'
+ _NAME_VARIANTS_POOL = []
+ _USED_NAME_VARIANTS = []
+
+ def __init__(self):
+ self.stream = None
+ self.server = None
+ self.game_id = None
+ self.token = None
+ self._name_variant = None
+ self._socket_no = None
+ self._local_addr = ('::1', 0, 0, 0)
+ self._remote_addr = ('::1', 0, 0, 0)
+
+ self.message_mapping = {MessageType.INITIAL: self._on_initial_message,
+ MessageType.DIPLOMACY: self._on_diplomacy_message,
+ MessageType.FINAL: self._on_final_message,
+ MessageType.ERROR: self._on_error_message}
+
+ def initialize(self, stream, server, game_id):
+ """ Initialize the connection handler.
+ :param server: a Server object.
+ :type server: diplomacy.Server
+ """
+ self.stream = stream
+ self.server = server
+ self.game_id = game_id
+ stream.set_close_callback(self.on_connection_close)
+ self._socket_no = self.stream.socket.fileno()
+ self._local_addr = stream.socket.getsockname()
+ self._remote_addr = stream.socket.getpeername()
+
+ @property
+ def local_addr(self):
+ """ Return the address of the local endpoint """
+ return self._local_addr
+
+ @property
+ def remote_addr(self):
+ """ Return the address of the remote endpoint """
+ return self._remote_addr
+
+ def get_name_variant(self):
+ """ Return the address of the remote endpoint """
+ if self._name_variant is None:
+ self._name_variant = self._NAME_VARIANTS_POOL.pop(0) if self._NAME_VARIANTS_POOL \
+ else len(self._USED_NAME_VARIANTS)
+ self._USED_NAME_VARIANTS.append(self._name_variant)
+ return self._NAME_VARIANT_PREFIX + str(self._name_variant)
+
+ def release_name_variant(self):
+ """ Return the next available user name variant """
+ self._USED_NAME_VARIANTS.remove(self._name_variant)
+ self._NAME_VARIANTS_POOL.append(self._name_variant)
+ self._name_variant = None
+
+ @gen.coroutine
+ def close_connection(self):
+ """ Close the connection with the client """
+ try:
+ message = DiplomacyMessage()
+ message.content = bytes(responses.TurnOffResponse())
+ yield self.write_message(message)
+ self.stream.close()
+ except StreamClosedError:
+ LOGGER.error('Stream is closed.')
+
+ def on_connection_close(self):
+ """ Invoked when the socket is closed (see parent method).
+ Detach this connection handler from server users.
+ """
+ self.release_name_variant()
+ self.server.users.remove_connection(self, remove_tokens=False)
+ LOGGER.info('Removed connection. Remaining %d connection(s).', self.server.users.count_connections())
+
+ @gen.coroutine
+ def read_stream(self):
+ """ Read the next message from the stream """
+ messages = []
+ in_message = yield DaideMessage.from_stream(self.stream)
+
+ if in_message and in_message.is_valid:
+ message_handler = self.message_mapping.get(in_message.message_type, None)
+ if not message_handler:
+ raise RuntimeError('Unrecognized DAIDE message type [{}]'.format(in_message.message_type))
+
+ if gen.is_coroutine_function(message_handler):
+ messages = yield message_handler(in_message)
+ else:
+ messages = message_handler(in_message)
+ elif in_message:
+ err_message = ErrorMessage()
+ err_message.error_code = in_message.error_code
+ messages = [err_message]
+
+ for message in messages:
+ yield self.write_message(message)
+
+ # Added for compatibility with WebSocketHandler interface
+ def write_message(self, message, binary=True):
+ """ Write a message into the stream """
+ if binary and isinstance(message, bytes):
+ future = self.stream.write(message)
+ else:
+ if isinstance(message, notifications.DaideNotification):
+ LOGGER.info('[%d] notification:[%s]', self._socket_no, bytes_to_str(bytes(message)))
+ notification = message
+ message = DiplomacyMessage()
+ message.content = bytes(notification)
+
+ if isinstance(message, DaideMessage):
+ future = self.stream.write(bytes(message))
+ else:
+ future = Future()
+ future.set_result(None)
+ return future
+
+ def translate_notification(self, notification):
+ """ Translate a notification to a DAIDE notification.
+ :param notification: a notification object to pass to handler function.
+ See diplomacy.communication.notifications for possible notifications.
+ :return: either None or an array of daide notifications.
+ See module diplomacy.daide.notifications for possible daide notifications.
+ """
+ return translate_notification(self.server, notification, self)
+
+ def _on_initial_message(self, _):
+ """ Handle an initial message """
+ LOGGER.info('[%d] initial message', self._socket_no)
+ return [RepresentationMessage()]
+
+ @gen.coroutine
+ def _on_diplomacy_message(self, in_message):
+ """ Handle a diplomacy message """
+ messages = []
+ request = RequestBuilder.from_bytes(in_message.content)
+
+ try:
+ LOGGER.info('[%d] request:[%s]', self._socket_no, bytes_to_str(in_message.content))
+ request.game_id = self.game_id
+ message_responses = yield request_managers.handle_request(self.server, request, self)
+ except exceptions.ResponseException:
+ message_responses = [responses.REJ(bytes(request))]
+
+ if message_responses:
+ for response in message_responses:
+ response_bytes = bytes(response)
+ LOGGER.info('[%d] response:[%s]', self._socket_no, bytes_to_str(response_bytes) \
+ if response_bytes else None)
+ message = DiplomacyMessage()
+ message.content = response_bytes
+ messages.append(message)
+
+ return messages
+
+ def _on_final_message(self, _):
+ """ Handle a final message """
+ LOGGER.info('[%d] final message', self._socket_no)
+ self.stream.close()
+ return []
+
+ def _on_error_message(self, in_message):
+ """ Handle an error message """
+ LOGGER.error('[%d] error [%d]', self._socket_no, in_message.error_code)
+ return []