Index: mojo/python/tests/messaging_unittest.py |
diff --git a/mojo/python/tests/messaging_unittest.py b/mojo/python/tests/messaging_unittest.py |
index c67048b750af37871be82614591a9a91d434a8db..2d08941ac6b8a60996c2153a3aa70bb8c5c0de91 100644 |
--- a/mojo/python/tests/messaging_unittest.py |
+++ b/mojo/python/tests/messaging_unittest.py |
@@ -10,16 +10,6 @@ from mojo.bindings import messaging |
from mojo import system |
-class _ForwardingMessageReceiver(messaging.MessageReceiver): |
- |
- def __init__(self, callback): |
- self._callback = callback |
- |
- def Accept(self, message): |
- self._callback(message) |
- return True |
- |
- |
class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler): |
def __init__(self, callback): |
@@ -29,7 +19,7 @@ class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler): |
self._callback(result) |
-class MessagingTest(unittest.TestCase): |
+class ConnectorTest(unittest.TestCase): |
def setUp(self): |
mojo.embedder.Init() |
@@ -38,12 +28,13 @@ class MessagingTest(unittest.TestCase): |
self.received_errors = [] |
def _OnMessage(message): |
self.received_messages.append(message) |
+ return True |
def _OnError(result): |
self.received_errors.append(result) |
handles = system.MessagePipe() |
self.connector = messaging.Connector(handles.handle1) |
self.connector.SetIncomingMessageReceiver( |
- _ForwardingMessageReceiver(_OnMessage)) |
+ messaging.ForwardingMessageReceiver(_OnMessage)) |
self.connector.SetErrorHandler( |
_ForwardingConnectionErrorHandler(_OnError)) |
self.connector.Start() |
@@ -79,3 +70,138 @@ class MessagingTest(unittest.TestCase): |
self.connector = None |
(result, _, _) = self.handle.ReadMessage() |
self.assertEquals(result, system.RESULT_FAILED_PRECONDITION) |
+ |
+ |
+class HeaderTest(unittest.TestCase): |
+ |
+ def testSimpleMessageHeader(self): |
+ header = messaging.MessageHeader(0xdeadbeaf, messaging.NO_FLAG) |
+ self.assertEqual(header.message_type, 0xdeadbeaf) |
+ self.assertFalse(header.has_request_id) |
+ self.assertFalse(header.expects_response) |
+ self.assertFalse(header.is_response) |
+ data = header.Serialize() |
+ other_header = messaging.MessageHeader.Deserialize(data) |
+ self.assertEqual(other_header.message_type, 0xdeadbeaf) |
+ self.assertFalse(other_header.has_request_id) |
+ self.assertFalse(other_header.expects_response) |
+ self.assertFalse(other_header.is_response) |
+ |
+ def testMessageHeaderWithRequestID(self): |
+ # Request message. |
+ header = messaging.MessageHeader(0xdeadbeaf, |
+ messaging.MESSAGE_EXPECTS_RESPONSE_FLAG) |
+ |
+ self.assertEqual(header.message_type, 0xdeadbeaf) |
+ self.assertTrue(header.has_request_id) |
+ self.assertTrue(header.expects_response) |
+ self.assertFalse(header.is_response) |
+ self.assertEqual(header.request_id, 0) |
+ |
+ data = header.Serialize() |
+ other_header = messaging.MessageHeader.Deserialize(data) |
+ |
+ self.assertEqual(other_header.message_type, 0xdeadbeaf) |
+ self.assertTrue(other_header.has_request_id) |
+ self.assertTrue(other_header.expects_response) |
+ self.assertFalse(other_header.is_response) |
+ self.assertEqual(other_header.request_id, 0) |
+ |
+ header.request_id = 0xdeadbeafdeadbeaf |
+ data = header.Serialize() |
+ other_header = messaging.MessageHeader.Deserialize(data) |
+ |
+ self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf) |
+ |
+ # Response message. |
+ header = messaging.MessageHeader(0xdeadbeaf, |
+ messaging.MESSAGE_IS_RESPONSE_FLAG, |
+ 0xdeadbeafdeadbeaf) |
+ |
+ self.assertEqual(header.message_type, 0xdeadbeaf) |
+ self.assertTrue(header.has_request_id) |
+ self.assertFalse(header.expects_response) |
+ self.assertTrue(header.is_response) |
+ self.assertEqual(header.request_id, 0xdeadbeafdeadbeaf) |
+ |
+ data = header.Serialize() |
+ other_header = messaging.MessageHeader.Deserialize(data) |
+ |
+ self.assertEqual(other_header.message_type, 0xdeadbeaf) |
+ self.assertTrue(other_header.has_request_id) |
+ self.assertFalse(other_header.expects_response) |
+ self.assertTrue(other_header.is_response) |
+ self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf) |
+ |
+ |
+class RouterTest(unittest.TestCase): |
+ |
+ def setUp(self): |
+ mojo.embedder.Init() |
+ self.loop = system.RunLoop() |
+ self.received_messages = [] |
+ self.received_errors = [] |
+ def _OnMessage(message): |
+ self.received_messages.append(message) |
+ return True |
+ def _OnError(result): |
+ self.received_errors.append(result) |
+ handles = system.MessagePipe() |
+ self.router = messaging.Router(handles.handle1) |
+ self.router.SetIncomingMessageReceiver( |
+ messaging.ForwardingMessageReceiver(_OnMessage)) |
+ self.router.SetErrorHandler( |
+ _ForwardingConnectionErrorHandler(_OnError)) |
+ self.router.Start() |
+ self.handle = handles.handle0 |
+ |
+ def tearDown(self): |
+ self.router = None |
+ self.handle = None |
+ self.loop = None |
+ |
+ def testSimpleMessage(self): |
+ header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize() |
+ message = messaging.Message(header_data) |
+ self.router.Accept(message) |
+ self.loop.RunUntilIdle() |
+ self.assertFalse(self.received_errors) |
+ self.assertFalse(self.received_messages) |
+ (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data))) |
+ self.assertEquals(system.RESULT_OK, res) |
+ self.assertEquals(data[0], header_data) |
+ |
+ def testSimpleReception(self): |
+ header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize() |
+ self.handle.WriteMessage(header_data) |
+ self.loop.RunUntilIdle() |
+ self.assertFalse(self.received_errors) |
+ self.assertEquals(len(self.received_messages), 1) |
+ self.assertEquals(self.received_messages[0].data, header_data) |
+ |
+ def testRequestResponse(self): |
+ header_data = messaging.MessageHeader( |
+ 0, messaging.MESSAGE_EXPECTS_RESPONSE_FLAG).Serialize() |
+ message = messaging.Message(header_data) |
+ back_messages = [] |
+ def OnBackMessage(message): |
+ back_messages.append(message) |
+ self.router.AcceptWithResponder(message, |
+ messaging.ForwardingMessageReceiver( |
+ OnBackMessage)) |
+ self.loop.RunUntilIdle() |
+ self.assertFalse(self.received_errors) |
+ self.assertFalse(self.received_messages) |
+ (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data))) |
+ self.assertEquals(system.RESULT_OK, res) |
+ message_header = messaging.MessageHeader.Deserialize(data[0]) |
+ self.assertNotEquals(message_header.request_id, 0) |
+ response_header_data = messaging.MessageHeader( |
+ 0, |
+ messaging.MESSAGE_IS_RESPONSE_FLAG, |
+ message_header.request_id).Serialize() |
+ self.handle.WriteMessage(response_header_data) |
+ self.loop.RunUntilIdle() |
+ self.assertFalse(self.received_errors) |
+ self.assertEquals(len(back_messages), 1) |
+ self.assertEquals(back_messages[0].data, response_header_data) |