Index: net/socket/ssl_client_socket_openssl.cc |
diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc |
index 7192d850e6a99d5efee473ffd727ca65912d5223..a6f5caf188d96dac7c1a339e732990cb05ebd6f3 100644 |
--- a/net/socket/ssl_client_socket_openssl.cc |
+++ b/net/socket/ssl_client_socket_openssl.cc |
@@ -95,15 +95,6 @@ int GetNetSSLVersion(SSL* ssl) { |
} |
} |
-// Compute a unique key string for the SSL session cache. |socket| is an |
-// input socket object. Return a string. |
-std::string GetSocketSessionCacheKey(const SSLClientSocketOpenSSL& socket) { |
- std::string result = socket.host_and_port().ToString(); |
- result.append("/"); |
- result.append(socket.ssl_session_cache_shard()); |
- return result; |
-} |
- |
void FreeX509Stack(STACK_OF(X509) * ptr) { |
sk_X509_pop_free(ptr, X509_free); |
} |
@@ -164,7 +155,7 @@ class SSLClientSocketOpenSSL::SSLContext { |
static std::string GetSessionCacheKey(const SSL* ssl) { |
SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); |
DCHECK(socket); |
- return GetSocketSessionCacheKey(*socket); |
+ return socket->GetSessionCacheKey(); |
} |
static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig; |
@@ -372,12 +363,24 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( |
next_handshake_state_(STATE_NONE), |
npn_status_(kNextProtoUnsupported), |
channel_id_xtn_negotiated_(false), |
- net_log_(transport_->socket()->NetLog()) {} |
+ net_log_(transport_->socket()->NetLog()) { |
+} |
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { |
Disconnect(); |
} |
+bool SSLClientSocketOpenSSL::InSessionCache() const { |
+ SSLContext* context = SSLContext::GetInstance(); |
+ std::string cache_key = GetSessionCacheKey(); |
+ return context->session_cache()->SSLSessionIsInCache(cache_key); |
+} |
+ |
+void SSLClientSocketOpenSSL::SetHandshakeCompletionCallback( |
+ const base::Closure& callback) { |
+ handshake_completion_callback_ = callback; |
+} |
+ |
void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( |
SSLCertRequestInfo* cert_request_info) { |
cert_request_info->host_and_port = host_and_port_; |
@@ -432,6 +435,14 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { |
return rv; |
} |
+ if (!handshake_completion_callback_.is_null()) { |
+ SSLContext* context = SSLContext::GetInstance(); |
+ context->session_cache()->SetSessionAddedCallback( |
+ ssl_, |
+ base::Bind(&SSLClientSocketOpenSSL::OnHandshakeCompletion, |
+ base::Unretained(this))); |
+ } |
+ |
// Set SSL to client mode. Handshake happens in the loop below. |
SSL_set_connect_state(ssl_); |
@@ -441,13 +452,21 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { |
user_connect_callback_ = callback; |
} else { |
net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); |
+ if (rv < OK) |
+ OnHandshakeCompletion(); |
} |
return rv > OK ? OK : rv; |
} |
void SSLClientSocketOpenSSL::Disconnect() { |
+ // If a handshake was pending (Connect() had been called), notify interested |
+ // parties that it's been aborted now. If the handshake had already |
+ // completed, this is a no-op. |
+ OnHandshakeCompletion(); |
if (ssl_) { |
+ SSLContext* context = SSLContext::GetInstance(); |
+ context->session_cache()->RemoveSessionAddedCallback(ssl_); |
// Calling SSL_shutdown prevents the session from being marked as |
// unresumable. |
SSL_shutdown(ssl_); |
@@ -625,6 +644,11 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf, |
was_ever_used_ = true; |
user_read_buf_ = NULL; |
user_read_buf_len_ = 0; |
+ if (rv <= 0) { |
+ // Failure of a read attempt may indicate a failed false start |
+ // connection. |
+ OnHandshakeCompletion(); |
+ } |
} |
return rv; |
@@ -645,6 +669,11 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, |
was_ever_used_ = true; |
user_write_buf_ = NULL; |
user_write_buf_len_ = 0; |
+ if (rv < 0) { |
+ // Failure of a write attempt may indicate a failed false start |
+ // connection. |
+ OnHandshakeCompletion(); |
+ } |
} |
return rv; |
@@ -673,7 +702,7 @@ int SSLClientSocketOpenSSL::Init() { |
return ERR_UNEXPECTED; |
trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey( |
- ssl_, GetSocketSessionCacheKey(*this)); |
+ ssl_, GetSessionCacheKey()); |
BIO* ssl_bio = NULL; |
// 0 => use default buffer sizes. |
@@ -791,6 +820,11 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) { |
was_ever_used_ = true; |
user_read_buf_ = NULL; |
user_read_buf_len_ = 0; |
+ if (rv <= 0) { |
+ // Failure of a read attempt may indicate a failed false start |
+ // connection. |
+ OnHandshakeCompletion(); |
+ } |
base::ResetAndReturn(&user_read_callback_).Run(rv); |
} |
@@ -801,9 +835,23 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { |
was_ever_used_ = true; |
user_write_buf_ = NULL; |
user_write_buf_len_ = 0; |
+ if (rv < 0) { |
+ // Failure of a write attempt may indicate a failed false start |
+ // connection. |
+ OnHandshakeCompletion(); |
+ } |
base::ResetAndReturn(&user_write_callback_).Run(rv); |
} |
+std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { |
+ return CreateSessionCacheKey(host_and_port_, ssl_session_cache_shard_); |
+} |
+ |
+void SSLClientSocketOpenSSL::OnHandshakeCompletion() { |
+ if (!handshake_completion_callback_.is_null()) |
+ base::ResetAndReturn(&handshake_completion_callback_).Run(); |
+} |
+ |
bool SSLClientSocketOpenSSL::DoTransportIO() { |
bool network_moved = false; |
int rv; |
@@ -996,6 +1044,8 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { |
} |
void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { |
+ if (rv < OK) |
+ OnHandshakeCompletion(); |
if (!user_connect_callback_.is_null()) { |
CompletionCallback c = user_connect_callback_; |
user_connect_callback_.Reset(); |
@@ -1116,6 +1166,7 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { |
rv = OK; // This causes us to stay in the loop. |
} |
} while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); |
+ |
return rv; |
} |
@@ -1217,7 +1268,6 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { |
int SSLClientSocketOpenSSL::DoPayloadWrite() { |
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); |
int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); |
- |
if (rv >= 0) { |
net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, |
user_write_buf_->data()); |