OLD | NEW |
(Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "remoting/protocol/channel_multiplexer.h" |
| 6 |
| 7 #include <string.h> |
| 8 |
| 9 #include "base/bind.h" |
| 10 #include "base/callback.h" |
| 11 #include "base/location.h" |
| 12 #include "base/stl_util.h" |
| 13 #include "net/base/net_errors.h" |
| 14 #include "net/socket/stream_socket.h" |
| 15 #include "remoting/protocol/util.h" |
| 16 |
| 17 namespace remoting { |
| 18 namespace protocol { |
| 19 |
| 20 namespace { |
| 21 const int kChannelIdUnknown = -1; |
| 22 const int kMaxPacketSize = 1024; |
| 23 |
| 24 class PendingPacket { |
| 25 public: |
| 26 PendingPacket(scoped_ptr<MultiplexPacket> packet, |
| 27 const base::Closure& done_task) |
| 28 : packet(packet.Pass()), |
| 29 done_task(done_task), |
| 30 pos(0U) { |
| 31 } |
| 32 ~PendingPacket() { |
| 33 done_task.Run(); |
| 34 } |
| 35 |
| 36 bool is_empty() { return pos >= packet->data().size(); } |
| 37 |
| 38 int Read(char* buffer, size_t size) { |
| 39 size = std::min(size, packet->data().size() - pos); |
| 40 memcpy(buffer, packet->data().data() + pos, size); |
| 41 pos += size; |
| 42 return size; |
| 43 } |
| 44 |
| 45 private: |
| 46 scoped_ptr<MultiplexPacket> packet; |
| 47 base::Closure done_task; |
| 48 size_t pos; |
| 49 |
| 50 DISALLOW_COPY_AND_ASSIGN(PendingPacket); |
| 51 }; |
| 52 |
| 53 } // namespace |
| 54 |
| 55 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; |
| 56 |
| 57 struct ChannelMultiplexer::PendingChannel { |
| 58 PendingChannel(const std::string& name, |
| 59 const StreamChannelCallback& callback) |
| 60 : name(name), callback(callback) { |
| 61 } |
| 62 std::string name; |
| 63 StreamChannelCallback callback; |
| 64 }; |
| 65 |
| 66 class ChannelMultiplexer::MuxChannel { |
| 67 public: |
| 68 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, |
| 69 int send_id); |
| 70 ~MuxChannel(); |
| 71 |
| 72 const std::string& name() { return name_; } |
| 73 int receive_id() { return receive_id_; } |
| 74 void set_receive_id(int id) { receive_id_ = id; } |
| 75 |
| 76 // Called by ChannelMultiplexer. |
| 77 scoped_ptr<net::StreamSocket> CreateSocket(); |
| 78 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
| 79 const base::Closure& done_task); |
| 80 void OnWriteFailed(); |
| 81 |
| 82 // Called by MuxSocket. |
| 83 void OnSocketDestroyed(); |
| 84 bool DoWrite(scoped_ptr<MultiplexPacket> packet, |
| 85 const base::Closure& done_task); |
| 86 int DoRead(net::IOBuffer* buffer, int buffer_len); |
| 87 |
| 88 private: |
| 89 ChannelMultiplexer* multiplexer_; |
| 90 std::string name_; |
| 91 int send_id_; |
| 92 bool id_sent_; |
| 93 int receive_id_; |
| 94 MuxSocket* socket_; |
| 95 std::list<PendingPacket*> pending_packets_; |
| 96 |
| 97 DISALLOW_COPY_AND_ASSIGN(MuxChannel); |
| 98 }; |
| 99 |
| 100 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, |
| 101 public base::NonThreadSafe, |
| 102 public base::SupportsWeakPtr<MuxSocket> { |
| 103 public: |
| 104 MuxSocket(MuxChannel* channel); |
| 105 ~MuxSocket(); |
| 106 |
| 107 void OnWriteComplete(); |
| 108 void OnWriteFailed(); |
| 109 void OnPacketReceived(); |
| 110 |
| 111 // net::StreamSocket interface. |
| 112 virtual int Read(net::IOBuffer* buffer, int buffer_len, |
| 113 const net::CompletionCallback& callback) OVERRIDE; |
| 114 virtual int Write(net::IOBuffer* buffer, int buffer_len, |
| 115 const net::CompletionCallback& callback) OVERRIDE; |
| 116 |
| 117 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
| 118 NOTIMPLEMENTED(); |
| 119 return false; |
| 120 } |
| 121 virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
| 122 NOTIMPLEMENTED(); |
| 123 return false; |
| 124 } |
| 125 |
| 126 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
| 127 NOTIMPLEMENTED(); |
| 128 return net::ERR_FAILED; |
| 129 } |
| 130 virtual void Disconnect() OVERRIDE { |
| 131 NOTIMPLEMENTED(); |
| 132 } |
| 133 virtual bool IsConnected() const OVERRIDE { |
| 134 NOTIMPLEMENTED(); |
| 135 return true; |
| 136 } |
| 137 virtual bool IsConnectedAndIdle() const OVERRIDE { |
| 138 NOTIMPLEMENTED(); |
| 139 return false; |
| 140 } |
| 141 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
| 142 NOTIMPLEMENTED(); |
| 143 return net::ERR_FAILED; |
| 144 } |
| 145 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
| 146 NOTIMPLEMENTED(); |
| 147 return net::ERR_FAILED; |
| 148 } |
| 149 virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
| 150 NOTIMPLEMENTED(); |
| 151 return net_log_; |
| 152 } |
| 153 virtual void SetSubresourceSpeculation() OVERRIDE { |
| 154 NOTIMPLEMENTED(); |
| 155 } |
| 156 virtual void SetOmniboxSpeculation() OVERRIDE { |
| 157 NOTIMPLEMENTED(); |
| 158 } |
| 159 virtual bool WasEverUsed() const OVERRIDE { |
| 160 return true; |
| 161 } |
| 162 virtual bool UsingTCPFastOpen() const OVERRIDE { |
| 163 return false; |
| 164 } |
| 165 virtual int64 NumBytesRead() const OVERRIDE { |
| 166 NOTIMPLEMENTED(); |
| 167 return 0; |
| 168 } |
| 169 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
| 170 NOTIMPLEMENTED(); |
| 171 return base::TimeDelta(); |
| 172 } |
| 173 virtual bool WasNpnNegotiated() const OVERRIDE { |
| 174 return false; |
| 175 } |
| 176 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
| 177 return net::kProtoUnknown; |
| 178 } |
| 179 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
| 180 NOTIMPLEMENTED(); |
| 181 return false; |
| 182 } |
| 183 |
| 184 private: |
| 185 MuxChannel* channel_; |
| 186 |
| 187 net::CompletionCallback read_callback_; |
| 188 scoped_refptr<net::IOBuffer> read_buffer_; |
| 189 int read_buffer_size_; |
| 190 |
| 191 bool write_pending_; |
| 192 int write_result_; |
| 193 net::CompletionCallback write_callback_; |
| 194 |
| 195 net::BoundNetLog net_log_; |
| 196 |
| 197 DISALLOW_COPY_AND_ASSIGN(MuxSocket); |
| 198 }; |
| 199 |
| 200 |
| 201 ChannelMultiplexer::MuxChannel::MuxChannel( |
| 202 ChannelMultiplexer* multiplexer, |
| 203 const std::string& name, |
| 204 int send_id) |
| 205 : multiplexer_(multiplexer), |
| 206 name_(name), |
| 207 send_id_(send_id), |
| 208 id_sent_(false), |
| 209 receive_id_(kChannelIdUnknown), |
| 210 socket_(NULL) { |
| 211 } |
| 212 |
| 213 ChannelMultiplexer::MuxChannel::~MuxChannel() { |
| 214 // Socket must be destroyed before the channel. |
| 215 DCHECK(!socket_); |
| 216 STLDeleteElements(&pending_packets_); |
| 217 } |
| 218 |
| 219 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { |
| 220 DCHECK(!socket_); // Can't create more than one socket per channel. |
| 221 scoped_ptr<MuxSocket> result(new MuxSocket(this)); |
| 222 socket_ = result.get(); |
| 223 return result.PassAs<net::StreamSocket>(); |
| 224 } |
| 225 |
| 226 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( |
| 227 scoped_ptr<MultiplexPacket> packet, |
| 228 const base::Closure& done_task) { |
| 229 DCHECK_EQ(packet->channel_id(), receive_id_); |
| 230 if (packet->data().size() > 0) { |
| 231 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); |
| 232 if (socket_) { |
| 233 // Notify the socket that we have more data. |
| 234 socket_->OnPacketReceived(); |
| 235 } |
| 236 } |
| 237 } |
| 238 |
| 239 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { |
| 240 if (socket_) |
| 241 socket_->OnWriteFailed(); |
| 242 } |
| 243 |
| 244 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { |
| 245 DCHECK(socket_); |
| 246 socket_ = NULL; |
| 247 } |
| 248 |
| 249 bool ChannelMultiplexer::MuxChannel::DoWrite( |
| 250 scoped_ptr<MultiplexPacket> packet, |
| 251 const base::Closure& done_task) { |
| 252 packet->set_channel_id(send_id_); |
| 253 if (!id_sent_) { |
| 254 packet->set_channel_name(name_); |
| 255 id_sent_ = true; |
| 256 } |
| 257 return multiplexer_->DoWrite(packet.Pass(), done_task); |
| 258 } |
| 259 |
| 260 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, |
| 261 int buffer_len) { |
| 262 int pos = 0; |
| 263 while (buffer_len > 0 && !pending_packets_.empty()) { |
| 264 DCHECK(!pending_packets_.front()->is_empty()); |
| 265 int result = pending_packets_.front()->Read( |
| 266 buffer->data() + pos, buffer_len); |
| 267 DCHECK_LE(result, buffer_len); |
| 268 pos += result; |
| 269 buffer_len -= pos; |
| 270 if (pending_packets_.front()->is_empty()) { |
| 271 delete pending_packets_.front(); |
| 272 pending_packets_.erase(pending_packets_.begin()); |
| 273 } |
| 274 } |
| 275 return pos; |
| 276 } |
| 277 |
| 278 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) |
| 279 : channel_(channel), |
| 280 read_buffer_size_(0), |
| 281 write_pending_(false), |
| 282 write_result_(0) { |
| 283 } |
| 284 |
| 285 ChannelMultiplexer::MuxSocket::~MuxSocket() { |
| 286 channel_->OnSocketDestroyed(); |
| 287 } |
| 288 |
| 289 int ChannelMultiplexer::MuxSocket::Read( |
| 290 net::IOBuffer* buffer, int buffer_len, |
| 291 const net::CompletionCallback& callback) { |
| 292 DCHECK(CalledOnValidThread()); |
| 293 DCHECK(read_callback_.is_null()); |
| 294 |
| 295 int result = channel_->DoRead(buffer, buffer_len); |
| 296 if (result == 0) { |
| 297 read_buffer_ = buffer; |
| 298 read_buffer_size_ = buffer_len; |
| 299 read_callback_ = callback; |
| 300 return net::ERR_IO_PENDING; |
| 301 } |
| 302 return result; |
| 303 } |
| 304 |
| 305 int ChannelMultiplexer::MuxSocket::Write( |
| 306 net::IOBuffer* buffer, int buffer_len, |
| 307 const net::CompletionCallback& callback) { |
| 308 DCHECK(CalledOnValidThread()); |
| 309 |
| 310 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); |
| 311 size_t size = std::min(kMaxPacketSize, buffer_len); |
| 312 packet->mutable_data()->assign(buffer->data(), size); |
| 313 |
| 314 write_pending_ = true; |
| 315 bool result = channel_->DoWrite(packet.Pass(), base::Bind( |
| 316 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); |
| 317 |
| 318 if (!result) { |
| 319 // Cannot complete the write, e.g. if the connection has been terminated. |
| 320 return net::ERR_FAILED; |
| 321 } |
| 322 |
| 323 // OnWriteComplete() might be called above synchronously. |
| 324 if (write_pending_) { |
| 325 DCHECK(write_callback_.is_null()); |
| 326 write_callback_ = callback; |
| 327 write_result_ = size; |
| 328 return net::ERR_IO_PENDING; |
| 329 } |
| 330 |
| 331 return size; |
| 332 } |
| 333 |
| 334 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { |
| 335 write_pending_ = false; |
| 336 if (!write_callback_.is_null()) { |
| 337 net::CompletionCallback cb; |
| 338 std::swap(cb, write_callback_); |
| 339 cb.Run(write_result_); |
| 340 } |
| 341 } |
| 342 |
| 343 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { |
| 344 if (!write_callback_.is_null()) { |
| 345 net::CompletionCallback cb; |
| 346 std::swap(cb, write_callback_); |
| 347 cb.Run(net::ERR_FAILED); |
| 348 } |
| 349 } |
| 350 |
| 351 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { |
| 352 if (!read_callback_.is_null()) { |
| 353 int result = channel_->DoRead(read_buffer_, read_buffer_size_); |
| 354 read_buffer_ = NULL; |
| 355 DCHECK_GT(result, 0); |
| 356 net::CompletionCallback cb; |
| 357 std::swap(cb, read_callback_); |
| 358 cb.Run(result); |
| 359 } |
| 360 } |
| 361 |
| 362 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, |
| 363 const std::string& base_channel_name) |
| 364 : base_channel_factory_(factory), |
| 365 base_channel_name_(base_channel_name), |
| 366 next_channel_id_(0), |
| 367 destroyed_flag_(NULL) { |
| 368 factory->CreateStreamChannel( |
| 369 base_channel_name, |
| 370 base::Bind(&ChannelMultiplexer::OnBaseChannelReady, |
| 371 base::Unretained(this))); |
| 372 } |
| 373 |
| 374 ChannelMultiplexer::~ChannelMultiplexer() { |
| 375 DCHECK(pending_channels_.empty()); |
| 376 STLDeleteValues(&channels_); |
| 377 |
| 378 // Cancel creation of the base channel if it hasn't finished. |
| 379 if (base_channel_factory_) |
| 380 base_channel_factory_->CancelChannelCreation(base_channel_name_); |
| 381 |
| 382 if (destroyed_flag_) |
| 383 *destroyed_flag_ = true; |
| 384 } |
| 385 |
| 386 void ChannelMultiplexer::CreateStreamChannel( |
| 387 const std::string& name, |
| 388 const StreamChannelCallback& callback) { |
| 389 if (base_channel_.get()) { |
| 390 // Already have |base_channel_|. Create new multiplexed channel |
| 391 // synchronously. |
| 392 callback.Run(GetOrCreateChannel(name)->CreateSocket()); |
| 393 } else if (!base_channel_.get() && !base_channel_factory_) { |
| 394 // Fail synchronously if we failed to create |base_channel_|. |
| 395 callback.Run(scoped_ptr<net::StreamSocket>()); |
| 396 } else { |
| 397 // Still waiting for the |base_channel_|. |
| 398 pending_channels_.push_back(PendingChannel(name, callback)); |
| 399 } |
| 400 } |
| 401 |
| 402 void ChannelMultiplexer::CreateDatagramChannel( |
| 403 const std::string& name, |
| 404 const DatagramChannelCallback& callback) { |
| 405 NOTIMPLEMENTED(); |
| 406 callback.Run(scoped_ptr<net::Socket>()); |
| 407 } |
| 408 |
| 409 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { |
| 410 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
| 411 it != pending_channels_.end(); ++it) { |
| 412 if (it->name == name) { |
| 413 pending_channels_.erase(it); |
| 414 return; |
| 415 } |
| 416 } |
| 417 } |
| 418 |
| 419 void ChannelMultiplexer::OnBaseChannelReady( |
| 420 scoped_ptr<net::StreamSocket> socket) { |
| 421 base_channel_factory_ = NULL; |
| 422 base_channel_ = socket.Pass(); |
| 423 |
| 424 if (!base_channel_.get()) { |
| 425 // Notify all callers that we can't create any channels. |
| 426 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
| 427 it != pending_channels_.end(); ++it) { |
| 428 it->callback.Run(scoped_ptr<net::StreamSocket>()); |
| 429 } |
| 430 pending_channels_.clear(); |
| 431 return; |
| 432 } |
| 433 |
| 434 // Initialize reader and writer. |
| 435 reader_.Init(base_channel_.get(), |
| 436 base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
| 437 base::Unretained(this))); |
| 438 writer_.Init(base_channel_.get(), |
| 439 base::Bind(&ChannelMultiplexer::OnWriteFailed, |
| 440 base::Unretained(this))); |
| 441 |
| 442 // Now create all pending channels. |
| 443 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
| 444 it != pending_channels_.end(); ++it) { |
| 445 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); |
| 446 } |
| 447 pending_channels_.clear(); |
| 448 } |
| 449 |
| 450 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( |
| 451 const std::string& name) { |
| 452 // Check if we already have a channel with the requested name. |
| 453 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
| 454 if (it != channels_.end()) |
| 455 return it->second; |
| 456 |
| 457 // Create a new channel if we haven't found existing one. |
| 458 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
| 459 ++next_channel_id_; |
| 460 channels_[channel->name()] = channel; |
| 461 return channel; |
| 462 } |
| 463 |
| 464 |
| 465 void ChannelMultiplexer::OnWriteFailed(int error) { |
| 466 bool destroyed = false; |
| 467 destroyed_flag_ = &destroyed; |
| 468 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
| 469 it != channels_.end(); ++it) { |
| 470 it->second->OnWriteFailed(); |
| 471 if (destroyed) |
| 472 return; |
| 473 } |
| 474 destroyed_flag_ = NULL; |
| 475 } |
| 476 |
| 477 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
| 478 const base::Closure& done_task) { |
| 479 if (!packet->has_channel_id()) { |
| 480 LOG(ERROR) << "Received packet without channel_id."; |
| 481 done_task.Run(); |
| 482 return; |
| 483 } |
| 484 |
| 485 int receive_id = packet->channel_id(); |
| 486 MuxChannel* channel = NULL; |
| 487 std::map<int, MuxChannel*>::iterator it = |
| 488 channels_by_receive_id_.find(receive_id); |
| 489 if (it != channels_by_receive_id_.end()) { |
| 490 channel = it->second; |
| 491 } else { |
| 492 // This is a new |channel_id| we haven't seen before. Look it up by name. |
| 493 if (!packet->has_channel_name()) { |
| 494 LOG(ERROR) << "Received packet with unknown channel_id and " |
| 495 "without channel_name."; |
| 496 done_task.Run(); |
| 497 return; |
| 498 } |
| 499 channel = GetOrCreateChannel(packet->channel_name()); |
| 500 channel->set_receive_id(receive_id); |
| 501 channels_by_receive_id_[receive_id] = channel; |
| 502 } |
| 503 |
| 504 channel->OnIncomingPacket(packet.Pass(), done_task); |
| 505 } |
| 506 |
| 507 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, |
| 508 const base::Closure& done_task) { |
| 509 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
| 510 } |
| 511 |
| 512 } // namespace protocol |
| 513 } // namespace remoting |
OLD | NEW |