Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(256)

Side by Side Diff: net/socket/web_socket_server_socket.cc

Issue 10894004: Remove WebSocketServerSocket. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Remove LIMIT_VIOLATION as well. Created 8 years, 3 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
« no previous file with comments | « net/socket/web_socket_server_socket.h ('k') | net/socket/web_socket_server_socket_unittest.cc » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
(Empty)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/web_socket_server_socket.h"
6
7 #include <algorithm>
8 #include <deque>
9 #include <limits>
10 #include <map>
11 #include <vector>
12
13 #include "base/basictypes.h"
14 #include "base/bind.h"
15 #include "base/bind_helpers.h"
16 #include "base/logging.h"
17 #include "base/md5.h"
18 #include "base/memory/ref_counted.h"
19 #include "base/memory/scoped_ptr.h"
20 #include "base/memory/weak_ptr.h"
21 #include "base/message_loop.h"
22 #include "base/string_util.h"
23 #include "base/sys_byteorder.h"
24 #include "googleurl/src/gurl.h"
25 #include "net/base/completion_callback.h"
26 #include "net/base/io_buffer.h"
27 #include "net/base/net_errors.h"
28
29 namespace {
30
31 const size_t kHandshakeLimitBytes = 1 << 14;
32
33 const char kCrOctet = '\r';
34 COMPILE_ASSERT(kCrOctet == '\x0d', ASCII);
35 const char kLfOctet = '\n';
36 COMPILE_ASSERT(kLfOctet == '\x0a', ASCII);
37 const char kSpaceOctet = ' ';
38 COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII);
39 const char kCommaOctet = ',';
40 COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII);
41
42 const char kCRLF[] = { kCrOctet, kLfOctet, 0 };
43 const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 };
44
45 const char kPlainHostFieldName[] = "Host";
46 const char kPlainOriginFieldName[] = "Origin";
47 const char kOriginFieldName[] = "Sec-WebSocket-Origin";
48 const char kProtocolFieldName[] = "Sec-WebSocket-Protocol";
49 const char kVersionFieldName[] = "Sec-WebSocket-Version";
50 const char kLocationFieldName[] = "Sec-WebSocket-Location";
51 const char kKey1FieldName[] = "Sec-WebSocket-Key1";
52 const char kKey2FieldName[] = "Sec-WebSocket-Key2";
53
54 int CountSpaces(const std::string& s) {
55 return std::count(s.begin(), s.end(), kSpaceOctet);
56 }
57
58 // Returns true on success.
59 bool FetchDecimalDigits(const std::string& s, uint32* result) {
60 *result = 0;
61 bool got_something = false;
62 for (size_t i = 0; i < s.size(); ++i) {
63 if (IsAsciiDigit(s[i])) {
64 got_something = true;
65 if (*result > std::numeric_limits<uint32>::max() / 10)
66 return false;
67 *result *= 10;
68 int digit = s[i] - '0';
69 if (*result > std::numeric_limits<uint32>::max() - digit)
70 return false;
71 *result += digit;
72 }
73 }
74 return got_something;
75 }
76
77 // Returns number of fetched subprotocols or negative error code.
78 int FetchSubprotocolList(
79 const std::string& s, std::vector<std::string>* subprotocol_list) {
80 subprotocol_list->clear();
81 subprotocol_list->push_back(std::string());
82 for (size_t i = 0; i < s.size(); ++i) {
83 if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet)
84 subprotocol_list->back() += s[i];
85 else if (!subprotocol_list->back().empty()) {
86 if (subprotocol_list->size() < 16)
87 subprotocol_list->push_back(std::string());
88 else
89 return net::ERR_LIMIT_VIOLATION;
90 }
91 }
92 if (subprotocol_list->back().empty())
93 subprotocol_list->pop_back();
94 if (subprotocol_list->empty())
95 return net::ERR_WS_PROTOCOL_ERROR;
96
97 {
98 std::vector<std::string> tmp(*subprotocol_list);
99 std::sort(tmp.begin(), tmp.end());
100 if (tmp.end() != std::unique(tmp.begin(), tmp.end()))
101 return net::ERR_WS_PROTOCOL_ERROR;
102 }
103 return subprotocol_list->size();
104 }
105
106 class WebSocketServerSocketImpl : public net::WebSocketServerSocket {
107 public:
108 WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate)
109 : phase_(PHASE_NYMPH),
110 frame_bytes_remaining_(0),
111 transport_socket_(transport_socket),
112 delegate_(delegate),
113 handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)),
114 fill_handshake_buf_(new net::DrainableIOBuffer(
115 handshake_buf_, kHandshakeLimitBytes)),
116 process_handshake_buf_(new net::DrainableIOBuffer(
117 handshake_buf_, kHandshakeLimitBytes)),
118 is_transport_read_pending_(false),
119 is_transport_write_pending_(false),
120 weak_factory_(this) {
121 DCHECK(transport_socket);
122 DCHECK(delegate);
123 }
124
125 virtual ~WebSocketServerSocketImpl() {
126 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
127 if (it != pending_reqs_.end() &&
128 it->type == PendingReq::TYPE_READ &&
129 it->io_buf != NULL &&
130 it->io_buf->data() != NULL &&
131 !it->callback.is_null()) {
132 it->callback.Run(0); // Report EOF.
133 }
134 }
135
136 private:
137 enum Phase {
138 // Before Accept() is called.
139 PHASE_NYMPH,
140
141 // After Accept() is called and until handshake success/fail.
142 PHASE_HANDSHAKE,
143
144 // Processing data stream.
145 PHASE_FRAME_OUTSIDE, // Outside data frame.
146 PHASE_FRAME_INSIDE, // Inside text frame.
147 PHASE_FRAME_LENGTH, // Reading length of binary frame.
148 PHASE_FRAME_SKIP, // Skipping binary frame.
149
150 // After termination.
151 PHASE_SHUT
152 };
153
154 struct PendingReq {
155 enum Type {
156 // Frame delimiters or handshake (as opposed to user data).
157 TYPE_METADATA = 1 << 0,
158 // Read request.
159 TYPE_READ = 1 << 1,
160 // Write request.
161 TYPE_WRITE = 1 << 2,
162
163 TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA,
164 TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA
165 };
166
167 PendingReq(Type type, net::DrainableIOBuffer* io_buf,
168 const net::CompletionCallback& callback)
169 : type(type),
170 io_buf(io_buf),
171 callback(callback) {
172 switch (type) {
173 case PendingReq::TYPE_READ:
174 case PendingReq::TYPE_WRITE:
175 case PendingReq::TYPE_READ_METADATA:
176 case PendingReq::TYPE_WRITE_METADATA: {
177 DCHECK(io_buf);
178 break;
179 }
180 default: {
181 NOTREACHED();
182 break;
183 }
184 }
185 }
186
187 Type type;
188 scoped_refptr<net::DrainableIOBuffer> io_buf;
189 net::CompletionCallback callback;
190 };
191
192 // Socket implementation.
193 virtual int Read(net::IOBuffer* buf, int buf_len,
194 const net::CompletionCallback& callback) OVERRIDE {
195 if (buf_len == 0)
196 return 0;
197 if (buf == NULL || buf_len < 0) {
198 NOTREACHED();
199 return net::ERR_INVALID_ARGUMENT;
200 }
201 while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() -
202 process_handshake_buf_->BytesConsumed()) {
203 DCHECK(!is_transport_read_pending_);
204 DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end());
205 switch (phase_) {
206 case PHASE_FRAME_OUTSIDE:
207 case PHASE_FRAME_INSIDE:
208 case PHASE_FRAME_LENGTH:
209 case PHASE_FRAME_SKIP: {
210 int n = std::min(bytes_remaining, buf_len);
211 int rv = ProcessDataFrames(
212 process_handshake_buf_->data(), n, buf->data(), buf_len);
213 process_handshake_buf_->DidConsume(n);
214 if (rv == 0) {
215 // ProcessDataFrames may return zero for non-empty buffer if it
216 // contains only frame delimiters without real data. In this case:
217 // try again and do not just return zero (zero stands for EOF).
218 continue;
219 }
220 return rv;
221 }
222 case PHASE_SHUT: {
223 return 0;
224 }
225 case PHASE_NYMPH:
226 case PHASE_HANDSHAKE:
227 default: {
228 NOTREACHED();
229 return net::ERR_UNEXPECTED;
230 }
231 }
232 }
233 switch (phase_) {
234 case PHASE_FRAME_OUTSIDE:
235 case PHASE_FRAME_INSIDE:
236 case PHASE_FRAME_LENGTH:
237 case PHASE_FRAME_SKIP: {
238 pending_reqs_.push_back(PendingReq(
239 PendingReq::TYPE_READ,
240 new net::DrainableIOBuffer(buf, buf_len),
241 callback));
242 ConsiderTransportRead();
243 break;
244 }
245 case PHASE_SHUT: {
246 return 0;
247 }
248 case PHASE_NYMPH:
249 case PHASE_HANDSHAKE:
250 default: {
251 NOTREACHED();
252 return net::ERR_UNEXPECTED;
253 }
254 }
255 return net::ERR_IO_PENDING;
256 }
257
258 virtual int Write(net::IOBuffer* buf, int buf_len,
259 const net::CompletionCallback& callback) OVERRIDE {
260 if (buf_len == 0)
261 return 0;
262 if (buf == NULL || buf_len < 0) {
263 NOTREACHED();
264 return net::ERR_INVALID_ARGUMENT;
265 }
266 DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'),
267 buf->data() + buf_len);
268 switch (phase_) {
269 case PHASE_FRAME_OUTSIDE:
270 case PHASE_FRAME_INSIDE:
271 case PHASE_FRAME_LENGTH:
272 case PHASE_FRAME_SKIP: {
273 break;
274 }
275 case PHASE_SHUT: {
276 return net::ERR_SOCKET_NOT_CONNECTED;
277 }
278 case PHASE_NYMPH:
279 case PHASE_HANDSHAKE:
280 default: {
281 NOTREACHED();
282 return net::ERR_UNEXPECTED;
283 }
284 }
285
286 net::IOBuffer* frame_start = new net::IOBuffer(1);
287 frame_start->data()[0] = '\x00';
288 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA,
289 new net::DrainableIOBuffer(frame_start, 1),
290 net::CompletionCallback()));
291
292 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE,
293 new net::DrainableIOBuffer(buf, buf_len),
294 callback));
295
296 net::IOBuffer* frame_end = new net::IOBuffer(1);
297 frame_end->data()[0] = '\xff';
298 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA,
299 new net::DrainableIOBuffer(frame_end, 1),
300 net::CompletionCallback()));
301
302 ConsiderTransportWrite();
303 return net::ERR_IO_PENDING;
304 }
305
306 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
307 return transport_socket_->SetReceiveBufferSize(size);
308 }
309
310 virtual bool SetSendBufferSize(int32 size) OVERRIDE {
311 return transport_socket_->SetSendBufferSize(size);
312 }
313
314 // WebSocketServerSocket implementation.
315 virtual int Accept(const net::CompletionCallback& callback) OVERRIDE {
316 if (phase_ != PHASE_NYMPH)
317 return net::ERR_UNEXPECTED;
318 phase_ = PHASE_HANDSHAKE;
319 pending_reqs_.push_front(PendingReq(
320 PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback));
321 ConsiderTransportRead();
322 return net::ERR_IO_PENDING;
323 }
324
325 std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) {
326 for (std::deque<PendingReq>::iterator it = pending_reqs_.begin();
327 it != pending_reqs_.end(); ++it) {
328 if (it->type & type)
329 return it;
330 }
331 return pending_reqs_.end();
332 }
333
334 void ConsiderTransportRead() {
335 if (pending_reqs_.empty())
336 return;
337 if (is_transport_read_pending_)
338 return;
339 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
340 if (it == pending_reqs_.end())
341 return;
342 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) {
343 NOTREACHED();
344 return;
345 }
346 is_transport_read_pending_ = true;
347 int rv = transport_socket_->Read(
348 it->io_buf.get(), it->io_buf->BytesRemaining(),
349 base::Bind(&WebSocketServerSocketImpl::OnRead,
350 base::Unretained(this)));
351 if (rv != net::ERR_IO_PENDING) {
352 // PostTask rather than direct call in order to:
353 // (1) guarantee calling callback after returning from Read();
354 // (2) avoid potential stack overflow;
355 MessageLoop::current()->PostTask(
356 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnRead,
357 weak_factory_.GetWeakPtr(), rv));
358 }
359 }
360
361 void ConsiderTransportWrite() {
362 if (is_transport_write_pending_)
363 return;
364 if (pending_reqs_.empty())
365 return;
366 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE);
367 if (it == pending_reqs_.end())
368 return;
369 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) {
370 NOTREACHED();
371 Shut(net::ERR_UNEXPECTED);
372 return;
373 }
374 is_transport_write_pending_ = true;
375 int rv = transport_socket_->Write(
376 it->io_buf.get(), it->io_buf->BytesRemaining(),
377 base::Bind(&WebSocketServerSocketImpl::OnWrite,
378 base::Unretained(this)));
379 if (rv != net::ERR_IO_PENDING) {
380 // PostTask rather than direct call in order to:
381 // (1) guarantee calling callback after returning from Read();
382 // (2) avoid potential stack overflow;
383 MessageLoop::current()->PostTask(
384 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnWrite,
385 weak_factory_.GetWeakPtr(), rv));
386 }
387 }
388
389 void Shut(int result) {
390 if (result > 0 || result == net::ERR_IO_PENDING)
391 result = net::ERR_UNEXPECTED;
392 if (result != 0) {
393 while (!pending_reqs_.empty()) {
394 PendingReq& req = pending_reqs_.front();
395 if (!req.callback.is_null())
396 req.callback.Run(result);
397 pending_reqs_.pop_front();
398 }
399 transport_socket_.reset(); // terminate underlying connection.
400 }
401 phase_ = PHASE_SHUT;
402 }
403
404 // Callbacks for transport socket.
405 void OnRead(int result) {
406 if (!is_transport_read_pending_) {
407 NOTREACHED();
408 Shut(net::ERR_UNEXPECTED);
409 return;
410 }
411 is_transport_read_pending_ = false;
412
413 if (result <= 0) {
414 Shut(result);
415 return;
416 }
417
418 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ);
419 if (it == pending_reqs_.end() ||
420 it->io_buf == NULL ||
421 it->io_buf->data() == NULL) {
422 NOTREACHED();
423 Shut(net::ERR_UNEXPECTED);
424 return;
425 }
426 if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) {
427 NOTREACHED();
428 Shut(net::ERR_UNEXPECTED);
429 return;
430 }
431
432 switch (phase_) {
433 case PHASE_HANDSHAKE: {
434 if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) {
435 NOTREACHED();
436 Shut(net::ERR_UNEXPECTED);
437 return;
438 }
439 fill_handshake_buf_->DidConsume(result);
440 // ProcessHandshake invalidates iterators for |pending_reqs_|
441 int rv = ProcessHandshake();
442 if (rv > 0) {
443 process_handshake_buf_->DidConsume(rv);
444 phase_ = PHASE_FRAME_OUTSIDE;
445 net::CompletionCallback cb = pending_reqs_.front().callback;
446 pending_reqs_.pop_front();
447 ConsiderTransportWrite(); // Schedule answer handshake.
448 if (!cb.is_null())
449 cb.Run(0);
450 } else if (rv == net::ERR_IO_PENDING) {
451 if (fill_handshake_buf_->BytesRemaining() < 1)
452 Shut(net::ERR_LIMIT_VIOLATION);
453 } else if (rv < 0) {
454 Shut(rv);
455 } else {
456 Shut(net::ERR_UNEXPECTED);
457 }
458 break;
459 }
460 case PHASE_FRAME_OUTSIDE:
461 case PHASE_FRAME_INSIDE:
462 case PHASE_FRAME_LENGTH:
463 case PHASE_FRAME_SKIP: {
464 int rv = ProcessDataFrames(
465 it->io_buf->data(), result,
466 it->io_buf->data(), it->io_buf->BytesRemaining());
467 if (rv < 0) {
468 Shut(rv);
469 return;
470 }
471 if (rv > 0 || phase_ == PHASE_SHUT) {
472 net::CompletionCallback cb = it->callback;
473 pending_reqs_.erase(it);
474 if (!cb.is_null())
475 cb.Run(rv);
476 }
477 break;
478 }
479 case PHASE_NYMPH:
480 default: {
481 NOTREACHED();
482 Shut(net::ERR_UNEXPECTED);
483 break;
484 }
485 }
486 ConsiderTransportRead();
487 }
488
489 void OnWrite(int result) {
490 if (!is_transport_write_pending_) {
491 NOTREACHED();
492 Shut(net::ERR_UNEXPECTED);
493 return;
494 }
495 is_transport_write_pending_ = false;
496
497 if (result < 0) {
498 Shut(result);
499 return;
500 }
501
502 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE);
503 if (it == pending_reqs_.end() ||
504 it->io_buf == NULL ||
505 it->io_buf->data() == NULL) {
506 NOTREACHED();
507 Shut(net::ERR_UNEXPECTED);
508 return;
509 }
510 DCHECK_LE(result, it->io_buf->BytesRemaining());
511 it->io_buf->DidConsume(result);
512 if (it->io_buf->BytesRemaining() == 0) {
513 net::CompletionCallback cb = it->callback;
514 int bytes_written = it->io_buf->BytesConsumed();
515 DCHECK_GT(bytes_written, 0);
516 pending_reqs_.erase(it);
517 if (!cb.is_null())
518 cb.Run(bytes_written);
519 }
520 ConsiderTransportWrite();
521 }
522
523 // Returns (positive) number of consumed bytes on success.
524 // Returns ERR_IO_PENDING in case of incomplete input.
525 // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to
526 // reasonably parse input.
527 int ProcessHandshake() {
528 static const char kGetPrefix[] = "GET ";
529 static const char kKeyValueDelimiter[] = ": ";
530
531 class Fields {
532 public:
533 bool Has(const std::string& name) {
534 return map_.find(StringToLowerASCII(name)) != map_.end();
535 }
536
537 std::string Get(const std::string& name) {
538 return Has(name) ? map_[StringToLowerASCII(name)] : std::string();
539 }
540
541 void Set(const std::string& name, const std::string& value) {
542 map_[StringToLowerASCII(name)] = StringToLowerASCII(value);
543 }
544
545 private:
546 std::map<std::string, std::string> map_;
547 } fields;
548
549 char* buf = process_handshake_buf_->data();
550 size_t buf_size = fill_handshake_buf_->BytesConsumed();
551
552 if (buf_size < 1)
553 return net::ERR_IO_PENDING;
554 if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)),
555 kGetPrefix)) {
556 // Data head does not match what is expected.
557 return net::ERR_WS_PROTOCOL_ERROR;
558 }
559 if (buf_size >= kHandshakeLimitBytes)
560 return net::ERR_LIMIT_VIOLATION;
561 char* buf_end = buf + buf_size;
562
563 if (buf_size < strlen(kGetPrefix))
564 return net::ERR_IO_PENDING;
565 char* resource_begin = buf + strlen(kGetPrefix);
566 char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet);
567 if (resource_end == buf_end)
568 return net::ERR_IO_PENDING;
569 std::string resource(resource_begin, resource_end);
570 if (!IsStringUTF8(resource) ||
571 resource.find_first_of(kCRLF) != std::string::npos) {
572 return net::ERR_WS_PROTOCOL_ERROR;
573 }
574 char* term_pos = std::search(
575 buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF));
576 char key3[8]; // Notation (key3) matches websocket RFC.
577 size_t message_len = buf_end - term_pos;
578 if (message_len < sizeof(key3) + strlen(kCRLFCRLF))
579 return net::ERR_IO_PENDING;
580 term_pos += strlen(kCRLFCRLF);
581 memcpy(key3, term_pos, sizeof(key3));
582 term_pos += sizeof(key3);
583 // First line is "GET resource" line, so skip it.
584 char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF));
585 if (pos == term_pos)
586 return net::ERR_WS_PROTOCOL_ERROR;
587 for (;;) {
588 pos += strlen(kCRLF);
589 if (term_pos - pos <
590 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) {
591 return net::ERR_WS_PROTOCOL_ERROR;
592 }
593 if (term_pos - pos ==
594 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) {
595 break;
596 }
597 char* next_pos = std::search(
598 pos, term_pos, kKeyValueDelimiter,
599 kKeyValueDelimiter + strlen(kKeyValueDelimiter));
600 if (next_pos == term_pos)
601 return net::ERR_WS_PROTOCOL_ERROR;
602 std::string key(pos, next_pos);
603 if (!IsStringASCII(key) ||
604 key.find_first_of(kCRLF) != std::string::npos) {
605 return net::ERR_WS_PROTOCOL_ERROR;
606 }
607 pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos,
608 kCRLF, kCRLF + strlen(kCRLF));
609 if (pos == term_pos)
610 return net::ERR_WS_PROTOCOL_ERROR;
611 if (!key.empty()) {
612 std::string value(next_pos, pos);
613 if (!IsStringASCII(value) ||
614 value.find_first_of(kCRLF) != std::string::npos) {
615 return net::ERR_WS_PROTOCOL_ERROR;
616 }
617 fields.Set(key, value);
618 }
619 }
620
621 // Values of Upgrade and Connection fields are hardcoded in the protocol.
622 if (fields.Get("Upgrade") != "websocket" ||
623 fields.Get("Connection") != "upgrade") {
624 return net::ERR_WS_PROTOCOL_ERROR;
625 }
626 if (fields.Has(kVersionFieldName)) {
627 NOTIMPLEMENTED(); // new protocol.
628 return net::ERR_NOT_IMPLEMENTED;
629 }
630
631 if (!fields.Has(kPlainOriginFieldName))
632 return net::ERR_CONNECTION_REFUSED;
633 // Normalize (e.g. w.r.t. leading slashes) origin.
634 GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin();
635 if (!origin.is_valid())
636 return net::ERR_WS_PROTOCOL_ERROR;
637 std::string normalized_origin = origin.spec();
638
639 if (!fields.Has(kPlainHostFieldName))
640 return net::ERR_CONNECTION_REFUSED;
641
642 std::vector<std::string> subprotocol_list;
643 if (fields.Has(kProtocolFieldName)) {
644 int rv = FetchSubprotocolList(
645 fields.Get(kProtocolFieldName), &subprotocol_list);
646 if (rv < 0)
647 return rv;
648 DCHECK(subprotocol_list.end() == std::find(
649 subprotocol_list.begin(), subprotocol_list.end(), ""));
650 }
651
652 std::string location;
653 std::string subprotocol;
654 if (!delegate_->ValidateWebSocket(resource,
655 normalized_origin,
656 fields.Get(kPlainHostFieldName),
657 subprotocol_list,
658 &location,
659 &subprotocol)) {
660 return net::ERR_CONNECTION_REFUSED;
661 }
662 if (subprotocol_list.empty()) {
663 DCHECK(subprotocol.empty());
664 } else {
665 if (!subprotocol.empty()) {
666 if (subprotocol_list.end() == std::find(
667 subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) {
668 NOTREACHED() << "delegate must pick subprotocol from given list";
669 return net::ERR_UNEXPECTED;
670 }
671 }
672 }
673
674 uint32 key_number1 = 0;
675 uint32 key_number2 = 0;
676 if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) ||
677 !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) {
678 return net::ERR_WS_PROTOCOL_ERROR;
679 }
680
681 // We limit incoming header size so following numbers shall not be too high.
682 int spaces1 = CountSpaces(fields.Get(kKey1FieldName));
683 int spaces2 = CountSpaces(fields.Get(kKey2FieldName));
684 if (spaces1 == 0 ||
685 spaces2 == 0 ||
686 key_number1 % spaces1 != 0 ||
687 key_number2 % spaces2 != 0) {
688 return net::ERR_WS_PROTOCOL_ERROR;
689 }
690
691 char challenge[4 + 4 + sizeof(key3)];
692 int32 part1 = base::HostToNet32(key_number1 / spaces1);
693 int32 part2 = base::HostToNet32(key_number2 / spaces2);
694 memcpy(challenge, &part1, 4);
695 memcpy(challenge + 4, &part2, 4);
696 memcpy(challenge + 4 + 4, key3, sizeof(key3));
697 base::MD5Digest challenge_response;
698 base::MD5Sum(challenge, sizeof(challenge), &challenge_response);
699
700 // Concocting response handshake.
701 class Buffer {
702 public:
703 Buffer()
704 : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)),
705 bytes_written_(0),
706 is_ok_(true) {
707 }
708
709 bool Write(const void* p, int len) {
710 DCHECK(p);
711 DCHECK_GE(len, 0);
712 if (!is_ok_)
713 return false;
714 if (bytes_written_ + len > kHandshakeLimitBytes) {
715 NOTREACHED();
716 is_ok_ = false;
717 return false;
718 }
719 memcpy(io_buf_->data() + bytes_written_, p, len);
720 bytes_written_ += len;
721 return true;
722 }
723
724 bool WriteLine(const char* p) {
725 return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF));
726 }
727
728 operator net::DrainableIOBuffer*() {
729 return new net::DrainableIOBuffer(io_buf_.get(), bytes_written_);
730 }
731
732 bool is_ok() { return is_ok_; }
733
734 private:
735 scoped_refptr<net::IOBuffer> io_buf_;
736 size_t bytes_written_;
737 bool is_ok_;
738 } buffer;
739
740 buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake");
741 buffer.WriteLine("Upgrade: WebSocket");
742 buffer.WriteLine("Connection: Upgrade");
743
744 {
745 // Take care of Location field.
746 char tmp[2048];
747 int rv = base::snprintf(tmp, sizeof(tmp),
748 "%s: %s",
749 kLocationFieldName,
750 location.c_str());
751 if (rv <= 0 || rv + 0u >= sizeof(tmp))
752 return net::ERR_LIMIT_VIOLATION;
753 buffer.WriteLine(tmp);
754 }
755 {
756 // Take care of Origin field.
757 char tmp[2048];
758 int rv = base::snprintf(tmp, sizeof(tmp),
759 "%s: %s",
760 kOriginFieldName,
761 fields.Get(kPlainOriginFieldName).c_str());
762 if (rv <= 0 || rv + 0u >= sizeof(tmp))
763 return net::ERR_LIMIT_VIOLATION;
764 buffer.WriteLine(tmp);
765 }
766 if (!subprotocol.empty()) {
767 char tmp[2048];
768 int rv = base::snprintf(tmp, sizeof(tmp),
769 "%s: %s",
770 kProtocolFieldName,
771 subprotocol.c_str());
772 if (rv <= 0 || rv + 0u >= sizeof(tmp))
773 return net::ERR_LIMIT_VIOLATION;
774 buffer.WriteLine(tmp);
775 }
776 buffer.WriteLine("");
777 buffer.Write(&challenge_response, sizeof(challenge_response));
778
779 if (!buffer.is_ok())
780 return net::ERR_LIMIT_VIOLATION;
781
782 pending_reqs_.push_back(PendingReq(
783 PendingReq::TYPE_WRITE_METADATA, buffer, net::CompletionCallback()));
784 DCHECK_GT(term_pos - buf, 0);
785 return term_pos - buf;
786 }
787
788 // Removes frame delimiters and returns net number of data bytes (or error).
789 // |out| may be equal to |buf|, in that case it is in-place operation.
790 int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) {
791 if (out_len < buf_len) {
792 NOTREACHED();
793 return net::ERR_UNEXPECTED;
794 }
795 int out_pos = 0;
796 for (char* p = buf; p < buf + buf_len; ++p) {
797 switch (phase_) {
798 case PHASE_FRAME_INSIDE: {
799 if (*p == '\x00')
800 return net::ERR_WS_PROTOCOL_ERROR;
801 if (*p == '\xff')
802 phase_ = PHASE_FRAME_OUTSIDE;
803 else
804 out[out_pos++] = *p;
805 break;
806 }
807 case PHASE_FRAME_OUTSIDE: {
808 if (*p == '\x00') {
809 phase_ = PHASE_FRAME_INSIDE;
810 } else if (*p == '\xff') {
811 phase_ = PHASE_FRAME_LENGTH;
812 frame_bytes_remaining_ = 0;
813 }
814 else {
815 return net::ERR_WS_PROTOCOL_ERROR;
816 }
817 break;
818 }
819 case PHASE_FRAME_LENGTH: {
820 static const int kValueBits = 7;
821 static const char kValueMask = (1 << kValueBits) - 1;
822 frame_bytes_remaining_ <<= kValueBits;
823 frame_bytes_remaining_ += (*p & kValueMask);
824 if (*p & ~kValueMask) {
825 // Check that next byte would not overflow.
826 if (frame_bytes_remaining_ >
827 (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) {
828 return net::ERR_LIMIT_VIOLATION;
829 }
830 } else {
831 if (frame_bytes_remaining_ == 0) {
832 phase_ = PHASE_SHUT;
833 return out_pos;
834 } else {
835 phase_ = PHASE_FRAME_SKIP;
836 }
837 }
838 break;
839 }
840 case PHASE_FRAME_SKIP: {
841 DCHECK_GE(frame_bytes_remaining_, 1);
842 frame_bytes_remaining_ -= 1;
843 if (frame_bytes_remaining_ < 1)
844 phase_ = PHASE_FRAME_OUTSIDE;
845 break;
846 }
847 default: {
848 NOTREACHED();
849 }
850 }
851 }
852 return out_pos;
853 }
854
855 // State machinery.
856 Phase phase_;
857
858 // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP.
859 int frame_bytes_remaining_;
860
861 // Underlying socket.
862 scoped_ptr<net::Socket> transport_socket_;
863
864 // Validation is performed via delegate.
865 Delegate* delegate_;
866
867 // IOBuffer used to communicate with transport at initial stage.
868 scoped_refptr<net::IOBuffer> handshake_buf_;
869 scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_;
870 scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_;
871
872 // Pending IO requests we need to complete.
873 std::deque<PendingReq> pending_reqs_;
874
875 // Whether transport requests are pending.
876 bool is_transport_read_pending_;
877 bool is_transport_write_pending_;
878
879 base::WeakPtrFactory<WebSocketServerSocketImpl> weak_factory_;
880
881 DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl);
882 };
883
884 } // namespace
885
886 namespace net {
887
888 WebSocketServerSocket* CreateWebSocketServerSocket(
889 Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) {
890 return new WebSocketServerSocketImpl(transport_socket, delegate);
891 }
892
893 WebSocketServerSocket::~WebSocketServerSocket() {
894 }
895
896 } // namespace net;
OLDNEW
« no previous file with comments | « net/socket/web_socket_server_socket.h ('k') | net/socket/web_socket_server_socket_unittest.cc » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698