Index: third_party/upload.py |
diff --git a/third_party/upload.py b/third_party/upload.py |
index aaf843fb8df5ccb8c3b9fda12b1be064e0c89010..b2a5193ad850204070059bf9b549bae497a14c7e 100755 |
--- a/third_party/upload.py |
+++ b/third_party/upload.py |
@@ -112,12 +112,19 @@ VCS_ABBREVIATIONS = { |
LOCALHOST_IP = '127.0.0.1' |
DEFAULT_OAUTH2_PORT = 8001 |
ACCESS_TOKEN_PARAM = 'access_token' |
+ERROR_PARAM = 'error' |
+OAUTH_DEFAULT_ERROR_MESSAGE = 'OAuth 2.0 error occurred.' |
OAUTH_PATH = '/get-access-token' |
OAUTH_PATH_PORT_TEMPLATE = OAUTH_PATH + '?port=%(port)d' |
AUTH_HANDLER_RESPONSE = """\ |
<html> |
<head> |
<title>Authentication Status</title> |
+ <script> |
+ window.onload = function() { |
+ window.close(); |
+ } |
+ </script> |
</head> |
<body> |
<p>The authentication flow has completed.</p> |
@@ -673,31 +680,37 @@ class ClientRedirectServer(BaseHTTPServer.HTTPServer): |
"""A server for redirects back to localhost from the associated server. |
Waits for a single request and parses the query parameters for an access token |
- and then stops serving. |
+ or an error and then stops serving. |
""" |
access_token = None |
+ error = None |
class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): |
"""A handler for redirects back to localhost from the associated server. |
Waits for a single request and parses the query parameters into the server's |
- access_token and then stops serving. |
+ access_token or error and then stops serving. |
""" |
- def SetAccessToken(self): |
- """Stores the access token from the request on the server. |
+ def SetResponseValue(self): |
+ """Stores the access token or error from the request on the server. |
Will only do this if exactly one query parameter was passed in to the |
- request and that query parameter used 'access_token' as the key. |
+ request and that query parameter used 'access_token' or 'error' as the key. |
""" |
query_string = urlparse.urlparse(self.path).query |
query_params = urlparse.parse_qs(query_string) |
if len(query_params) == 1: |
- access_token_list = query_params.get(ACCESS_TOKEN_PARAM, []) |
- if len(access_token_list) == 1: |
- self.server.access_token = access_token_list[0] |
+ if query_params.has_key(ACCESS_TOKEN_PARAM): |
+ access_token_list = query_params[ACCESS_TOKEN_PARAM] |
+ if len(access_token_list) == 1: |
+ self.server.access_token = access_token_list[0] |
+ else: |
+ error_list = query_params.get(ERROR_PARAM, []) |
+ if len(error_list) == 1: |
+ self.server.error = error_list[0] |
def do_GET(self): |
"""Handle a GET request. |
@@ -710,7 +723,7 @@ class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): |
self.send_response(200) |
self.send_header('Content-type', 'text/html') |
self.end_headers() |
- self.SetAccessToken() |
+ self.SetResponseValue() |
self.wfile.write(AUTH_HANDLER_RESPONSE) |
def log_message(self, format, *args): |
@@ -729,11 +742,23 @@ def OpenOAuth2ConsentPage(server=DEFAULT_REVIEW_SERVER, |
DEFAULT_REVIEW_SERVER. |
port: Integer, the port where the localhost server receiving the redirect |
is serving. Defaults to DEFAULT_OAUTH2_PORT. |
+ |
+ Returns: |
+ A boolean indicating whether the page opened successfully. |
""" |
path = OAUTH_PATH_PORT_TEMPLATE % {'port': port} |
- page = 'https://%s%s' % (server, path) |
- webbrowser.open(page, new=1, autoraise=True) |
- print OPEN_LOCAL_MESSAGE_TEMPLATE % (page,) |
+ parsed_url = urlparse.urlparse(server) |
+ scheme = parsed_url[0] or 'https' |
+ if scheme != 'https': |
+ ErrorExit('Using OAuth requires a review server with SSL enabled.') |
+ # If no scheme was given on command line the server address ends up in |
+ # parsed_url.path otherwise in netloc. |
+ host = parsed_url[1] or parsed_url[2] |
+ page = '%s://%s%s' % (scheme, host, path) |
+ page_opened = webbrowser.open(page, new=1, autoraise=True) |
+ if page_opened: |
+ print OPEN_LOCAL_MESSAGE_TEMPLATE % (page,) |
+ return page_opened |
def WaitForAccessToken(port=DEFAULT_OAUTH2_PORT): |
@@ -754,6 +779,8 @@ def WaitForAccessToken(port=DEFAULT_OAUTH2_PORT): |
# Wait to serve just one request before deferring control back |
# to the caller of wait_for_refresh_token |
httpd.handle_request() |
+ if httpd.access_token is None: |
+ ErrorExit(httpd.error or OAUTH_DEFAULT_ERROR_MESSAGE) |
return httpd.access_token |
@@ -776,11 +803,12 @@ def GetAccessToken(server=DEFAULT_REVIEW_SERVER, port=DEFAULT_OAUTH2_PORT, |
""" |
access_token = None |
if open_local_webbrowser: |
- OpenOAuth2ConsentPage(server=server, port=port) |
- try: |
- access_token = WaitForAccessToken(port=port) |
- except socket.error, e: |
- print 'Can\'t start local webserver. Socket Error: %s\n' % (e.strerror,) |
+ page_opened = OpenOAuth2ConsentPage(server=server, port=port) |
+ if page_opened: |
+ try: |
+ access_token = WaitForAccessToken(port=port) |
+ except socket.error, e: |
+ print 'Can\'t start local webserver. Socket Error: %s\n' % (e.strerror,) |
if access_token is None: |
# TODO(dhermes): Offer to add to clipboard using xsel, xclip, pbcopy, etc. |
@@ -1150,7 +1178,7 @@ class VersionControlSystem(object): |
mimetype = mimetypes.guess_type(filename)[0] |
if not mimetype: |
return False |
- return mimetype.startswith("image/") |
+ return mimetype.startswith("image/") and not mimetype.startswith("image/svg") |
def IsBinaryData(self, data): |
"""Returns true if data contains a null byte.""" |