Index: mojo/public/python/mojo/bindings/messaging.py |
diff --git a/mojo/public/python/mojo/bindings/messaging.py b/mojo/public/python/mojo/bindings/messaging.py |
index a6eb57500fac77ee73fad6ce1cfb30576cbf9a7f..848c830502cfbb9e1edd666661144eba60172184 100644 |
--- a/mojo/public/python/mojo/bindings/messaging.py |
+++ b/mojo/public/python/mojo/bindings/messaging.py |
@@ -5,18 +5,140 @@ |
"""Utility classes to handle sending and receiving messages.""" |
+import struct |
import weakref |
# pylint: disable=F0401 |
+import mojo.bindings.serialization as serialization |
import mojo.system as system |
+# The flag values for a message header. |
+NO_FLAG = 0 |
+MESSAGE_EXPECTS_RESPONSE_FLAG = 1 << 0 |
+MESSAGE_IS_RESPONSE_FLAG = 1 << 1 |
+ |
+ |
+class MessageHeader(object): |
+ """The header of a mojo message.""" |
+ |
+ _SIMPLE_MESSAGE_NUM_FIELDS = 2 |
+ _SIMPLE_MESSAGE_STRUCT = struct.Struct("=IIII") |
+ |
+ _REQUEST_ID_STRUCT = struct.Struct("=Q") |
+ _REQUEST_ID_OFFSET = _SIMPLE_MESSAGE_STRUCT.size |
+ |
+ _MESSAGE_WITH_REQUEST_ID_NUM_FIELDS = 3 |
+ _MESSAGE_WITH_REQUEST_ID_SIZE = ( |
+ _SIMPLE_MESSAGE_STRUCT.size + _REQUEST_ID_STRUCT.size) |
+ |
+ def __init__(self, message_type, flags, request_id=0, data=None): |
+ self._message_type = message_type |
+ self._flags = flags |
+ self._request_id = request_id |
+ self._data = data |
+ |
+ @classmethod |
+ def Deserialize(cls, data): |
+ buf = buffer(data) |
+ if len(data) < cls._SIMPLE_MESSAGE_STRUCT.size: |
+ raise serialization.DeserializationException('Header is too short.') |
+ (size, version, message_type, flags) = ( |
+ cls._SIMPLE_MESSAGE_STRUCT.unpack_from(buf)) |
+ if (version < cls._SIMPLE_MESSAGE_NUM_FIELDS): |
+ raise serialization.DeserializationException('Incorrect version.') |
+ request_id = 0 |
+ if _HasRequestId(flags): |
+ if version < cls._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS: |
+ raise serialization.DeserializationException('Incorrect version.') |
+ if (size < cls._MESSAGE_WITH_REQUEST_ID_SIZE or |
+ len(data) < cls._MESSAGE_WITH_REQUEST_ID_SIZE): |
+ raise serialization.DeserializationException('Header is too short.') |
+ (request_id, ) = cls._REQUEST_ID_STRUCT.unpack_from( |
+ buf, cls._REQUEST_ID_OFFSET) |
+ return MessageHeader(message_type, flags, request_id, data) |
+ |
+ @property |
+ def message_type(self): |
+ return self._message_type |
+ |
+ # pylint: disable=E0202 |
+ @property |
+ def request_id(self): |
+ assert self.has_request_id |
+ return self._request_id |
+ |
+ # pylint: disable=E0202 |
+ @request_id.setter |
+ def request_id(self, request_id): |
+ assert self.has_request_id |
+ self._request_id = request_id |
+ self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET, |
+ request_id) |
+ |
+ @property |
+ def has_request_id(self): |
+ return _HasRequestId(self._flags) |
+ |
+ @property |
+ def expects_response(self): |
+ return self._HasFlag(MESSAGE_EXPECTS_RESPONSE_FLAG) |
+ |
+ @property |
+ def is_response(self): |
+ return self._HasFlag(MESSAGE_IS_RESPONSE_FLAG) |
+ |
+ @property |
+ def size(self): |
+ if self.has_request_id: |
+ return self._MESSAGE_WITH_REQUEST_ID_SIZE |
+ return self._SIMPLE_MESSAGE_STRUCT.size |
+ |
+ def Serialize(self): |
+ if not self._data: |
+ self._data = bytearray(self.size) |
+ version = self._SIMPLE_MESSAGE_NUM_FIELDS |
+ size = self._SIMPLE_MESSAGE_STRUCT.size |
+ if self.has_request_id: |
+ version = self._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS |
+ size = self._MESSAGE_WITH_REQUEST_ID_SIZE |
+ self._SIMPLE_MESSAGE_STRUCT.pack_into(self._data, 0, size, version, |
+ self._message_type, self._flags) |
+ if self.has_request_id: |
+ self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET, |
+ self._request_id) |
+ return self._data |
+ |
+ def _HasFlag(self, flag): |
+ return self._flags & flag != 0 |
+ |
+ |
class Message(object): |
"""A message for a message pipe. This contains data and handles.""" |
def __init__(self, data=None, handles=None): |
self.data = data |
self.handles = handles |
+ self._header = None |
+ self._payload = None |
+ |
+ @property |
+ def header(self): |
+ if self._header is None: |
+ self._header = MessageHeader.Deserialize(self.data) |
+ return self._header |
+ |
+ @property |
+ def payload(self): |
+ if self._payload is None: |
+ self._payload = Message(self.data[self.header.size:], self.handles) |
+ return self._payload |
+ |
+ def SetRequestId(self, request_id): |
+ header = self.header |
+ header.request_id = request_id |
+ (data, _) = header.Serialize() |
+ self.data[:header.Size] = data[:header.Size] |
class MessageReceiver(object): |
@@ -111,6 +233,12 @@ class Connector(MessageReceiver): |
result = self._handle.WriteMessage(message.data, message.handles) |
return result == system.RESULT_OK |
+ def Close(self): |
+ if self._cancellable: |
+ self._cancellable() |
+ self._cancellable = None |
+ self._handle.Close() |
+ |
def _OnAsyncWaiterResult(self, result): |
self._cancellable = None |
if result == system.RESULT_OK: |
@@ -141,6 +269,96 @@ class Connector(MessageReceiver): |
self._OnError(result) |
+class Router(MessageReceiverWithResponder): |
+ """ |
+ A Router will handle mojo message and forward those to a Connector. It deals |
+ with parsing of headers and adding of request ids in order to be able to match |
+ a response to a request. |
+ """ |
+ |
+ def __init__(self, handle): |
+ MessageReceiverWithResponder.__init__(self) |
+ self._incoming_message_receiver = None |
+ self._next_request_id = 1 |
+ self._responders = {} |
+ self._connector = Connector(handle) |
+ self._connector.SetIncomingMessageReceiver( |
+ ForwardingMessageReceiver(self._HandleIncomingMessage)) |
+ |
+ def Start(self): |
+ self._connector.Start() |
+ |
+ def SetIncomingMessageReceiver(self, message_receiver): |
+ """ |
+ Set the MessageReceiver that will receive message from the owned message |
+ pipe. |
+ """ |
+ self._incoming_message_receiver = message_receiver |
+ |
+ def SetErrorHandler(self, error_handler): |
+ """ |
+ Set the ConnectionErrorHandler that will be notified of errors on the owned |
+ message pipe. |
+ """ |
+ self._connector.SetErrorHandler(error_handler) |
+ |
+ def Accept(self, message): |
+ # A message without responder is directly forwarded to the connector. |
+ return self._connector.Accept(message) |
+ |
+ def AcceptWithResponder(self, message, responder): |
+ # The message must have a header. |
+ header = message.header |
+ assert header.expects_response |
+ request_id = self.NextRequestId() |
+ header.request_id = request_id |
+ if not self._connector.Accept(message): |
+ return False |
+ self._responders[request_id] = responder |
+ return True |
+ |
+ def Close(self): |
+ self._connector.Close() |
+ |
+ def _HandleIncomingMessage(self, message): |
+ header = message.header |
+ if header.expects_response: |
+ if self._incoming_message_receiver: |
+ return self._incoming_message_receiver.AcceptWithResponder( |
+ message, self) |
+ # If we receive a request expecting a response when the client is not |
+ # listening, then we have no choice but to tear down the pipe. |
+ self.Close() |
+ return False |
+ if header.is_response: |
+ request_id = header.request_id |
+ responder = self._responders.pop(request_id, None) |
+ if responder is not None: |
+ return False |
+ return responder.Accept(message) |
+ if self._incoming_message_receiver: |
+ return self._incoming_message_receiver.Accept(message) |
+ # Ok to drop the message |
+ return False |
+ |
+ def NextRequestId(self): |
+ request_id = self._next_request_id |
+ while request_id == 0 or request_id in self._responders: |
+ request_id = (request_id + 1) % (1 << 64) |
+ self._next_request_id = (request_id + 1) % (1 << 64) |
+ return request_id |
+ |
+class ForwardingMessageReceiver(MessageReceiver): |
+ """A MessageReceiver that forward calls to |Accept| to a callable.""" |
+ |
+ def __init__(self, callback): |
+ MessageReceiver.__init__(self) |
+ self._callback = callback |
+ |
+ def Accept(self, message): |
+ return self._callback(message) |
+ |
+ |
def _WeakCallback(callback): |
func = callback.im_func |
self = callback.im_self |
@@ -165,3 +383,5 @@ def _ReadAndDispatchMessage(handle, message_receiver): |
message_receiver.Accept(Message(data[0], data[1])) |
return result |
+def _HasRequestId(flags): |
+ return flags & (MESSAGE_EXPECTS_RESPONSE_FLAG|MESSAGE_IS_RESPONSE_FLAG) != 0 |