OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/socket/ssl_client_socket.h" | 5 #include "net/socket/ssl_client_socket.h" |
6 | 6 |
| 7 #include "base/callback_helpers.h" |
7 #include "base/memory/ref_counted.h" | 8 #include "base/memory/ref_counted.h" |
8 #include "net/base/address_list.h" | 9 #include "net/base/address_list.h" |
9 #include "net/base/cert_test_util.h" | 10 #include "net/base/cert_test_util.h" |
10 #include "net/base/host_resolver.h" | 11 #include "net/base/host_resolver.h" |
11 #include "net/base/io_buffer.h" | 12 #include "net/base/io_buffer.h" |
12 #include "net/base/mock_cert_verifier.h" | 13 #include "net/base/mock_cert_verifier.h" |
13 #include "net/base/net_errors.h" | 14 #include "net/base/net_errors.h" |
14 #include "net/base/net_log.h" | 15 #include "net/base/net_log.h" |
15 #include "net/base/net_log_unittest.h" | 16 #include "net/base/net_log_unittest.h" |
16 #include "net/base/ssl_cert_request_info.h" | 17 #include "net/base/ssl_cert_request_info.h" |
17 #include "net/base/ssl_config_service.h" | 18 #include "net/base/ssl_config_service.h" |
18 #include "net/base/test_completion_callback.h" | 19 #include "net/base/test_completion_callback.h" |
19 #include "net/base/test_data_directory.h" | 20 #include "net/base/test_data_directory.h" |
20 #include "net/base/test_root_certs.h" | 21 #include "net/base/test_root_certs.h" |
21 #include "net/socket/client_socket_factory.h" | 22 #include "net/socket/client_socket_factory.h" |
22 #include "net/socket/client_socket_handle.h" | 23 #include "net/socket/client_socket_handle.h" |
23 #include "net/socket/socket_test_util.h" | 24 #include "net/socket/socket_test_util.h" |
24 #include "net/socket/tcp_client_socket.h" | 25 #include "net/socket/tcp_client_socket.h" |
25 #include "net/test/test_server.h" | 26 #include "net/test/test_server.h" |
26 #include "testing/gtest/include/gtest/gtest.h" | 27 #include "testing/gtest/include/gtest/gtest.h" |
27 #include "testing/platform_test.h" | 28 #include "testing/platform_test.h" |
28 | 29 |
29 //----------------------------------------------------------------------------- | 30 //----------------------------------------------------------------------------- |
30 | 31 |
| 32 namespace { |
| 33 |
31 const net::SSLConfig kDefaultSSLConfig; | 34 const net::SSLConfig kDefaultSSLConfig; |
32 | 35 |
| 36 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that |
| 37 // will ensure a certain amount of data is internally buffered before |
| 38 // satisfying a Read() request. It exists to mimic OS-level internal |
| 39 // buffering, but in a way to guarantee that X number of bytes will be |
| 40 // returned to callers of Read(), regardless of how quickly the OS receives |
| 41 // them from the TestServer. |
| 42 class ReadBufferingStreamSocket : public net::StreamSocket { |
| 43 public: |
| 44 explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); |
| 45 virtual ~ReadBufferingStreamSocket() {} |
| 46 |
| 47 // Sets the internal buffer to |size|. This must not be greater than |
| 48 // the largest value supplied to Read() - that is, it does not handle |
| 49 // having "leftovers" at the end of Read(). |
| 50 // Each call to Read() will be prevented from completion until at least |
| 51 // |size| data has been read. |
| 52 // Set to 0 to turn off buffering, causing Read() to transparently |
| 53 // read via the underlying transport. |
| 54 void SetBufferSize(int size); |
| 55 |
| 56 // StreamSocket implementation: |
| 57 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
| 58 return transport_->Connect(callback); |
| 59 } |
| 60 virtual void Disconnect() OVERRIDE { |
| 61 transport_->Disconnect(); |
| 62 } |
| 63 virtual bool IsConnected() const OVERRIDE { |
| 64 return transport_->IsConnected(); |
| 65 } |
| 66 virtual bool IsConnectedAndIdle() const OVERRIDE { |
| 67 return transport_->IsConnectedAndIdle(); |
| 68 } |
| 69 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
| 70 return transport_->GetPeerAddress(address); |
| 71 } |
| 72 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
| 73 return transport_->GetLocalAddress(address); |
| 74 } |
| 75 virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
| 76 return transport_->NetLog(); |
| 77 } |
| 78 virtual void SetSubresourceSpeculation() OVERRIDE { |
| 79 transport_->SetSubresourceSpeculation(); |
| 80 } |
| 81 virtual void SetOmniboxSpeculation() OVERRIDE { |
| 82 transport_->SetOmniboxSpeculation(); |
| 83 } |
| 84 virtual bool WasEverUsed() const OVERRIDE { |
| 85 return transport_->WasEverUsed(); |
| 86 } |
| 87 virtual bool UsingTCPFastOpen() const OVERRIDE { |
| 88 return transport_->UsingTCPFastOpen(); |
| 89 } |
| 90 virtual int64 NumBytesRead() const OVERRIDE { |
| 91 return transport_->NumBytesRead(); |
| 92 } |
| 93 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
| 94 return transport_->GetConnectTimeMicros(); |
| 95 } |
| 96 virtual bool WasNpnNegotiated() const OVERRIDE { |
| 97 return transport_->WasNpnNegotiated(); |
| 98 } |
| 99 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
| 100 return transport_->GetNegotiatedProtocol(); |
| 101 } |
| 102 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
| 103 return transport_->GetSSLInfo(ssl_info); |
| 104 } |
| 105 |
| 106 // Socket implementation: |
| 107 virtual int Read(net::IOBuffer* buf, int buf_len, |
| 108 const net::CompletionCallback& callback) OVERRIDE; |
| 109 virtual int Write(net::IOBuffer* buf, int buf_len, |
| 110 const net::CompletionCallback& callback) OVERRIDE { |
| 111 return transport_->Write(buf, buf_len, callback); |
| 112 } |
| 113 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
| 114 return transport_->SetReceiveBufferSize(size); |
| 115 } |
| 116 virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
| 117 return transport_->SetSendBufferSize(size); |
| 118 } |
| 119 |
| 120 private: |
| 121 enum State { |
| 122 STATE_NONE, |
| 123 STATE_READ, |
| 124 STATE_READ_COMPLETE, |
| 125 }; |
| 126 |
| 127 int DoLoop(int result); |
| 128 int DoRead(); |
| 129 int DoReadComplete(int result); |
| 130 void OnReadCompleted(int result); |
| 131 |
| 132 State state_; |
| 133 scoped_ptr<net::StreamSocket> transport_; |
| 134 scoped_refptr<net::GrowableIOBuffer> read_buffer_; |
| 135 int buffer_size_; |
| 136 |
| 137 scoped_refptr<net::IOBuffer> user_read_buf_; |
| 138 net::CompletionCallback user_read_callback_; |
| 139 }; |
| 140 |
| 141 ReadBufferingStreamSocket::ReadBufferingStreamSocket( |
| 142 scoped_ptr<net::StreamSocket> transport) |
| 143 : transport_(transport.Pass()), |
| 144 read_buffer_(new net::GrowableIOBuffer()), |
| 145 buffer_size_(0) { |
| 146 } |
| 147 |
| 148 void ReadBufferingStreamSocket::SetBufferSize(int size) { |
| 149 DCHECK(!user_read_buf_); |
| 150 buffer_size_ = size; |
| 151 read_buffer_->SetCapacity(size); |
| 152 } |
| 153 |
| 154 int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, |
| 155 int buf_len, |
| 156 const net::CompletionCallback& callback) { |
| 157 if (buffer_size_ == 0) |
| 158 return transport_->Read(buf, buf_len, callback); |
| 159 |
| 160 if (buf_len < buffer_size_) |
| 161 return net::ERR_UNEXPECTED; |
| 162 |
| 163 state_ = STATE_READ; |
| 164 user_read_buf_ = buf; |
| 165 int result = DoLoop(net::OK); |
| 166 if (result == net::ERR_IO_PENDING) |
| 167 user_read_callback_ = callback; |
| 168 else |
| 169 user_read_buf_ = NULL; |
| 170 return result; |
| 171 } |
| 172 |
| 173 int ReadBufferingStreamSocket::DoLoop(int result) { |
| 174 int rv = result; |
| 175 do { |
| 176 State current_state = state_; |
| 177 state_ = STATE_NONE; |
| 178 switch (current_state) { |
| 179 case STATE_READ: |
| 180 rv = DoRead(); |
| 181 break; |
| 182 case STATE_READ_COMPLETE: |
| 183 rv = DoReadComplete(rv); |
| 184 break; |
| 185 case STATE_NONE: |
| 186 default: |
| 187 NOTREACHED() << "Unexpected state: " << current_state; |
| 188 rv = net::ERR_UNEXPECTED; |
| 189 break; |
| 190 } |
| 191 } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); |
| 192 return rv; |
| 193 } |
| 194 |
| 195 int ReadBufferingStreamSocket::DoRead() { |
| 196 state_ = STATE_READ_COMPLETE; |
| 197 int rv = transport_->Read( |
| 198 read_buffer_, |
| 199 read_buffer_->RemainingCapacity(), |
| 200 base::Bind(&ReadBufferingStreamSocket::OnReadCompleted, |
| 201 base::Unretained(this))); |
| 202 return rv; |
| 203 } |
| 204 |
| 205 int ReadBufferingStreamSocket::DoReadComplete(int result) { |
| 206 state_ = STATE_NONE; |
| 207 if (result <= 0) |
| 208 return result; |
| 209 |
| 210 read_buffer_->set_offset(read_buffer_->offset() + result); |
| 211 if (read_buffer_->RemainingCapacity() > 0) { |
| 212 state_ = STATE_READ; |
| 213 return net::OK; |
| 214 } |
| 215 |
| 216 memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), |
| 217 read_buffer_->capacity()); |
| 218 read_buffer_->set_offset(0); |
| 219 return read_buffer_->capacity(); |
| 220 } |
| 221 |
| 222 void ReadBufferingStreamSocket::OnReadCompleted(int result) { |
| 223 result = DoLoop(result); |
| 224 if (result == net::ERR_IO_PENDING) |
| 225 return; |
| 226 |
| 227 user_read_buf_ = NULL; |
| 228 base::ResetAndReturn(&user_read_callback_).Run(result); |
| 229 } |
| 230 |
| 231 } // namespace |
| 232 |
33 class SSLClientSocketTest : public PlatformTest { | 233 class SSLClientSocketTest : public PlatformTest { |
34 public: | 234 public: |
35 SSLClientSocketTest() | 235 SSLClientSocketTest() |
36 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), | 236 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), |
37 cert_verifier_(new net::MockCertVerifier) { | 237 cert_verifier_(new net::MockCertVerifier) { |
38 cert_verifier_->set_default_result(net::OK); | 238 cert_verifier_->set_default_result(net::OK); |
39 context_.cert_verifier = cert_verifier_.get(); | 239 context_.cert_verifier = cert_verifier_.get(); |
40 } | 240 } |
41 | 241 |
42 protected: | 242 protected: |
(...skipping 449 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
492 | 692 |
493 if (rv == net::ERR_IO_PENDING) | 693 if (rv == net::ERR_IO_PENDING) |
494 rv = callback.WaitForResult(); | 694 rv = callback.WaitForResult(); |
495 | 695 |
496 EXPECT_GE(rv, 0); | 696 EXPECT_GE(rv, 0); |
497 if (rv <= 0) | 697 if (rv <= 0) |
498 break; | 698 break; |
499 } | 699 } |
500 } | 700 } |
501 | 701 |
| 702 TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { |
| 703 net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
| 704 net::TestServer::kLocalhost, |
| 705 base::FilePath()); |
| 706 ASSERT_TRUE(test_server.Start()); |
| 707 |
| 708 net::AddressList addr; |
| 709 ASSERT_TRUE(test_server.GetAddressList(&addr)); |
| 710 |
| 711 net::TestCompletionCallback callback; |
| 712 |
| 713 scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( |
| 714 addr, NULL, net::NetLog::Source())); |
| 715 ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( |
| 716 real_transport.Pass()); |
| 717 int rv = callback.GetResult(transport->Connect(callback.callback())); |
| 718 ASSERT_EQ(net::OK, rv); |
| 719 |
| 720 scoped_ptr<net::SSLClientSocket> sock( |
| 721 CreateSSLClientSocket(transport, test_server.host_port_pair(), |
| 722 kDefaultSSLConfig)); |
| 723 |
| 724 rv = callback.GetResult(sock->Connect(callback.callback())); |
| 725 ASSERT_EQ(net::OK, rv); |
| 726 ASSERT_TRUE(sock->IsConnected()); |
| 727 |
| 728 const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; |
| 729 scoped_refptr<net::IOBuffer> request_buffer( |
| 730 new net::IOBuffer(arraysize(request_text) - 1)); |
| 731 memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); |
| 732 |
| 733 rv = callback.GetResult(sock->Write( |
| 734 request_buffer, arraysize(request_text) - 1, callback.callback())); |
| 735 ASSERT_GT(rv, 0); |
| 736 ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); |
| 737 |
| 738 // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of |
| 739 // data (the max SSL record size) at a time. Ensure that at least 15K worth |
| 740 // of SSL data is buffered first. The 15K of buffered data is made up of |
| 741 // many smaller SSL records (the TestServer writes along 1350 byte |
| 742 // plaintext boundaries), although there may also be a few records that are |
| 743 // smaller or larger, due to timing and SSL False Start. |
| 744 // 15K was chosen because 15K is smaller than the 17K (max) read issued by |
| 745 // the SSLClientSocket implementation, and larger than the minimum amount |
| 746 // of ciphertext necessary to contain the 8K of plaintext requested below. |
| 747 transport->SetBufferSize(15000); |
| 748 |
| 749 scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); |
| 750 rv = callback.GetResult(sock->Read(buffer, 8192, callback.callback())); |
| 751 ASSERT_EQ(rv, 8192); |
| 752 } |
| 753 |
502 TEST_F(SSLClientSocketTest, Read_Interrupted) { | 754 TEST_F(SSLClientSocketTest, Read_Interrupted) { |
503 net::TestServer test_server(net::TestServer::TYPE_HTTPS, | 755 net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
504 net::TestServer::kLocalhost, | 756 net::TestServer::kLocalhost, |
505 base::FilePath()); | 757 base::FilePath()); |
506 ASSERT_TRUE(test_server.Start()); | 758 ASSERT_TRUE(test_server.Start()); |
507 | 759 |
508 net::AddressList addr; | 760 net::AddressList addr; |
509 ASSERT_TRUE(test_server.GetAddressList(&addr)); | 761 ASSERT_TRUE(test_server.GetAddressList(&addr)); |
510 | 762 |
511 net::TestCompletionCallback callback; | 763 net::TestCompletionCallback callback; |
(...skipping 148 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
660 scoped_ptr<net::SSLClientSocket> sock( | 912 scoped_ptr<net::SSLClientSocket> sock( |
661 CreateSSLClientSocket(transport, test_server.host_port_pair(), | 913 CreateSSLClientSocket(transport, test_server.host_port_pair(), |
662 kDefaultSSLConfig)); | 914 kDefaultSSLConfig)); |
663 | 915 |
664 rv = sock->Connect(callback.callback()); | 916 rv = sock->Connect(callback.callback()); |
665 if (rv == net::ERR_IO_PENDING) | 917 if (rv == net::ERR_IO_PENDING) |
666 rv = callback.WaitForResult(); | 918 rv = callback.WaitForResult(); |
667 EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); | 919 EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); |
668 } | 920 } |
669 | 921 |
670 // TODO(rsleevi): Not implemented for Schannel. As Schannel is only used when | |
671 // performing client authentication, it will not be tested here. | |
672 TEST_F(SSLClientSocketTest, CipherSuiteDisables) { | 922 TEST_F(SSLClientSocketTest, CipherSuiteDisables) { |
673 // Rather than exhaustively disabling every RC4 ciphersuite defined at | 923 // Rather than exhaustively disabling every RC4 ciphersuite defined at |
674 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, | 924 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, |
675 // only disabling those cipher suites that the test server actually | 925 // only disabling those cipher suites that the test server actually |
676 // implements. | 926 // implements. |
677 const uint16 kCiphersToDisable[] = { | 927 const uint16 kCiphersToDisable[] = { |
678 0x0005, // TLS_RSA_WITH_RC4_128_SHA | 928 0x0005, // TLS_RSA_WITH_RC4_128_SHA |
679 }; | 929 }; |
680 | 930 |
681 net::TestServer::SSLOptions ssl_options; | 931 net::TestServer::SSLOptions ssl_options; |
(...skipping 358 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1040 scoped_refptr<net::SSLCertRequestInfo> request_info = | 1290 scoped_refptr<net::SSLCertRequestInfo> request_info = |
1041 GetCertRequest(ssl_options); | 1291 GetCertRequest(ssl_options); |
1042 ASSERT_TRUE(request_info); | 1292 ASSERT_TRUE(request_info); |
1043 ASSERT_EQ(2u, request_info->cert_authorities.size()); | 1293 ASSERT_EQ(2u, request_info->cert_authorities.size()); |
1044 EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), | 1294 EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), |
1045 request_info->cert_authorities[0]); | 1295 request_info->cert_authorities[0]); |
1046 EXPECT_EQ( | 1296 EXPECT_EQ( |
1047 std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), | 1297 std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), |
1048 request_info->cert_authorities[1]); | 1298 request_info->cert_authorities[1]); |
1049 } | 1299 } |
OLD | NEW |