Index: remoting/protocol/channel_multiplexer.cc |
diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc |
new file mode 100644 |
index 0000000000000000000000000000000000000000..71647bfe890952c34f845a0e6ff4bd6ae1c164b3 |
--- /dev/null |
+++ b/remoting/protocol/channel_multiplexer.cc |
@@ -0,0 +1,513 @@ |
+// Copyright (c) 2012 The Chromium Authors. All rights reserved. |
+// Use of this source code is governed by a BSD-style license that can be |
+// found in the LICENSE file. |
+ |
+#include "remoting/protocol/channel_multiplexer.h" |
+ |
+#include <string.h> |
+ |
+#include "base/bind.h" |
+#include "base/callback.h" |
+#include "base/location.h" |
+#include "base/stl_util.h" |
+#include "net/base/net_errors.h" |
+#include "net/socket/stream_socket.h" |
+#include "remoting/protocol/util.h" |
+ |
+namespace remoting { |
+namespace protocol { |
+ |
+namespace { |
+const int kChannelIdUnknown = -1; |
+const int kMaxPacketSize = 1024; |
+ |
+class PendingPacket { |
+ public: |
+ PendingPacket(scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task) |
+ : packet(packet.Pass()), |
+ done_task(done_task), |
+ pos(0U) { |
+ } |
+ ~PendingPacket() { |
+ done_task.Run(); |
+ } |
+ |
+ bool is_empty() { return pos >= packet->data().size(); } |
+ |
+ int Read(char* buffer, size_t size) { |
+ size = std::min(size, packet->data().size() - pos); |
+ memcpy(buffer, packet->data().data() + pos, size); |
+ pos += size; |
+ return size; |
+ } |
+ |
+ private: |
+ scoped_ptr<MultiplexPacket> packet; |
+ base::Closure done_task; |
+ size_t pos; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(PendingPacket); |
+}; |
+ |
+} // namespace |
+ |
+const char ChannelMultiplexer::kMuxChannelName[] = "mux"; |
+ |
+struct ChannelMultiplexer::PendingChannel { |
+ PendingChannel(const std::string& name, |
+ const StreamChannelCallback& callback) |
+ : name(name), callback(callback) { |
+ } |
+ std::string name; |
+ StreamChannelCallback callback; |
+}; |
+ |
+class ChannelMultiplexer::MuxChannel { |
+ public: |
+ MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, |
+ int send_id); |
+ ~MuxChannel(); |
+ |
+ const std::string& name() { return name_; } |
+ int receive_id() { return receive_id_; } |
+ void set_receive_id(int id) { receive_id_ = id; } |
+ |
+ // Called by ChannelMultiplexer. |
+ scoped_ptr<net::StreamSocket> CreateSocket(); |
+ void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task); |
+ void OnWriteFailed(); |
+ |
+ // Called by MuxSocket. |
+ void OnSocketDestroyed(); |
+ bool DoWrite(scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task); |
+ int DoRead(net::IOBuffer* buffer, int buffer_len); |
+ |
+ private: |
+ ChannelMultiplexer* multiplexer_; |
+ std::string name_; |
+ int send_id_; |
+ bool id_sent_; |
+ int receive_id_; |
+ MuxSocket* socket_; |
+ std::list<PendingPacket*> pending_packets_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(MuxChannel); |
+}; |
+ |
+class ChannelMultiplexer::MuxSocket : public net::StreamSocket, |
+ public base::NonThreadSafe, |
+ public base::SupportsWeakPtr<MuxSocket> { |
+ public: |
+ MuxSocket(MuxChannel* channel); |
+ ~MuxSocket(); |
+ |
+ void OnWriteComplete(); |
+ void OnWriteFailed(); |
+ void OnPacketReceived(); |
+ |
+ // net::StreamSocket interface. |
+ virtual int Read(net::IOBuffer* buffer, int buffer_len, |
+ const net::CompletionCallback& callback) OVERRIDE; |
+ virtual int Write(net::IOBuffer* buffer, int buffer_len, |
+ const net::CompletionCallback& callback) OVERRIDE; |
+ |
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return false; |
+ } |
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return false; |
+ } |
+ |
+ virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return net::ERR_FAILED; |
+ } |
+ virtual void Disconnect() OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ } |
+ virtual bool IsConnected() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return true; |
+ } |
+ virtual bool IsConnectedAndIdle() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return false; |
+ } |
+ virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return net::ERR_FAILED; |
+ } |
+ virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return net::ERR_FAILED; |
+ } |
+ virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return net_log_; |
+ } |
+ virtual void SetSubresourceSpeculation() OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ } |
+ virtual void SetOmniboxSpeculation() OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ } |
+ virtual bool WasEverUsed() const OVERRIDE { |
+ return true; |
+ } |
+ virtual bool UsingTCPFastOpen() const OVERRIDE { |
+ return false; |
+ } |
+ virtual int64 NumBytesRead() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return 0; |
+ } |
+ virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return base::TimeDelta(); |
+ } |
+ virtual bool WasNpnNegotiated() const OVERRIDE { |
+ return false; |
+ } |
+ virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
+ return net::kProtoUnknown; |
+ } |
+ virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return false; |
+ } |
+ |
+ private: |
+ MuxChannel* channel_; |
+ |
+ net::CompletionCallback read_callback_; |
+ scoped_refptr<net::IOBuffer> read_buffer_; |
+ int read_buffer_size_; |
+ |
+ bool write_pending_; |
+ int write_result_; |
+ net::CompletionCallback write_callback_; |
+ |
+ net::BoundNetLog net_log_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(MuxSocket); |
+}; |
+ |
+ |
+ChannelMultiplexer::MuxChannel::MuxChannel( |
+ ChannelMultiplexer* multiplexer, |
+ const std::string& name, |
+ int send_id) |
+ : multiplexer_(multiplexer), |
+ name_(name), |
+ send_id_(send_id), |
+ id_sent_(false), |
+ receive_id_(kChannelIdUnknown), |
+ socket_(NULL) { |
+} |
+ |
+ChannelMultiplexer::MuxChannel::~MuxChannel() { |
+ // Socket must be destroyed before the channel. |
+ DCHECK(!socket_); |
+ STLDeleteElements(&pending_packets_); |
+} |
+ |
+scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { |
+ DCHECK(!socket_); // Can't create more than one socket per channel. |
+ scoped_ptr<MuxSocket> result(new MuxSocket(this)); |
+ socket_ = result.get(); |
+ return result.PassAs<net::StreamSocket>(); |
+} |
+ |
+void ChannelMultiplexer::MuxChannel::OnIncomingPacket( |
+ scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task) { |
+ DCHECK_EQ(packet->channel_id(), receive_id_); |
+ if (packet->data().size() > 0) { |
+ pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); |
+ if (socket_) { |
+ // Notify the socket that we have more data. |
+ socket_->OnPacketReceived(); |
+ } |
+ } |
+} |
+ |
+void ChannelMultiplexer::MuxChannel::OnWriteFailed() { |
+ if (socket_) |
+ socket_->OnWriteFailed(); |
+} |
+ |
+void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { |
+ DCHECK(socket_); |
+ socket_ = NULL; |
+} |
+ |
+bool ChannelMultiplexer::MuxChannel::DoWrite( |
+ scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task) { |
+ packet->set_channel_id(send_id_); |
+ if (!id_sent_) { |
+ packet->set_channel_name(name_); |
+ id_sent_ = true; |
+ } |
+ return multiplexer_->DoWrite(packet.Pass(), done_task); |
+} |
+ |
+int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, |
+ int buffer_len) { |
+ int pos = 0; |
+ while (buffer_len > 0 && !pending_packets_.empty()) { |
+ DCHECK(!pending_packets_.front()->is_empty()); |
+ int result = pending_packets_.front()->Read( |
+ buffer->data() + pos, buffer_len); |
+ DCHECK_LE(result, buffer_len); |
+ pos += result; |
+ buffer_len -= pos; |
+ if (pending_packets_.front()->is_empty()) { |
+ delete pending_packets_.front(); |
+ pending_packets_.erase(pending_packets_.begin()); |
+ } |
+ } |
+ return pos; |
+} |
+ |
+ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) |
+ : channel_(channel), |
+ read_buffer_size_(0), |
+ write_pending_(false), |
+ write_result_(0) { |
+} |
+ |
+ChannelMultiplexer::MuxSocket::~MuxSocket() { |
+ channel_->OnSocketDestroyed(); |
+} |
+ |
+int ChannelMultiplexer::MuxSocket::Read( |
+ net::IOBuffer* buffer, int buffer_len, |
+ const net::CompletionCallback& callback) { |
+ DCHECK(CalledOnValidThread()); |
+ DCHECK(read_callback_.is_null()); |
+ |
+ int result = channel_->DoRead(buffer, buffer_len); |
+ if (result == 0) { |
+ read_buffer_ = buffer; |
+ read_buffer_size_ = buffer_len; |
+ read_callback_ = callback; |
+ return net::ERR_IO_PENDING; |
+ } |
+ return result; |
+} |
+ |
+int ChannelMultiplexer::MuxSocket::Write( |
+ net::IOBuffer* buffer, int buffer_len, |
+ const net::CompletionCallback& callback) { |
+ DCHECK(CalledOnValidThread()); |
+ |
+ scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); |
+ size_t size = std::min(kMaxPacketSize, buffer_len); |
+ packet->mutable_data()->assign(buffer->data(), size); |
+ |
+ write_pending_ = true; |
+ bool result = channel_->DoWrite(packet.Pass(), base::Bind( |
+ &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); |
+ |
+ if (!result) { |
+ // Cannot complete the write, e.g. if the connection has been terminated. |
+ return net::ERR_FAILED; |
+ } |
+ |
+ // OnWriteComplete() might be called above synchronously. |
+ if (write_pending_) { |
+ DCHECK(write_callback_.is_null()); |
+ write_callback_ = callback; |
+ write_result_ = size; |
+ return net::ERR_IO_PENDING; |
+ } |
+ |
+ return size; |
+} |
+ |
+void ChannelMultiplexer::MuxSocket::OnWriteComplete() { |
+ write_pending_ = false; |
+ if (!write_callback_.is_null()) { |
+ net::CompletionCallback cb; |
+ std::swap(cb, write_callback_); |
+ cb.Run(write_result_); |
+ } |
+} |
+ |
+void ChannelMultiplexer::MuxSocket::OnWriteFailed() { |
+ if (!write_callback_.is_null()) { |
+ net::CompletionCallback cb; |
+ std::swap(cb, write_callback_); |
+ cb.Run(net::ERR_FAILED); |
+ } |
+} |
+ |
+void ChannelMultiplexer::MuxSocket::OnPacketReceived() { |
+ if (!read_callback_.is_null()) { |
+ int result = channel_->DoRead(read_buffer_, read_buffer_size_); |
+ read_buffer_ = NULL; |
+ DCHECK_GT(result, 0); |
+ net::CompletionCallback cb; |
+ std::swap(cb, read_callback_); |
+ cb.Run(result); |
+ } |
+} |
+ |
+ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, |
+ const std::string& base_channel_name) |
+ : base_channel_factory_(factory), |
+ base_channel_name_(base_channel_name), |
+ next_channel_id_(0), |
+ destroyed_flag_(NULL) { |
+ factory->CreateStreamChannel( |
+ base_channel_name, |
+ base::Bind(&ChannelMultiplexer::OnBaseChannelReady, |
+ base::Unretained(this))); |
+} |
+ |
+ChannelMultiplexer::~ChannelMultiplexer() { |
+ DCHECK(pending_channels_.empty()); |
+ STLDeleteValues(&channels_); |
+ |
+ // Cancel creation of the base channel if it hasn't finished. |
+ if (base_channel_factory_) |
+ base_channel_factory_->CancelChannelCreation(base_channel_name_); |
+ |
+ if (destroyed_flag_) |
+ *destroyed_flag_ = true; |
+} |
+ |
+void ChannelMultiplexer::CreateStreamChannel( |
+ const std::string& name, |
+ const StreamChannelCallback& callback) { |
+ if (base_channel_.get()) { |
+ // Already have |base_channel_|. Create new multiplexed channel |
+ // synchronously. |
+ callback.Run(GetOrCreateChannel(name)->CreateSocket()); |
+ } else if (!base_channel_.get() && !base_channel_factory_) { |
+ // Fail synchronously if we failed to create |base_channel_|. |
+ callback.Run(scoped_ptr<net::StreamSocket>()); |
+ } else { |
+ // Still waiting for the |base_channel_|. |
+ pending_channels_.push_back(PendingChannel(name, callback)); |
+ } |
+} |
+ |
+void ChannelMultiplexer::CreateDatagramChannel( |
+ const std::string& name, |
+ const DatagramChannelCallback& callback) { |
+ NOTIMPLEMENTED(); |
+ callback.Run(scoped_ptr<net::Socket>()); |
+} |
+ |
+void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { |
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
+ it != pending_channels_.end(); ++it) { |
+ if (it->name == name) { |
+ pending_channels_.erase(it); |
+ return; |
+ } |
+ } |
+} |
+ |
+void ChannelMultiplexer::OnBaseChannelReady( |
+ scoped_ptr<net::StreamSocket> socket) { |
+ base_channel_factory_ = NULL; |
+ base_channel_ = socket.Pass(); |
+ |
+ if (!base_channel_.get()) { |
+ // Notify all callers that we can't create any channels. |
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
+ it != pending_channels_.end(); ++it) { |
+ it->callback.Run(scoped_ptr<net::StreamSocket>()); |
+ } |
+ pending_channels_.clear(); |
+ return; |
+ } |
+ |
+ // Initialize reader and writer. |
+ reader_.Init(base_channel_.get(), |
+ base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
+ base::Unretained(this))); |
+ writer_.Init(base_channel_.get(), |
+ base::Bind(&ChannelMultiplexer::OnWriteFailed, |
+ base::Unretained(this))); |
+ |
+ // Now create all pending channels. |
+ for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
+ it != pending_channels_.end(); ++it) { |
+ it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); |
+ } |
+ pending_channels_.clear(); |
+} |
+ |
+ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( |
+ const std::string& name) { |
+ // Check if we already have a channel with the requested name. |
+ std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
+ if (it != channels_.end()) |
+ return it->second; |
+ |
+ // Create a new channel if we haven't found existing one. |
+ MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
+ ++next_channel_id_; |
+ channels_[channel->name()] = channel; |
+ return channel; |
+} |
+ |
+ |
+void ChannelMultiplexer::OnWriteFailed(int error) { |
+ bool destroyed = false; |
+ destroyed_flag_ = &destroyed; |
+ for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
+ it != channels_.end(); ++it) { |
+ it->second->OnWriteFailed(); |
+ if (destroyed) |
+ return; |
+ } |
+ destroyed_flag_ = NULL; |
+} |
+ |
+void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task) { |
+ if (!packet->has_channel_id()) { |
+ LOG(ERROR) << "Received packet without channel_id."; |
+ done_task.Run(); |
+ return; |
+ } |
+ |
+ int receive_id = packet->channel_id(); |
+ MuxChannel* channel = NULL; |
+ std::map<int, MuxChannel*>::iterator it = |
+ channels_by_receive_id_.find(receive_id); |
+ if (it != channels_by_receive_id_.end()) { |
+ channel = it->second; |
+ } else { |
+ // This is a new |channel_id| we haven't seen before. Look it up by name. |
+ if (!packet->has_channel_name()) { |
+ LOG(ERROR) << "Received packet with unknown channel_id and " |
+ "without channel_name."; |
+ done_task.Run(); |
+ return; |
+ } |
+ channel = GetOrCreateChannel(packet->channel_name()); |
+ channel->set_receive_id(receive_id); |
+ channels_by_receive_id_[receive_id] = channel; |
+ } |
+ |
+ channel->OnIncomingPacket(packet.Pass(), done_task); |
+} |
+ |
+bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, |
+ const base::Closure& done_task) { |
+ return writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
+} |
+ |
+} // namespace protocol |
+} // namespace remoting |