Index: third_party/gsutil/oauth2_plugin/oauth2_client_test.py |
diff --git a/third_party/gsutil/oauth2_plugin/oauth2_client_test.py b/third_party/gsutil/oauth2_plugin/oauth2_client_test.py |
new file mode 100644 |
index 0000000000000000000000000000000000000000..1d8e581962554d010bc5fa2bdb32aae115a2b3c3 |
--- /dev/null |
+++ b/third_party/gsutil/oauth2_plugin/oauth2_client_test.py |
@@ -0,0 +1,374 @@ |
+# Copyright 2010 Google Inc. All Rights Reserved. |
+# |
+# Licensed under the Apache License, Version 2.0 (the "License"); |
+# you may not use this file except in compliance with the License. |
+# You may obtain a copy of the License at |
+# |
+# http://www.apache.org/licenses/LICENSE-2.0 |
+# |
+# Unless required by applicable law or agreed to in writing, software |
+# distributed under the License is distributed on an "AS IS" BASIS, |
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+# See the License for the specific language governing permissions and |
+# limitations under the License. |
+ |
+"""Unit tests for oauth2_client.""" |
+ |
+import datetime |
+import logging |
+import os |
+import sys |
+import unittest |
+import urllib2 |
+import urlparse |
+from stat import S_IMODE |
+from StringIO import StringIO |
+ |
+test_bin_dir = os.path.dirname(os.path.realpath(sys.argv[0])) |
+ |
+lib_dir = os.path.join(test_bin_dir, '..') |
+sys.path.insert(0, lib_dir) |
+ |
+# Needed for boto.cacerts |
+boto_lib_dir = os.path.join(test_bin_dir, '..', 'boto') |
+sys.path.insert(0, boto_lib_dir) |
+ |
+import oauth2_client |
+ |
+LOG = logging.getLogger('oauth2_client_test') |
+ |
+class MockOpener: |
+ def __init__(self): |
+ self.reset() |
+ |
+ def reset(self): |
+ self.open_error = None |
+ self.open_result = None |
+ self.open_capture_url = None |
+ self.open_capture_data = None |
+ |
+ def open(self, req, data=None): |
+ self.open_capture_url = req.get_full_url() |
+ self.open_capture_data = req.get_data() |
+ if self.open_error is not None: |
+ raise self.open_error |
+ else: |
+ return StringIO(self.open_result) |
+ |
+ |
+class MockDateTime: |
+ def __init__(self): |
+ self.mock_now = None |
+ |
+ def utcnow(self): |
+ return self.mock_now |
+ |
+ |
+class OAuth2ClientTest(unittest.TestCase): |
+ def setUp(self): |
+ self.opener = MockOpener() |
+ self.mock_datetime = MockDateTime() |
+ self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
+ self.mock_datetime.mock_now = self.start_time |
+ self.client = oauth2_client.OAuth2Client( |
+ oauth2_client.OAuth2Provider( |
+ 'Sample OAuth Provider', |
+ 'https://provider.example.com/oauth/provider?mode=authorize', |
+ 'https://provider.example.com/oauth/provider?mode=token'), |
+ 'clid', 'clsecret', |
+ url_opener=self.opener, datetime_strategy=self.mock_datetime) |
+ |
+ def testFetchAccessToken(self): |
+ refresh_token = '1/ZaBrxdPl77Bi4jbsO7x-NmATiaQZnWPB51nTvo8n9Sw' |
+ access_token = '1/aalskfja-asjwerwj' |
+ self.opener.open_result = ( |
+ '{"access_token":"%s","expires_in":3600}' % access_token) |
+ cred = oauth2_client.RefreshToken(self.client, refresh_token) |
+ token = self.client.FetchAccessToken(cred) |
+ |
+ self.assertEquals( |
+ self.opener.open_capture_url, |
+ 'https://provider.example.com/oauth/provider?mode=token') |
+ self.assertEquals({ |
+ 'grant_type': ['refresh_token'], |
+ 'client_id': ['clid'], |
+ 'client_secret': ['clsecret'], |
+ 'refresh_token': [refresh_token]}, |
+ urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
+ strict_parsing=True)) |
+ self.assertEquals(access_token, token.token) |
+ self.assertEquals( |
+ datetime.datetime(2011, 3, 1, 11, 25, 13, 300826), |
+ token.expiry) |
+ |
+ def testFetchAccessTokenFailsForBadJsonResponse(self): |
+ self.opener.open_result = 'blah' |
+ cred = oauth2_client.RefreshToken(self.client, 'abc123') |
+ self.assertRaises( |
+ oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred) |
+ |
+ def testFetchAccessTokenFailsForErrorResponse(self): |
+ self.opener.open_error = urllib2.HTTPError( |
+ None, 400, 'Bad Request', None, StringIO('{"error": "invalid token"}')) |
+ cred = oauth2_client.RefreshToken(self.client, 'abc123') |
+ self.assertRaises( |
+ oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred) |
+ |
+ def testFetchAccessTokenFailsForHttpError(self): |
+ self.opener.open_result = urllib2.HTTPError( |
+ 'foo', 400, 'Bad Request', None, None) |
+ cred = oauth2_client.RefreshToken(self.client, 'abc123') |
+ self.assertRaises( |
+ oauth2_client.AccessTokenRefreshError, self.client.FetchAccessToken, cred) |
+ |
+ def testGetAccessToken(self): |
+ refresh_token = 'ref_token' |
+ access_token_1 = 'abc123' |
+ self.opener.open_result = ( |
+ '{"access_token":"%s",' '"expires_in":3600}' % access_token_1) |
+ cred = oauth2_client.RefreshToken(self.client, refresh_token) |
+ |
+ token_1 = self.client.GetAccessToken(cred) |
+ |
+ # There's no access token in the cache; verify that we fetched a fresh |
+ # token. |
+ self.assertEquals({ |
+ 'grant_type': ['refresh_token'], |
+ 'client_id': ['clid'], |
+ 'client_secret': ['clsecret'], |
+ 'refresh_token': [refresh_token]}, |
+ urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
+ strict_parsing=True)) |
+ self.assertEquals(access_token_1, token_1.token) |
+ self.assertEquals(self.start_time + datetime.timedelta(minutes=60), |
+ token_1.expiry) |
+ |
+ # Advance time by less than expiry time, and fetch another token. |
+ self.opener.reset() |
+ self.mock_datetime.mock_now = ( |
+ self.start_time + datetime.timedelta(minutes=55)) |
+ token_2 = self.client.GetAccessToken(cred) |
+ |
+ # Since the access token wasn't expired, we get the cache token, and there |
+ # was no refresh request. |
+ self.assertEquals(token_1, token_2) |
+ self.assertEquals(access_token_1, token_2.token) |
+ self.assertEquals(None, self.opener.open_capture_url) |
+ self.assertEquals(None, self.opener.open_capture_data) |
+ |
+ # Advance time past expiry time, and fetch another token. |
+ self.opener.reset() |
+ self.mock_datetime.mock_now = ( |
+ self.start_time + datetime.timedelta(minutes=55, seconds=1)) |
+ access_token_2 = 'zyx456' |
+ self.opener.open_result = ( |
+ '{"access_token":"%s",' '"expires_in":3600}' % access_token_2) |
+ token_3 = self.client.GetAccessToken(cred) |
+ |
+ # This should have resulted in a refresh request and a fresh access token. |
+ self.assertEquals({ |
+ 'grant_type': ['refresh_token'], |
+ 'client_id': ['clid'], |
+ 'client_secret': ['clsecret'], |
+ 'refresh_token': [refresh_token]}, |
+ urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
+ strict_parsing=True)) |
+ self.assertEquals(access_token_2, token_3.token) |
+ self.assertEquals(self.mock_datetime.mock_now + datetime.timedelta(minutes=60), |
+ token_3.expiry) |
+ |
+ def testGetAuthorizationUri(self): |
+ authn_uri = self.client.GetAuthorizationUri( |
+ 'https://www.example.com/oauth/redir?mode=approve%20me', |
+ ('scope_foo', 'scope_bar'), |
+ {'state': 'this and that & sundry'}) |
+ |
+ uri_parts = urlparse.urlsplit(authn_uri) |
+ self.assertEquals(('https', 'provider.example.com', '/oauth/provider'), |
+ uri_parts[:3]) |
+ |
+ self.assertEquals({ |
+ 'response_type': ['code'], |
+ 'client_id': ['clid'], |
+ 'redirect_uri': |
+ ['https://www.example.com/oauth/redir?mode=approve%20me'], |
+ 'scope': ['scope_foo scope_bar'], |
+ 'state': ['this and that & sundry'], |
+ 'mode': ['authorize']}, |
+ urlparse.parse_qs(uri_parts[3])) |
+ |
+ def testExchangeAuthorizationCode(self): |
+ code = 'codeABQ1234' |
+ exp_refresh_token = 'ref_token42' |
+ exp_access_token = 'access_tokenXY123' |
+ self.opener.open_result = ( |
+ '{"access_token":"%s","expires_in":3600,"refresh_token":"%s"}' |
+ % (exp_access_token, exp_refresh_token)) |
+ |
+ refresh_token, access_token = self.client.ExchangeAuthorizationCode( |
+ code, 'urn:ietf:wg:oauth:2.0:oob', ('scope1', 'scope2')) |
+ |
+ self.assertEquals({ |
+ 'grant_type': ['authorization_code'], |
+ 'client_id': ['clid'], |
+ 'client_secret': ['clsecret'], |
+ 'code': [code], |
+ 'redirect_uri': ['urn:ietf:wg:oauth:2.0:oob'], |
+ 'scope': ['scope1 scope2'] }, |
+ urlparse.parse_qs(self.opener.open_capture_data, keep_blank_values=True, |
+ strict_parsing=True)) |
+ self.assertEquals(exp_access_token, access_token.token) |
+ self.assertEquals(self.start_time + datetime.timedelta(minutes=60), |
+ access_token.expiry) |
+ |
+ self.assertEquals(self.client, refresh_token.oauth2_client) |
+ self.assertEquals(exp_refresh_token, refresh_token.refresh_token) |
+ |
+ # Check that the access token was put in the cache. |
+ cached_token = self.client.access_token_cache.GetToken( |
+ refresh_token.CacheKey()) |
+ self.assertEquals(access_token, cached_token) |
+ |
+ |
+class AccessTokenTest(unittest.TestCase): |
+ |
+ def testShouldRefresh(self): |
+ mock_datetime = MockDateTime() |
+ start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
+ expiry = start + datetime.timedelta(minutes=60) |
+ token = oauth2_client.AccessToken( |
+ 'foo', expiry, datetime_strategy=mock_datetime) |
+ |
+ mock_datetime.mock_now = start |
+ self.assertFalse(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta(minutes=54) |
+ self.assertFalse(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta(minutes=55) |
+ self.assertFalse(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta( |
+ minutes=55, seconds=1) |
+ self.assertTrue(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta( |
+ minutes=61) |
+ self.assertTrue(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta(minutes=58) |
+ self.assertFalse(token.ShouldRefresh(time_delta=120)) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta( |
+ minutes=58, seconds=1) |
+ self.assertTrue(token.ShouldRefresh(time_delta=120)) |
+ |
+ def testShouldRefreshNoExpiry(self): |
+ mock_datetime = MockDateTime() |
+ start = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
+ token = oauth2_client.AccessToken( |
+ 'foo', None, datetime_strategy=mock_datetime) |
+ |
+ mock_datetime.mock_now = start |
+ self.assertFalse(token.ShouldRefresh()) |
+ |
+ mock_datetime.mock_now = start + datetime.timedelta( |
+ minutes=472) |
+ self.assertFalse(token.ShouldRefresh()) |
+ |
+ def testSerialization(self): |
+ expiry = datetime.datetime(2011, 3, 1, 11, 25, 13, 300826) |
+ token = oauth2_client.AccessToken('foo', expiry) |
+ serialized_token = token.Serialize() |
+ LOG.debug('testSerialization: serialized_token=%s' % serialized_token) |
+ |
+ token2 = oauth2_client.AccessToken.UnSerialize(serialized_token) |
+ self.assertEquals(token, token2) |
+ |
+ |
+class RefreshTokenTest(unittest.TestCase): |
+ def setUp(self): |
+ self.opener = MockOpener() |
+ self.mock_datetime = MockDateTime() |
+ self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
+ self.mock_datetime.mock_now = self.start_time |
+ self.client = oauth2_client.OAuth2Client( |
+ oauth2_client.OAuth2Provider( |
+ 'Sample OAuth Provider', |
+ 'https://provider.example.com/oauth/provider?mode=authorize', |
+ 'https://provider.example.com/oauth/provider?mode=token'), |
+ 'clid', 'clsecret', |
+ url_opener=self.opener, datetime_strategy=self.mock_datetime) |
+ |
+ self.cred = oauth2_client.RefreshToken(self.client, 'ref_token_abc123') |
+ |
+ def testUniqeId(self): |
+ cred_id = self.cred.CacheKey() |
+ self.assertEquals('0720afed6871f12761fbea3271f451e6ba184bf5', cred_id) |
+ |
+ def testGetAuthorizationHeader(self): |
+ access_token = 'access_123' |
+ self.opener.open_result = ( |
+ '{"access_token":"%s","expires_in":3600}' % access_token) |
+ |
+ self.assertEquals('Bearer %s' % access_token, |
+ self.cred.GetAuthorizationHeader()) |
+ |
+ |
+class FileSystemTokenCacheTest(unittest.TestCase): |
+ |
+ def setUp(self): |
+ self.cache = oauth2_client.FileSystemTokenCache() |
+ self.start_time = datetime.datetime(2011, 3, 1, 10, 25, 13, 300826) |
+ self.token_1 = oauth2_client.AccessToken('token1', self.start_time) |
+ self.token_2 = oauth2_client.AccessToken( |
+ 'token2', self.start_time + datetime.timedelta(seconds=492)) |
+ self.key = 'token1key' |
+ |
+ def tearDown(self): |
+ try: |
+ os.unlink(self.cache.CacheFileName(self.key)) |
+ except: |
+ pass |
+ |
+ def testPut(self): |
+ self.cache.PutToken(self.key, self.token_1) |
+ # Assert that the cache file exists and has correct permissions. |
+ self.assertEquals( |
+ 0600, S_IMODE(os.stat(self.cache.CacheFileName(self.key)).st_mode)) |
+ |
+ def testPutGet(self): |
+ # No cache file present. |
+ self.assertEquals(None, self.cache.GetToken(self.key)) |
+ |
+ # Put a token |
+ self.cache.PutToken(self.key, self.token_1) |
+ cached_token = self.cache.GetToken(self.key) |
+ self.assertEquals(self.token_1, cached_token) |
+ |
+ # Put a different token |
+ self.cache.PutToken(self.key, self.token_2) |
+ cached_token = self.cache.GetToken(self.key) |
+ self.assertEquals(self.token_2, cached_token) |
+ |
+ def testGetBadFile(self): |
+ f = open(self.cache.CacheFileName(self.key), 'w') |
+ f.write('blah') |
+ f.close() |
+ self.assertEquals(None, self.cache.GetToken(self.key)) |
+ |
+ def testCacheFileName(self): |
+ cache = oauth2_client.FileSystemTokenCache( |
+ path_pattern='/var/run/ccache/token.%(uid)s.%(key)s') |
+ self.assertEquals('/var/run/ccache/token.%d.abc123' % os.getuid(), |
+ cache.CacheFileName('abc123')) |
+ |
+ cache = oauth2_client.FileSystemTokenCache( |
+ path_pattern='/var/run/ccache/token.%(key)s') |
+ self.assertEquals('/var/run/ccache/token.abc123', |
+ cache.CacheFileName('abc123')) |
+ |
+ |
+if __name__ == '__main__': |
+ logging.basicConfig(level=logging.DEBUG) |
+ unittest.main() |