Index: content/browser/renderer_host/websocket_host.cc |
diff --git a/content/browser/renderer_host/websocket_host.cc b/content/browser/renderer_host/websocket_host.cc |
index d57814512e595d57696089aa6ffdc991b6bf20b1..88c333a77e98e779266cff49dfd21ff695b12bd3 100644 |
--- a/content/browser/renderer_host/websocket_host.cc |
+++ b/content/browser/renderer_host/websocket_host.cc |
@@ -4,20 +4,29 @@ |
#include "content/browser/renderer_host/websocket_host.h" |
+#include <inttypes.h> |
#include <utility> |
+#include "base/bind.h" |
+#include "base/bind_helpers.h" |
#include "base/location.h" |
+#include "base/logging.h" |
#include "base/macros.h" |
-#include "base/memory/weak_ptr.h" |
#include "base/single_thread_task_runner.h" |
#include "base/strings/string_util.h" |
+#include "base/strings/stringprintf.h" |
#include "base/thread_task_runner_handle.h" |
+#include "content/browser/bad_message.h" |
+#include "content/browser/renderer_host/websocket_blob_sender.h" |
#include "content/browser/renderer_host/websocket_dispatcher_host.h" |
#include "content/browser/ssl/ssl_error_handler.h" |
#include "content/browser/ssl/ssl_manager.h" |
#include "content/common/websocket_messages.h" |
+#include "content/public/browser/browser_thread.h" |
#include "content/public/browser/render_frame_host.h" |
+#include "content/public/browser/storage_partition.h" |
#include "ipc/ipc_message_macros.h" |
+#include "net/base/net_errors.h" |
#include "net/http/http_request_headers.h" |
#include "net/http/http_response_headers.h" |
#include "net/http/http_util.h" |
@@ -86,12 +95,41 @@ ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { |
return static_cast<ChannelState>(host_state); |
} |
+// Implementation of WebSocketBlobSender::Channel |
+class SendChannelImpl final : public WebSocketBlobSender::Channel { |
+ public: |
+ explicit SendChannelImpl(net::WebSocketChannel* channel) |
+ : channel_(channel) {} |
+ |
+ // Implementation of WebSocketBlobSender::Channel |
+ size_t GetSendQuota() const override { |
+ return static_cast<size_t>(channel_->current_send_quota()); |
+ } |
+ |
+ ChannelState SendFrame(bool fin, const std::vector<char>& data) override { |
+ int opcode = first_frame_ ? net::WebSocketFrameHeader::kOpCodeBinary |
+ : net::WebSocketFrameHeader::kOpCodeContinuation; |
+ first_frame_ = false; |
+ return channel_->SendFrame(fin, opcode, data); |
+ } |
+ |
+ private: |
+ net::WebSocketChannel* channel_; |
+ bool first_frame_ = true; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(SendChannelImpl); |
+}; |
+ |
+} // namespace |
+ |
// Implementation of net::WebSocketEventInterface. Receives events from our |
// WebSocketChannel object. Each event is translated to an IPC and sent to the |
// renderer or child process via WebSocketDispatcherHost. |
-class WebSocketEventHandler : public net::WebSocketEventInterface { |
+class WebSocketHost::WebSocketEventHandler final |
+ : public net::WebSocketEventInterface { |
public: |
WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, |
+ WebSocketHost* host, |
int routing_id, |
int render_frame_id); |
~WebSocketEventHandler() override; |
@@ -120,7 +158,7 @@ class WebSocketEventHandler : public net::WebSocketEventInterface { |
bool fatal) override; |
private: |
- class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate { |
+ class SSLErrorHandlerDelegate final : public SSLErrorHandler::Delegate { |
public: |
SSLErrorHandlerDelegate( |
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks); |
@@ -140,6 +178,7 @@ class WebSocketEventHandler : public net::WebSocketEventInterface { |
}; |
WebSocketDispatcherHost* const dispatcher_; |
+ WebSocketHost* const host_; |
const int routing_id_; |
const int render_frame_id_; |
scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; |
@@ -147,20 +186,21 @@ class WebSocketEventHandler : public net::WebSocketEventInterface { |
DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); |
}; |
-WebSocketEventHandler::WebSocketEventHandler( |
+WebSocketHost::WebSocketEventHandler::WebSocketEventHandler( |
WebSocketDispatcherHost* dispatcher, |
+ WebSocketHost* host, |
int routing_id, |
int render_frame_id) |
: dispatcher_(dispatcher), |
+ host_(host), |
routing_id_(routing_id), |
- render_frame_id_(render_frame_id) { |
-} |
+ render_frame_id_(render_frame_id) {} |
-WebSocketEventHandler::~WebSocketEventHandler() { |
+WebSocketHost::WebSocketEventHandler::~WebSocketEventHandler() { |
DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; |
} |
-ChannelState WebSocketEventHandler::OnAddChannelResponse( |
+ChannelState WebSocketHost::WebSocketEventHandler::OnAddChannelResponse( |
const std::string& selected_protocol, |
const std::string& extensions) { |
DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse" |
@@ -172,7 +212,7 @@ ChannelState WebSocketEventHandler::OnAddChannelResponse( |
routing_id_, selected_protocol, extensions)); |
} |
-ChannelState WebSocketEventHandler::OnDataFrame( |
+ChannelState WebSocketHost::WebSocketEventHandler::OnDataFrame( |
bool fin, |
net::WebSocketFrameHeader::OpCode type, |
const std::vector<char>& data) { |
@@ -180,27 +220,31 @@ ChannelState WebSocketEventHandler::OnDataFrame( |
<< " routing_id=" << routing_id_ << " fin=" << fin |
<< " type=" << type << " data is " << data.size() << " bytes"; |
- return StateCast(dispatcher_->SendFrame( |
- routing_id_, fin, OpCodeToMessageType(type), data)); |
+ return StateCast(dispatcher_->SendFrame(routing_id_, fin, |
+ OpCodeToMessageType(type), data)); |
} |
-ChannelState WebSocketEventHandler::OnClosingHandshake() { |
+ChannelState WebSocketHost::WebSocketEventHandler::OnClosingHandshake() { |
DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake" |
<< " routing_id=" << routing_id_; |
return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_)); |
} |
-ChannelState WebSocketEventHandler::OnFlowControl(int64_t quota) { |
+ChannelState WebSocketHost::WebSocketEventHandler::OnFlowControl( |
+ int64_t quota) { |
DVLOG(3) << "WebSocketEventHandler::OnFlowControl" |
<< " routing_id=" << routing_id_ << " quota=" << quota; |
+ if (host_->blob_sender_) |
+ host_->blob_sender_->OnNewSendQuota(); |
return StateCast(dispatcher_->SendFlowControl(routing_id_, quota)); |
} |
-ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean, |
- uint16_t code, |
- const std::string& reason) { |
+ChannelState WebSocketHost::WebSocketEventHandler::OnDropChannel( |
+ bool was_clean, |
+ uint16_t code, |
+ const std::string& reason) { |
DVLOG(3) << "WebSocketEventHandler::OnDropChannel" |
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean |
<< " code=" << code << " reason=\"" << reason << "\""; |
@@ -209,15 +253,15 @@ ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean, |
dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason)); |
} |
-ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) { |
+ChannelState WebSocketHost::WebSocketEventHandler::OnFailChannel( |
+ const std::string& message) { |
DVLOG(3) << "WebSocketEventHandler::OnFailChannel" |
- << " routing_id=" << routing_id_ |
- << " message=\"" << message << "\""; |
+ << " routing_id=" << routing_id_ << " message=\"" << message << "\""; |
return StateCast(dispatcher_->NotifyFailure(routing_id_, message)); |
} |
-ChannelState WebSocketEventHandler::OnStartOpeningHandshake( |
+ChannelState WebSocketHost::WebSocketEventHandler::OnStartOpeningHandshake( |
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) { |
bool should_send = dispatcher_->CanReadRawCookies(); |
DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake " |
@@ -237,11 +281,11 @@ ChannelState WebSocketEventHandler::OnStartOpeningHandshake( |
request->headers.ToString(); |
request_to_pass.request_time = request->request_time; |
- return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_, |
- request_to_pass)); |
+ return StateCast( |
+ dispatcher_->NotifyStartOpeningHandshake(routing_id_, request_to_pass)); |
} |
-ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( |
+ChannelState WebSocketHost::WebSocketEventHandler::OnFinishOpeningHandshake( |
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) { |
bool should_send = dispatcher_->CanReadRawCookies(); |
DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " |
@@ -263,11 +307,11 @@ ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( |
response->headers->raw_headers()); |
response_to_pass.response_time = response->response_time; |
- return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_, |
- response_to_pass)); |
+ return StateCast( |
+ dispatcher_->NotifyFinishOpeningHandshake(routing_id_, response_to_pass)); |
} |
-ChannelState WebSocketEventHandler::OnSSLCertificateError( |
+ChannelState WebSocketHost::WebSocketEventHandler::OnSSLCertificateError( |
scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, |
const GURL& url, |
const net::SSLInfo& ssl_info, |
@@ -284,20 +328,21 @@ ChannelState WebSocketEventHandler::OnSSLCertificateError( |
return WebSocketEventInterface::CHANNEL_ALIVE; |
} |
-WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate( |
- scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) |
+WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
+ SSLErrorHandlerDelegate( |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) |
: callbacks_(std::move(callbacks)), weak_ptr_factory_(this) {} |
-WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {} |
+WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
+ ~SSLErrorHandlerDelegate() {} |
base::WeakPtr<SSLErrorHandler::Delegate> |
-WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { |
+WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { |
return weak_ptr_factory_.GetWeakPtr(); |
} |
-void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest( |
- int error, |
- const net::SSLInfo* ssl_info) { |
+void WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
+ CancelSSLRequest(int error, const net::SSLInfo* ssl_info) { |
DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" |
<< " error=" << error |
<< " cert_status=" << (ssl_info ? ssl_info->cert_status |
@@ -305,13 +350,12 @@ void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest( |
callbacks_->CancelSSLRequest(error, ssl_info); |
} |
-void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest() { |
+void WebSocketHost::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
+ ContinueSSLRequest() { |
DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; |
callbacks_->ContinueSSLRequest(); |
} |
-} // namespace |
- |
WebSocketHost::WebSocketHost(int routing_id, |
WebSocketDispatcherHost* dispatcher, |
net::URLRequestContext* url_request_context, |
@@ -337,6 +381,7 @@ bool WebSocketHost::OnMessageReceived(const IPC::Message& message) { |
bool handled = true; |
IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message) |
IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest) |
+ IPC_MESSAGE_HANDLER(WebSocketHostMsg_SendBlob, OnSendBlob) |
IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame) |
IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl) |
IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel) |
@@ -383,7 +428,8 @@ void WebSocketHost::AddChannel( |
DCHECK(!channel_); |
scoped_ptr<net::WebSocketEventInterface> event_interface( |
- new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id)); |
+ new WebSocketEventHandler(dispatcher_, this, routing_id_, |
+ render_frame_id)); |
channel_.reset(new net::WebSocketChannel(std::move(event_interface), |
url_request_context_)); |
@@ -404,6 +450,41 @@ void WebSocketHost::AddChannel( |
// |this| may have been deleted here. |
} |
+void WebSocketHost::OnSendBlob(const std::string& uuid, |
+ uint64_t expected_size) { |
+ DVLOG(3) << "WebSocketHost::OnSendBlob" |
+ << " routing_id=" << routing_id_ << " uuid=" << uuid |
+ << " expected_size=" << expected_size; |
+ |
+ DCHECK(channel_); |
+ if (blob_sender_) { |
+ bad_message::ReceivedBadMessage( |
+ dispatcher_, bad_message::WSH_SEND_BLOB_DURING_BLOB_SEND); |
+ return; |
+ } |
+ blob_sender_.reset(new WebSocketBlobSender( |
+ make_scoped_ptr(new SendChannelImpl(channel_.get())))); |
+ StoragePartition* partition = dispatcher_->storage_partition(); |
+ storage::FileSystemContext* file_system_context = |
+ partition->GetFileSystemContext(); |
+ |
+ net::WebSocketEventInterface::ChannelState channel_state = |
+ net::WebSocketEventInterface::CHANNEL_ALIVE; |
+ |
+ // This use of base::Unretained is safe because the WebSocketBlobSender object |
+ // is owned by this object and will not call it back after destruction. |
+ int rv = blob_sender_->Start( |
+ uuid, expected_size, dispatcher_->blob_storage_context(), |
+ file_system_context, |
+ BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE).get(), |
+ &channel_state, |
+ base::Bind(&WebSocketHost::BlobSendComplete, base::Unretained(this))); |
+ if (channel_state == net::WebSocketEventInterface::CHANNEL_ALIVE && |
+ rv != net::ERR_IO_PENDING) |
+ BlobSendComplete(rv); |
+ // |this| may be destroyed here. |
+} |
+ |
void WebSocketHost::OnSendFrame(bool fin, |
WebSocketMessageType type, |
const std::vector<char>& data) { |
@@ -412,6 +493,11 @@ void WebSocketHost::OnSendFrame(bool fin, |
<< " type=" << type << " data is " << data.size() << " bytes"; |
DCHECK(channel_); |
+ if (blob_sender_) { |
+ bad_message::ReceivedBadMessage( |
+ dispatcher_, bad_message::WSH_SEND_FRAME_DURING_BLOB_SEND); |
+ return; |
+ } |
channel_->SendFrame(fin, MessageTypeToOpCode(type), data); |
} |
@@ -441,16 +527,52 @@ void WebSocketHost::OnDropChannel(bool was_clean, |
// WebSocketChannel is not yet created due to the delay introduced by |
// per-renderer WebSocket throttling. |
WebSocketDispatcherHost::WebSocketHostState result = |
- dispatcher_->DoDropChannel(routing_id_, |
- false, |
- net::kWebSocketErrorAbnormalClosure, |
- ""); |
+ dispatcher_->DoDropChannel(routing_id_, false, |
+ net::kWebSocketErrorAbnormalClosure, ""); |
DCHECK_EQ(WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED, result); |
return; |
} |
+ blob_sender_.reset(); |
// TODO(yhirano): Handle |was_clean| appropriately. |
channel_->StartClosingHandshake(code, reason); |
} |
+void WebSocketHost::BlobSendComplete(int result) { |
+ DVLOG(3) << "WebSocketHost::BlobSendComplete" |
+ << " routing_id=" << routing_id_ |
+ << " result=" << net::ErrorToString(result); |
+ |
+ // All paths through this method must reset blob_sender_, so take ownership |
+ // at the beginning. |
+ scoped_ptr<WebSocketBlobSender> blob_sender(std::move(blob_sender_)); |
+ switch (result) { |
+ case net::OK: |
+ ignore_result(dispatcher_->BlobSendComplete(routing_id_)); |
+ // |this| may be destroyed here. |
+ return; |
+ |
+ case net::ERR_UPLOAD_FILE_CHANGED: { |
+ uint64_t expected_size = blob_sender->expected_size(); |
+ uint64_t actual_size = blob_sender->ActualSize(); |
+ if (expected_size != actual_size) { |
+ ignore_result(dispatcher_->NotifyFailure( |
+ routing_id_, |
+ base::StringPrintf("Blob size mismatch; renderer size = %" PRIu64 |
+ ", browser size = %" PRIu64, |
+ expected_size, actual_size))); |
+ // |this| is destroyed here. |
+ return; |
+ } // else fallthrough |
+ } |
+ |
+ default: |
+ ignore_result(dispatcher_->NotifyFailure( |
+ routing_id_, |
+ "Failed to load Blob: error code = " + net::ErrorToString(result))); |
+ // |this| is destroyed here. |
+ return; |
+ } |
+} |
+ |
} // namespace content |