OLD | NEW |
(Empty) | |
| 1 try: |
| 2 import unittest2 as unittest |
| 3 except ImportError: |
| 4 import unittest |
| 5 import httplib |
| 6 |
| 7 from mock import Mock |
| 8 |
| 9 |
| 10 class AWSMockServiceTestCase(unittest.TestCase): |
| 11 """Base class for mocking aws services.""" |
| 12 # This param is used by the unittest module to display a full |
| 13 # diff when assert*Equal methods produce an error message. |
| 14 maxDiff = None |
| 15 connection_class = None |
| 16 |
| 17 def setUp(self): |
| 18 self.https_connection = Mock(spec=httplib.HTTPSConnection) |
| 19 self.https_connection_factory = ( |
| 20 Mock(return_value=self.https_connection), ()) |
| 21 self.service_connection = self.create_service_connection( |
| 22 https_connection_factory=self.https_connection_factory, |
| 23 aws_access_key_id='aws_access_key_id', |
| 24 aws_secret_access_key='aws_secret_access_key') |
| 25 self.initialize_service_connection() |
| 26 |
| 27 def initialize_service_connection(self): |
| 28 self.actual_request = None |
| 29 self.original_mexe = self.service_connection._mexe |
| 30 self.service_connection._mexe = self._mexe_spy |
| 31 |
| 32 def create_service_connection(self, **kwargs): |
| 33 if self.connection_class is None: |
| 34 raise ValueError("The connection_class class attribute must be " |
| 35 "set to a non-None value.") |
| 36 return self.connection_class(**kwargs) |
| 37 |
| 38 def _mexe_spy(self, request, *args, **kwargs): |
| 39 self.actual_request = request |
| 40 return self.original_mexe(request, *args, **kwargs) |
| 41 |
| 42 def create_response(self, status_code, reason='', header=[], body=None): |
| 43 if body is None: |
| 44 body = self.default_body() |
| 45 response = Mock(spec=httplib.HTTPResponse) |
| 46 response.status = status_code |
| 47 response.read.return_value = body |
| 48 response.reason = reason |
| 49 |
| 50 response.getheaders.return_value = header |
| 51 response.msg = dict(header) |
| 52 def overwrite_header(arg, default=None): |
| 53 header_dict = dict(header) |
| 54 if header_dict.has_key(arg): |
| 55 return header_dict[arg] |
| 56 else: |
| 57 return default |
| 58 response.getheader.side_effect = overwrite_header |
| 59 |
| 60 return response |
| 61 |
| 62 def assert_request_parameters(self, params, ignore_params_values=None): |
| 63 """Verify the actual parameters sent to the service API.""" |
| 64 request_params = self.actual_request.params.copy() |
| 65 if ignore_params_values is not None: |
| 66 for param in ignore_params_values: |
| 67 # We still want to check that the ignore_params_values params |
| 68 # are in the request parameters, we just don't need to check |
| 69 # their value. |
| 70 self.assertIn(param, request_params) |
| 71 del request_params[param] |
| 72 self.assertDictEqual(request_params, params) |
| 73 |
| 74 def set_http_response(self, status_code, reason='', header=[], body=None): |
| 75 http_response = self.create_response(status_code, reason, header, body) |
| 76 self.https_connection.getresponse.return_value = http_response |
| 77 |
| 78 def default_body(self): |
| 79 return '' |
OLD | NEW |