Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(1318)

Unified Diff: remoting/protocol/channel_multiplexer.cc

Issue 10830046: Implement ChannelMultiplexer. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 8 years, 4 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « remoting/protocol/channel_multiplexer.h ('k') | remoting/protocol/channel_multiplexer_unittest.cc » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: remoting/protocol/channel_multiplexer.cc
diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc
new file mode 100644
index 0000000000000000000000000000000000000000..71647bfe890952c34f845a0e6ff4bd6ae1c164b3
--- /dev/null
+++ b/remoting/protocol/channel_multiplexer.cc
@@ -0,0 +1,513 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "remoting/protocol/channel_multiplexer.h"
+
+#include <string.h>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/location.h"
+#include "base/stl_util.h"
+#include "net/base/net_errors.h"
+#include "net/socket/stream_socket.h"
+#include "remoting/protocol/util.h"
+
+namespace remoting {
+namespace protocol {
+
+namespace {
+const int kChannelIdUnknown = -1;
+const int kMaxPacketSize = 1024;
+
+class PendingPacket {
+ public:
+ PendingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task)
+ : packet(packet.Pass()),
+ done_task(done_task),
+ pos(0U) {
+ }
+ ~PendingPacket() {
+ done_task.Run();
+ }
+
+ bool is_empty() { return pos >= packet->data().size(); }
+
+ int Read(char* buffer, size_t size) {
+ size = std::min(size, packet->data().size() - pos);
+ memcpy(buffer, packet->data().data() + pos, size);
+ pos += size;
+ return size;
+ }
+
+ private:
+ scoped_ptr<MultiplexPacket> packet;
+ base::Closure done_task;
+ size_t pos;
+
+ DISALLOW_COPY_AND_ASSIGN(PendingPacket);
+};
+
+} // namespace
+
+const char ChannelMultiplexer::kMuxChannelName[] = "mux";
+
+struct ChannelMultiplexer::PendingChannel {
+ PendingChannel(const std::string& name,
+ const StreamChannelCallback& callback)
+ : name(name), callback(callback) {
+ }
+ std::string name;
+ StreamChannelCallback callback;
+};
+
+class ChannelMultiplexer::MuxChannel {
+ public:
+ MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
+ int send_id);
+ ~MuxChannel();
+
+ const std::string& name() { return name_; }
+ int receive_id() { return receive_id_; }
+ void set_receive_id(int id) { receive_id_ = id; }
+
+ // Called by ChannelMultiplexer.
+ scoped_ptr<net::StreamSocket> CreateSocket();
+ void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+ void OnWriteFailed();
+
+ // Called by MuxSocket.
+ void OnSocketDestroyed();
+ bool DoWrite(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task);
+ int DoRead(net::IOBuffer* buffer, int buffer_len);
+
+ private:
+ ChannelMultiplexer* multiplexer_;
+ std::string name_;
+ int send_id_;
+ bool id_sent_;
+ int receive_id_;
+ MuxSocket* socket_;
+ std::list<PendingPacket*> pending_packets_;
+
+ DISALLOW_COPY_AND_ASSIGN(MuxChannel);
+};
+
+class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
+ public base::NonThreadSafe,
+ public base::SupportsWeakPtr<MuxSocket> {
+ public:
+ MuxSocket(MuxChannel* channel);
+ ~MuxSocket();
+
+ void OnWriteComplete();
+ void OnWriteFailed();
+ void OnPacketReceived();
+
+ // net::StreamSocket interface.
+ virtual int Read(net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) OVERRIDE;
+ virtual int Write(net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) OVERRIDE;
+
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+
+ virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual void Disconnect() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual bool IsConnected() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return true;
+ }
+ virtual bool IsConnectedAndIdle() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+ virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net::ERR_FAILED;
+ }
+ virtual const net::BoundNetLog& NetLog() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return net_log_;
+ }
+ virtual void SetSubresourceSpeculation() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual void SetOmniboxSpeculation() OVERRIDE {
+ NOTIMPLEMENTED();
+ }
+ virtual bool WasEverUsed() const OVERRIDE {
+ return true;
+ }
+ virtual bool UsingTCPFastOpen() const OVERRIDE {
+ return false;
+ }
+ virtual int64 NumBytesRead() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return 0;
+ }
+ virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE {
+ NOTIMPLEMENTED();
+ return base::TimeDelta();
+ }
+ virtual bool WasNpnNegotiated() const OVERRIDE {
+ return false;
+ }
+ virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
+ return net::kProtoUnknown;
+ }
+ virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
+ NOTIMPLEMENTED();
+ return false;
+ }
+
+ private:
+ MuxChannel* channel_;
+
+ net::CompletionCallback read_callback_;
+ scoped_refptr<net::IOBuffer> read_buffer_;
+ int read_buffer_size_;
+
+ bool write_pending_;
+ int write_result_;
+ net::CompletionCallback write_callback_;
+
+ net::BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(MuxSocket);
+};
+
+
+ChannelMultiplexer::MuxChannel::MuxChannel(
+ ChannelMultiplexer* multiplexer,
+ const std::string& name,
+ int send_id)
+ : multiplexer_(multiplexer),
+ name_(name),
+ send_id_(send_id),
+ id_sent_(false),
+ receive_id_(kChannelIdUnknown),
+ socket_(NULL) {
+}
+
+ChannelMultiplexer::MuxChannel::~MuxChannel() {
+ // Socket must be destroyed before the channel.
+ DCHECK(!socket_);
+ STLDeleteElements(&pending_packets_);
+}
+
+scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
+ DCHECK(!socket_); // Can't create more than one socket per channel.
+ scoped_ptr<MuxSocket> result(new MuxSocket(this));
+ socket_ = result.get();
+ return result.PassAs<net::StreamSocket>();
+}
+
+void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
+ scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ DCHECK_EQ(packet->channel_id(), receive_id_);
+ if (packet->data().size() > 0) {
+ pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
+ if (socket_) {
+ // Notify the socket that we have more data.
+ socket_->OnPacketReceived();
+ }
+ }
+}
+
+void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
+ if (socket_)
+ socket_->OnWriteFailed();
+}
+
+void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
+ DCHECK(socket_);
+ socket_ = NULL;
+}
+
+bool ChannelMultiplexer::MuxChannel::DoWrite(
+ scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ packet->set_channel_id(send_id_);
+ if (!id_sent_) {
+ packet->set_channel_name(name_);
+ id_sent_ = true;
+ }
+ return multiplexer_->DoWrite(packet.Pass(), done_task);
+}
+
+int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
+ int buffer_len) {
+ int pos = 0;
+ while (buffer_len > 0 && !pending_packets_.empty()) {
+ DCHECK(!pending_packets_.front()->is_empty());
+ int result = pending_packets_.front()->Read(
+ buffer->data() + pos, buffer_len);
+ DCHECK_LE(result, buffer_len);
+ pos += result;
+ buffer_len -= pos;
+ if (pending_packets_.front()->is_empty()) {
+ delete pending_packets_.front();
+ pending_packets_.erase(pending_packets_.begin());
+ }
+ }
+ return pos;
+}
+
+ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
+ : channel_(channel),
+ read_buffer_size_(0),
+ write_pending_(false),
+ write_result_(0) {
+}
+
+ChannelMultiplexer::MuxSocket::~MuxSocket() {
+ channel_->OnSocketDestroyed();
+}
+
+int ChannelMultiplexer::MuxSocket::Read(
+ net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+ DCHECK(read_callback_.is_null());
+
+ int result = channel_->DoRead(buffer, buffer_len);
+ if (result == 0) {
+ read_buffer_ = buffer;
+ read_buffer_size_ = buffer_len;
+ read_callback_ = callback;
+ return net::ERR_IO_PENDING;
+ }
+ return result;
+}
+
+int ChannelMultiplexer::MuxSocket::Write(
+ net::IOBuffer* buffer, int buffer_len,
+ const net::CompletionCallback& callback) {
+ DCHECK(CalledOnValidThread());
+
+ scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
+ size_t size = std::min(kMaxPacketSize, buffer_len);
+ packet->mutable_data()->assign(buffer->data(), size);
+
+ write_pending_ = true;
+ bool result = channel_->DoWrite(packet.Pass(), base::Bind(
+ &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
+
+ if (!result) {
+ // Cannot complete the write, e.g. if the connection has been terminated.
+ return net::ERR_FAILED;
+ }
+
+ // OnWriteComplete() might be called above synchronously.
+ if (write_pending_) {
+ DCHECK(write_callback_.is_null());
+ write_callback_ = callback;
+ write_result_ = size;
+ return net::ERR_IO_PENDING;
+ }
+
+ return size;
+}
+
+void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
+ write_pending_ = false;
+ if (!write_callback_.is_null()) {
+ net::CompletionCallback cb;
+ std::swap(cb, write_callback_);
+ cb.Run(write_result_);
+ }
+}
+
+void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
+ if (!write_callback_.is_null()) {
+ net::CompletionCallback cb;
+ std::swap(cb, write_callback_);
+ cb.Run(net::ERR_FAILED);
+ }
+}
+
+void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
+ if (!read_callback_.is_null()) {
+ int result = channel_->DoRead(read_buffer_, read_buffer_size_);
+ read_buffer_ = NULL;
+ DCHECK_GT(result, 0);
+ net::CompletionCallback cb;
+ std::swap(cb, read_callback_);
+ cb.Run(result);
+ }
+}
+
+ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
+ const std::string& base_channel_name)
+ : base_channel_factory_(factory),
+ base_channel_name_(base_channel_name),
+ next_channel_id_(0),
+ destroyed_flag_(NULL) {
+ factory->CreateStreamChannel(
+ base_channel_name,
+ base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
+ base::Unretained(this)));
+}
+
+ChannelMultiplexer::~ChannelMultiplexer() {
+ DCHECK(pending_channels_.empty());
+ STLDeleteValues(&channels_);
+
+ // Cancel creation of the base channel if it hasn't finished.
+ if (base_channel_factory_)
+ base_channel_factory_->CancelChannelCreation(base_channel_name_);
+
+ if (destroyed_flag_)
+ *destroyed_flag_ = true;
+}
+
+void ChannelMultiplexer::CreateStreamChannel(
+ const std::string& name,
+ const StreamChannelCallback& callback) {
+ if (base_channel_.get()) {
+ // Already have |base_channel_|. Create new multiplexed channel
+ // synchronously.
+ callback.Run(GetOrCreateChannel(name)->CreateSocket());
+ } else if (!base_channel_.get() && !base_channel_factory_) {
+ // Fail synchronously if we failed to create |base_channel_|.
+ callback.Run(scoped_ptr<net::StreamSocket>());
+ } else {
+ // Still waiting for the |base_channel_|.
+ pending_channels_.push_back(PendingChannel(name, callback));
+ }
+}
+
+void ChannelMultiplexer::CreateDatagramChannel(
+ const std::string& name,
+ const DatagramChannelCallback& callback) {
+ NOTIMPLEMENTED();
+ callback.Run(scoped_ptr<net::Socket>());
+}
+
+void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ if (it->name == name) {
+ pending_channels_.erase(it);
+ return;
+ }
+ }
+}
+
+void ChannelMultiplexer::OnBaseChannelReady(
+ scoped_ptr<net::StreamSocket> socket) {
+ base_channel_factory_ = NULL;
+ base_channel_ = socket.Pass();
+
+ if (!base_channel_.get()) {
+ // Notify all callers that we can't create any channels.
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ it->callback.Run(scoped_ptr<net::StreamSocket>());
+ }
+ pending_channels_.clear();
+ return;
+ }
+
+ // Initialize reader and writer.
+ reader_.Init(base_channel_.get(),
+ base::Bind(&ChannelMultiplexer::OnIncomingPacket,
+ base::Unretained(this)));
+ writer_.Init(base_channel_.get(),
+ base::Bind(&ChannelMultiplexer::OnWriteFailed,
+ base::Unretained(this)));
+
+ // Now create all pending channels.
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
+ it != pending_channels_.end(); ++it) {
+ it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket());
+ }
+ pending_channels_.clear();
+}
+
+ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
+ const std::string& name) {
+ // Check if we already have a channel with the requested name.
+ std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
+ if (it != channels_.end())
+ return it->second;
+
+ // Create a new channel if we haven't found existing one.
+ MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
+ ++next_channel_id_;
+ channels_[channel->name()] = channel;
+ return channel;
+}
+
+
+void ChannelMultiplexer::OnWriteFailed(int error) {
+ bool destroyed = false;
+ destroyed_flag_ = &destroyed;
+ for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
+ it != channels_.end(); ++it) {
+ it->second->OnWriteFailed();
+ if (destroyed)
+ return;
+ }
+ destroyed_flag_ = NULL;
+}
+
+void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ if (!packet->has_channel_id()) {
+ LOG(ERROR) << "Received packet without channel_id.";
+ done_task.Run();
+ return;
+ }
+
+ int receive_id = packet->channel_id();
+ MuxChannel* channel = NULL;
+ std::map<int, MuxChannel*>::iterator it =
+ channels_by_receive_id_.find(receive_id);
+ if (it != channels_by_receive_id_.end()) {
+ channel = it->second;
+ } else {
+ // This is a new |channel_id| we haven't seen before. Look it up by name.
+ if (!packet->has_channel_name()) {
+ LOG(ERROR) << "Received packet with unknown channel_id and "
+ "without channel_name.";
+ done_task.Run();
+ return;
+ }
+ channel = GetOrCreateChannel(packet->channel_name());
+ channel->set_receive_id(receive_id);
+ channels_by_receive_id_[receive_id] = channel;
+ }
+
+ channel->OnIncomingPacket(packet.Pass(), done_task);
+}
+
+bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
+ const base::Closure& done_task) {
+ return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
+}
+
+} // namespace protocol
+} // namespace remoting
« no previous file with comments | « remoting/protocol/channel_multiplexer.h ('k') | remoting/protocol/channel_multiplexer_unittest.cc » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698