OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "remoting/protocol/channel_multiplexer.h" | 5 #include "remoting/protocol/channel_multiplexer.h" |
6 | 6 |
7 #include <string.h> | 7 #include <string.h> |
8 | 8 |
9 #include "base/bind.h" | 9 #include "base/bind.h" |
10 #include "base/callback.h" | 10 #include "base/callback.h" |
11 #include "base/location.h" | 11 #include "base/location.h" |
| 12 #include "base/single_thread_task_runner.h" |
12 #include "base/stl_util.h" | 13 #include "base/stl_util.h" |
| 14 #include "base/thread_task_runner_handle.h" |
13 #include "net/base/net_errors.h" | 15 #include "net/base/net_errors.h" |
14 #include "net/socket/stream_socket.h" | 16 #include "net/socket/stream_socket.h" |
15 #include "remoting/protocol/util.h" | 17 #include "remoting/protocol/util.h" |
16 | 18 |
17 namespace remoting { | 19 namespace remoting { |
18 namespace protocol { | 20 namespace protocol { |
19 | 21 |
20 namespace { | 22 namespace { |
21 const int kChannelIdUnknown = -1; | 23 const int kChannelIdUnknown = -1; |
22 const int kMaxPacketSize = 1024; | 24 const int kMaxPacketSize = 1024; |
(...skipping 334 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
357 std::swap(cb, read_callback_); | 359 std::swap(cb, read_callback_); |
358 cb.Run(result); | 360 cb.Run(result); |
359 } | 361 } |
360 } | 362 } |
361 | 363 |
362 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, | 364 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, |
363 const std::string& base_channel_name) | 365 const std::string& base_channel_name) |
364 : base_channel_factory_(factory), | 366 : base_channel_factory_(factory), |
365 base_channel_name_(base_channel_name), | 367 base_channel_name_(base_channel_name), |
366 next_channel_id_(0), | 368 next_channel_id_(0), |
367 destroyed_flag_(NULL) { | 369 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
368 } | 370 } |
369 | 371 |
370 ChannelMultiplexer::~ChannelMultiplexer() { | 372 ChannelMultiplexer::~ChannelMultiplexer() { |
371 DCHECK(pending_channels_.empty()); | 373 DCHECK(pending_channels_.empty()); |
372 STLDeleteValues(&channels_); | 374 STLDeleteValues(&channels_); |
373 | 375 |
374 // Cancel creation of the base channel if it hasn't finished. | 376 // Cancel creation of the base channel if it hasn't finished. |
375 if (base_channel_factory_) | 377 if (base_channel_factory_) |
376 base_channel_factory_->CancelChannelCreation(base_channel_name_); | 378 base_channel_factory_->CancelChannelCreation(base_channel_name_); |
377 | |
378 if (destroyed_flag_) | |
379 *destroyed_flag_ = true; | |
380 } | 379 } |
381 | 380 |
382 void ChannelMultiplexer::CreateStreamChannel( | 381 void ChannelMultiplexer::CreateStreamChannel( |
383 const std::string& name, | 382 const std::string& name, |
384 const StreamChannelCallback& callback) { | 383 const StreamChannelCallback& callback) { |
385 if (base_channel_.get()) { | 384 if (base_channel_.get()) { |
386 // Already have |base_channel_|. Create new multiplexed channel | 385 // Already have |base_channel_|. Create new multiplexed channel |
387 // synchronously. | 386 // synchronously. |
388 callback.Run(GetOrCreateChannel(name)->CreateSocket()); | 387 callback.Run(GetOrCreateChannel(name)->CreateSocket()); |
389 } else if (!base_channel_.get() && !base_channel_factory_) { | 388 } else if (!base_channel_.get() && !base_channel_factory_) { |
(...skipping 28 matching lines...) Expand all Loading... |
418 return; | 417 return; |
419 } | 418 } |
420 } | 419 } |
421 } | 420 } |
422 | 421 |
423 void ChannelMultiplexer::OnBaseChannelReady( | 422 void ChannelMultiplexer::OnBaseChannelReady( |
424 scoped_ptr<net::StreamSocket> socket) { | 423 scoped_ptr<net::StreamSocket> socket) { |
425 base_channel_factory_ = NULL; | 424 base_channel_factory_ = NULL; |
426 base_channel_ = socket.Pass(); | 425 base_channel_ = socket.Pass(); |
427 | 426 |
428 if (!base_channel_.get()) { | 427 if (base_channel_.get()) { |
429 // Notify all callers that we can't create any channels. | 428 // Initialize reader and writer. |
430 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | 429 reader_.Init(base_channel_.get(), |
431 it != pending_channels_.end(); ++it) { | 430 base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
432 it->callback.Run(scoped_ptr<net::StreamSocket>()); | 431 base::Unretained(this))); |
433 } | 432 writer_.Init(base_channel_.get(), |
434 pending_channels_.clear(); | 433 base::Bind(&ChannelMultiplexer::OnWriteFailed, |
435 return; | 434 base::Unretained(this))); |
436 } | 435 } |
437 | 436 |
438 // Initialize reader and writer. | 437 DoCreatePendingChannels(); |
439 reader_.Init(base_channel_.get(), | 438 } |
440 base::Bind(&ChannelMultiplexer::OnIncomingPacket, | |
441 base::Unretained(this))); | |
442 writer_.Init(base_channel_.get(), | |
443 base::Bind(&ChannelMultiplexer::OnWriteFailed, | |
444 base::Unretained(this))); | |
445 | 439 |
446 // Now create all pending channels. | 440 void ChannelMultiplexer::DoCreatePendingChannels() { |
447 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | 441 if (pending_channels_.empty()) |
448 it != pending_channels_.end(); ++it) { | 442 return; |
449 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); | 443 |
450 } | 444 // Every time this function is called it connects a single channel and posts a |
451 pending_channels_.clear(); | 445 // separate task to connect other channels. This is necessary because the |
| 446 // callback may destroy the multiplexer or somehow else modify |
| 447 // |pending_channels_| list (e.g. call CancelChannelCreation()). |
| 448 base::ThreadTaskRunnerHandle::Get()->PostTask( |
| 449 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels, |
| 450 weak_factory_.GetWeakPtr())); |
| 451 |
| 452 PendingChannel c = pending_channels_.front(); |
| 453 pending_channels_.erase(pending_channels_.begin()); |
| 454 scoped_ptr<net::StreamSocket> socket; |
| 455 if (base_channel_.get()) |
| 456 socket = GetOrCreateChannel(c.name)->CreateSocket(); |
| 457 c.callback.Run(socket.Pass()); |
452 } | 458 } |
453 | 459 |
454 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( | 460 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( |
455 const std::string& name) { | 461 const std::string& name) { |
456 // Check if we already have a channel with the requested name. | 462 // Check if we already have a channel with the requested name. |
457 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | 463 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
458 if (it != channels_.end()) | 464 if (it != channels_.end()) |
459 return it->second; | 465 return it->second; |
460 | 466 |
461 // Create a new channel if we haven't found existing one. | 467 // Create a new channel if we haven't found existing one. |
462 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | 468 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
463 ++next_channel_id_; | 469 ++next_channel_id_; |
464 channels_[channel->name()] = channel; | 470 channels_[channel->name()] = channel; |
465 return channel; | 471 return channel; |
466 } | 472 } |
467 | 473 |
468 | 474 |
469 void ChannelMultiplexer::OnWriteFailed(int error) { | 475 void ChannelMultiplexer::OnWriteFailed(int error) { |
470 bool destroyed = false; | |
471 destroyed_flag_ = &destroyed; | |
472 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | 476 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
473 it != channels_.end(); ++it) { | 477 it != channels_.end(); ++it) { |
| 478 base::ThreadTaskRunnerHandle::Get()->PostTask( |
| 479 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, |
| 480 weak_factory_.GetWeakPtr(), it->second->name())); |
| 481 } |
| 482 } |
| 483 |
| 484 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { |
| 485 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
| 486 if (it != channels_.end()) { |
474 it->second->OnWriteFailed(); | 487 it->second->OnWriteFailed(); |
475 if (destroyed) | |
476 return; | |
477 } | 488 } |
478 destroyed_flag_ = NULL; | |
479 } | 489 } |
480 | 490 |
481 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | 491 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
482 const base::Closure& done_task) { | 492 const base::Closure& done_task) { |
483 if (!packet->has_channel_id()) { | 493 if (!packet->has_channel_id()) { |
484 LOG(ERROR) << "Received packet without channel_id."; | 494 LOG(ERROR) << "Received packet without channel_id."; |
485 done_task.Run(); | 495 done_task.Run(); |
486 return; | 496 return; |
487 } | 497 } |
488 | 498 |
(...skipping 19 matching lines...) Expand all Loading... |
508 channel->OnIncomingPacket(packet.Pass(), done_task); | 518 channel->OnIncomingPacket(packet.Pass(), done_task); |
509 } | 519 } |
510 | 520 |
511 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, | 521 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, |
512 const base::Closure& done_task) { | 522 const base::Closure& done_task) { |
513 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); | 523 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
514 } | 524 } |
515 | 525 |
516 } // namespace protocol | 526 } // namespace protocol |
517 } // namespace remoting | 527 } // namespace remoting |
OLD | NEW |