| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" | 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" |
| 6 | 6 |
| 7 #include "base/base64.h" | 7 #include "base/base64.h" |
| 8 #include "base/compiler_specific.h" | 8 #include "base/compiler_specific.h" |
| 9 #include "base/message_loop.h" | 9 #include "base/message_loop.h" |
| 10 #include "jingle/glue/utils.h" | 10 #include "jingle/glue/utils.h" |
| 11 #include "net/base/address_list.h" | 11 #include "net/base/address_list.h" |
| 12 #include "net/base/cert_verifier.h" | 12 #include "net/base/cert_verifier.h" |
| 13 #include "net/base/host_port_pair.h" | 13 #include "net/base/host_port_pair.h" |
| 14 #include "net/base/net_errors.h" | 14 #include "net/base/net_errors.h" |
| 15 #include "net/base/ssl_config_service.h" | 15 #include "net/base/ssl_config_service.h" |
| 16 #include "net/socket/client_socket_factory.h" | 16 #include "net/socket/client_socket_factory.h" |
| 17 #include "net/url_request/url_request_context.h" | 17 #include "net/url_request/url_request_context.h" |
| 18 | 18 |
| 19 namespace remoting { | 19 namespace remoting { |
| 20 | 20 |
| 21 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { | 21 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { |
| 22 return new SSLSocketAdapter(socket); | 22 return new SSLSocketAdapter(socket); |
| 23 } | 23 } |
| 24 | 24 |
| 25 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) | 25 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) |
| 26 : SSLAdapter(socket), | 26 : SSLAdapter(socket), |
| 27 ignore_bad_cert_(false), | 27 ignore_bad_cert_(false), |
| 28 cert_verifier_(net::CertVerifier::CreateDefault()), | 28 cert_verifier_(net::CertVerifier::CreateDefault()), |
| 29 ssl_state_(SSLSTATE_NONE), | 29 ssl_state_(SSLSTATE_NONE), |
| 30 read_state_(IOSTATE_NONE), | 30 read_pending_(false), |
| 31 write_state_(IOSTATE_NONE), | 31 write_pending_(false) { |
| 32 data_transferred_(0) { | |
| 33 transport_socket_ = new TransportSocket(socket, this); | 32 transport_socket_ = new TransportSocket(socket, this); |
| 34 } | 33 } |
| 35 | 34 |
| 36 SSLSocketAdapter::~SSLSocketAdapter() { | 35 SSLSocketAdapter::~SSLSocketAdapter() { |
| 37 } | 36 } |
| 38 | 37 |
| 39 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { | 38 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { |
| 40 DCHECK(!restartable); | 39 DCHECK(!restartable); |
| 41 hostname_ = hostname; | 40 hostname_ = hostname; |
| 42 | 41 |
| (...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 76 | 75 |
| 77 if (result == net::ERR_IO_PENDING || result == net::OK) { | 76 if (result == net::ERR_IO_PENDING || result == net::OK) { |
| 78 return 0; | 77 return 0; |
| 79 } else { | 78 } else { |
| 80 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); | 79 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); |
| 81 return result; | 80 return result; |
| 82 } | 81 } |
| 83 } | 82 } |
| 84 | 83 |
| 85 int SSLSocketAdapter::Send(const void* buf, size_t len) { | 84 int SSLSocketAdapter::Send(const void* buf, size_t len) { |
| 86 if (ssl_state_ != SSLSTATE_CONNECTED) { | 85 if (ssl_state_ == SSLSTATE_ERROR) { |
| 86 SetError(EINVAL); |
| 87 return -1; |
| 88 } |
| 89 |
| 90 if (ssl_state_ == SSLSTATE_NONE) { |
| 91 // Propagate the call to underlying socket if SSL is not connected |
| 92 // yet (connection is not encrypted until StartSSL() is called). |
| 87 return AsyncSocketAdapter::Send(buf, len); | 93 return AsyncSocketAdapter::Send(buf, len); |
| 88 } else { | 94 } |
| 89 scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); | |
| 90 memcpy(transport_buf->data(), buf, len); | |
| 91 | 95 |
| 92 int result = ssl_socket_->Write(transport_buf, len, | 96 if (write_pending_) { |
| 93 net::CompletionCallback()); | 97 SetError(EWOULDBLOCK); |
| 94 if (result == net::ERR_IO_PENDING) { | 98 return -1; |
| 95 SetError(EWOULDBLOCK); | |
| 96 } | |
| 97 transport_buf = NULL; | |
| 98 return result; | |
| 99 } | 99 } |
| 100 |
| 101 write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len); |
| 102 memcpy(write_buffer_->data(), buf, len); |
| 103 |
| 104 DoWrite(); |
| 105 |
| 106 return len; |
| 100 } | 107 } |
| 101 | 108 |
| 102 int SSLSocketAdapter::Recv(void* buf, size_t len) { | 109 int SSLSocketAdapter::Recv(void* buf, size_t len) { |
| 103 switch (ssl_state_) { | 110 switch (ssl_state_) { |
| 104 case SSLSTATE_NONE: | 111 case SSLSTATE_NONE: { |
| 105 return AsyncSocketAdapter::Recv(buf, len); | 112 return AsyncSocketAdapter::Recv(buf, len); |
| 113 } |
| 106 | 114 |
| 107 case SSLSTATE_WAIT: | 115 case SSLSTATE_WAIT: { |
| 108 SetError(EWOULDBLOCK); | 116 SetError(EWOULDBLOCK); |
| 109 return -1; | 117 return -1; |
| 118 } |
| 110 | 119 |
| 111 case SSLSTATE_CONNECTED: | 120 case SSLSTATE_CONNECTED: { |
| 112 switch (read_state_) { | 121 if (read_pending_) { |
| 113 case IOSTATE_NONE: { | 122 SetError(EWOULDBLOCK); |
| 114 transport_buf_ = new net::IOBuffer(len); | 123 return -1; |
| 115 int result = ssl_socket_->Read( | 124 } |
| 116 transport_buf_, len, | |
| 117 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); | |
| 118 if (result >= 0) { | |
| 119 memcpy(buf, transport_buf_->data(), len); | |
| 120 } | |
| 121 | 125 |
| 122 if (result == net::ERR_IO_PENDING) { | 126 int bytes_read = 0; |
| 123 read_state_ = IOSTATE_PENDING; | 127 |
| 124 SetError(EWOULDBLOCK); | 128 // Process any data we have left from the previous read. |
| 125 } else { | 129 if (read_buffer_) { |
| 126 if (result < 0) { | 130 int size = std::min(read_buffer_->RemainingCapacity(), |
| 127 SetError(result); | 131 static_cast<int>(len)); |
| 128 VLOG(1) << "Socket error " << result; | 132 memcpy(buf, read_buffer_->data(), size); |
| 129 } | 133 read_buffer_->set_offset(read_buffer_->offset() + size); |
| 130 transport_buf_ = NULL; | 134 if (!read_buffer_->RemainingCapacity()) |
| 131 } | 135 read_buffer_ = NULL; |
| 132 return result; | 136 |
| 133 } | 137 if (size == static_cast<int>(len)) |
| 134 case IOSTATE_PENDING: | 138 return size; |
| 139 |
| 140 // If we didn't fill the caller's buffer then dispatch a new |
| 141 // Read() in case there's more data ready. |
| 142 buf = reinterpret_cast<char*>(buf) + size; |
| 143 len -= size; |
| 144 bytes_read = size; |
| 145 DCHECK(!read_buffer_); |
| 146 } |
| 147 |
| 148 // Dispatch a Read() request to the SSL layer. |
| 149 read_buffer_ = new net::GrowableIOBuffer(); |
| 150 read_buffer_->SetCapacity(len); |
| 151 int result = ssl_socket_->Read( |
| 152 read_buffer_, len, |
| 153 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); |
| 154 if (result >= 0) |
| 155 memcpy(buf, read_buffer_->data(), len); |
| 156 |
| 157 if (result == net::ERR_IO_PENDING) { |
| 158 read_pending_ = true; |
| 159 if (bytes_read) { |
| 160 return bytes_read; |
| 161 } else { |
| 135 SetError(EWOULDBLOCK); | 162 SetError(EWOULDBLOCK); |
| 136 return -1; | 163 return -1; |
| 164 } |
| 165 } |
| 137 | 166 |
| 138 case IOSTATE_COMPLETE: | 167 if (result < 0) { |
| 139 memcpy(buf, transport_buf_->data(), len); | 168 SetError(EINVAL); |
| 140 transport_buf_ = NULL; | 169 ssl_state_ = SSLSTATE_ERROR; |
| 141 read_state_ = IOSTATE_NONE; | 170 LOG(ERROR) << "Error reading from SSL socket " << result; |
| 142 return data_transferred_; | 171 return -1; |
| 143 } | 172 } |
| 173 read_buffer_ = NULL; |
| 174 return result + bytes_read; |
| 175 } |
| 176 |
| 177 case SSLSTATE_ERROR: { |
| 178 SetError(EINVAL); |
| 179 return -1; |
| 180 } |
| 144 } | 181 } |
| 145 | 182 |
| 146 NOTREACHED(); | 183 NOTREACHED(); |
| 147 return -1; | 184 return -1; |
| 148 } | 185 } |
| 149 | 186 |
| 150 void SSLSocketAdapter::OnConnected(int result) { | 187 void SSLSocketAdapter::OnConnected(int result) { |
| 151 if (result == net::OK) { | 188 if (result == net::OK) { |
| 152 ssl_state_ = SSLSTATE_CONNECTED; | 189 ssl_state_ = SSLSTATE_CONNECTED; |
| 153 OnConnectEvent(this); | 190 OnConnectEvent(this); |
| 154 } else { | 191 } else { |
| 155 LOG(WARNING) << "OnConnected failed with error " << result; | 192 LOG(WARNING) << "OnConnected failed with error " << result; |
| 156 } | 193 } |
| 157 } | 194 } |
| 158 | 195 |
| 159 void SSLSocketAdapter::OnRead(int result) { | 196 void SSLSocketAdapter::OnRead(int result) { |
| 160 DCHECK(read_state_ == IOSTATE_PENDING); | 197 DCHECK(read_pending_); |
| 161 read_state_ = IOSTATE_COMPLETE; | 198 read_pending_ = false; |
| 162 data_transferred_ = result; | 199 if (result > 0) { |
| 200 DCHECK_GE(read_buffer_->capacity(), result); |
| 201 read_buffer_->SetCapacity(result); |
| 202 } else { |
| 203 if (result < 0) |
| 204 ssl_state_ = SSLSTATE_ERROR; |
| 205 } |
| 163 AsyncSocketAdapter::OnReadEvent(this); | 206 AsyncSocketAdapter::OnReadEvent(this); |
| 164 } | 207 } |
| 165 | 208 |
| 166 void SSLSocketAdapter::OnWrite(int result) { | 209 void SSLSocketAdapter::OnWritten(int result) { |
| 167 DCHECK(write_state_ == IOSTATE_PENDING); | 210 DCHECK(write_pending_); |
| 168 write_state_ = IOSTATE_COMPLETE; | 211 write_pending_ = false; |
| 169 data_transferred_ = result; | 212 if (result >= 0) { |
| 213 write_buffer_->DidConsume(result); |
| 214 if (!write_buffer_->BytesRemaining()) { |
| 215 write_buffer_ = NULL; |
| 216 } else { |
| 217 DoWrite(); |
| 218 } |
| 219 } else { |
| 220 ssl_state_ = SSLSTATE_ERROR; |
| 221 } |
| 170 AsyncSocketAdapter::OnWriteEvent(this); | 222 AsyncSocketAdapter::OnWriteEvent(this); |
| 171 } | 223 } |
| 172 | 224 |
| 225 void SSLSocketAdapter::DoWrite() { |
| 226 DCHECK_GT(write_buffer_->BytesRemaining(), 0); |
| 227 DCHECK(!write_pending_); |
| 228 |
| 229 while (true) { |
| 230 int result = ssl_socket_->Write( |
| 231 write_buffer_, write_buffer_->BytesRemaining(), |
| 232 base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this))); |
| 233 |
| 234 if (result > 0) { |
| 235 write_buffer_->DidConsume(result); |
| 236 if (!write_buffer_->BytesRemaining()) { |
| 237 write_buffer_ = NULL; |
| 238 return; |
| 239 } |
| 240 continue; |
| 241 } |
| 242 |
| 243 if (result == net::ERR_IO_PENDING) { |
| 244 write_pending_ = true; |
| 245 } else { |
| 246 SetError(EINVAL); |
| 247 ssl_state_ = SSLSTATE_ERROR; |
| 248 } |
| 249 return; |
| 250 } |
| 251 } |
| 252 |
| 173 void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { | 253 void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { |
| 174 if (ssl_state_ != SSLSTATE_WAIT) { | 254 if (ssl_state_ != SSLSTATE_WAIT) { |
| 175 AsyncSocketAdapter::OnConnectEvent(socket); | 255 AsyncSocketAdapter::OnConnectEvent(socket); |
| 176 } else { | 256 } else { |
| 177 ssl_state_ = SSLSTATE_NONE; | 257 ssl_state_ = SSLSTATE_NONE; |
| 178 int result = BeginSSL(); | 258 int result = BeginSSL(); |
| 179 if (0 != result) { | 259 if (0 != result) { |
| 180 // TODO(zork): Handle this case gracefully. | 260 // TODO(zork): Handle this case gracefully. |
| 181 LOG(WARNING) << "BeginSSL() failed with " << result; | 261 LOG(WARNING) << "BeginSSL() failed with " << result; |
| 182 } | 262 } |
| (...skipping 183 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 366 write_buffer_len_ = buffer_len; | 446 write_buffer_len_ = buffer_len; |
| 367 return; | 447 return; |
| 368 } | 448 } |
| 369 } | 449 } |
| 370 was_used_to_convey_data_ = true; | 450 was_used_to_convey_data_ = true; |
| 371 callback.Run(result); | 451 callback.Run(result); |
| 372 } | 452 } |
| 373 } | 453 } |
| 374 | 454 |
| 375 } // namespace remoting | 455 } // namespace remoting |
| OLD | NEW |