OLD | NEW |
| (Empty) |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "net/socket/web_socket_server_socket.h" | |
6 | |
7 #include <algorithm> | |
8 #include <deque> | |
9 #include <limits> | |
10 #include <map> | |
11 #include <vector> | |
12 | |
13 #include "base/basictypes.h" | |
14 #include "base/bind.h" | |
15 #include "base/bind_helpers.h" | |
16 #include "base/logging.h" | |
17 #include "base/md5.h" | |
18 #include "base/memory/ref_counted.h" | |
19 #include "base/memory/scoped_ptr.h" | |
20 #include "base/memory/weak_ptr.h" | |
21 #include "base/message_loop.h" | |
22 #include "base/string_util.h" | |
23 #include "base/sys_byteorder.h" | |
24 #include "googleurl/src/gurl.h" | |
25 #include "net/base/completion_callback.h" | |
26 #include "net/base/io_buffer.h" | |
27 #include "net/base/net_errors.h" | |
28 | |
29 namespace { | |
30 | |
31 const size_t kHandshakeLimitBytes = 1 << 14; | |
32 | |
33 const char kCrOctet = '\r'; | |
34 COMPILE_ASSERT(kCrOctet == '\x0d', ASCII); | |
35 const char kLfOctet = '\n'; | |
36 COMPILE_ASSERT(kLfOctet == '\x0a', ASCII); | |
37 const char kSpaceOctet = ' '; | |
38 COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII); | |
39 const char kCommaOctet = ','; | |
40 COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII); | |
41 | |
42 const char kCRLF[] = { kCrOctet, kLfOctet, 0 }; | |
43 const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 }; | |
44 | |
45 const char kPlainHostFieldName[] = "Host"; | |
46 const char kPlainOriginFieldName[] = "Origin"; | |
47 const char kOriginFieldName[] = "Sec-WebSocket-Origin"; | |
48 const char kProtocolFieldName[] = "Sec-WebSocket-Protocol"; | |
49 const char kVersionFieldName[] = "Sec-WebSocket-Version"; | |
50 const char kLocationFieldName[] = "Sec-WebSocket-Location"; | |
51 const char kKey1FieldName[] = "Sec-WebSocket-Key1"; | |
52 const char kKey2FieldName[] = "Sec-WebSocket-Key2"; | |
53 | |
54 int CountSpaces(const std::string& s) { | |
55 return std::count(s.begin(), s.end(), kSpaceOctet); | |
56 } | |
57 | |
58 // Returns true on success. | |
59 bool FetchDecimalDigits(const std::string& s, uint32* result) { | |
60 *result = 0; | |
61 bool got_something = false; | |
62 for (size_t i = 0; i < s.size(); ++i) { | |
63 if (IsAsciiDigit(s[i])) { | |
64 got_something = true; | |
65 if (*result > std::numeric_limits<uint32>::max() / 10) | |
66 return false; | |
67 *result *= 10; | |
68 int digit = s[i] - '0'; | |
69 if (*result > std::numeric_limits<uint32>::max() - digit) | |
70 return false; | |
71 *result += digit; | |
72 } | |
73 } | |
74 return got_something; | |
75 } | |
76 | |
77 // Returns number of fetched subprotocols or negative error code. | |
78 int FetchSubprotocolList( | |
79 const std::string& s, std::vector<std::string>* subprotocol_list) { | |
80 subprotocol_list->clear(); | |
81 subprotocol_list->push_back(std::string()); | |
82 for (size_t i = 0; i < s.size(); ++i) { | |
83 if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet) | |
84 subprotocol_list->back() += s[i]; | |
85 else if (!subprotocol_list->back().empty()) { | |
86 if (subprotocol_list->size() < 16) | |
87 subprotocol_list->push_back(std::string()); | |
88 else | |
89 return net::ERR_LIMIT_VIOLATION; | |
90 } | |
91 } | |
92 if (subprotocol_list->back().empty()) | |
93 subprotocol_list->pop_back(); | |
94 if (subprotocol_list->empty()) | |
95 return net::ERR_WS_PROTOCOL_ERROR; | |
96 | |
97 { | |
98 std::vector<std::string> tmp(*subprotocol_list); | |
99 std::sort(tmp.begin(), tmp.end()); | |
100 if (tmp.end() != std::unique(tmp.begin(), tmp.end())) | |
101 return net::ERR_WS_PROTOCOL_ERROR; | |
102 } | |
103 return subprotocol_list->size(); | |
104 } | |
105 | |
106 class WebSocketServerSocketImpl : public net::WebSocketServerSocket { | |
107 public: | |
108 WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate) | |
109 : phase_(PHASE_NYMPH), | |
110 frame_bytes_remaining_(0), | |
111 transport_socket_(transport_socket), | |
112 delegate_(delegate), | |
113 handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)), | |
114 fill_handshake_buf_(new net::DrainableIOBuffer( | |
115 handshake_buf_, kHandshakeLimitBytes)), | |
116 process_handshake_buf_(new net::DrainableIOBuffer( | |
117 handshake_buf_, kHandshakeLimitBytes)), | |
118 is_transport_read_pending_(false), | |
119 is_transport_write_pending_(false), | |
120 weak_factory_(this) { | |
121 DCHECK(transport_socket); | |
122 DCHECK(delegate); | |
123 } | |
124 | |
125 virtual ~WebSocketServerSocketImpl() { | |
126 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
127 if (it != pending_reqs_.end() && | |
128 it->type == PendingReq::TYPE_READ && | |
129 it->io_buf != NULL && | |
130 it->io_buf->data() != NULL && | |
131 !it->callback.is_null()) { | |
132 it->callback.Run(0); // Report EOF. | |
133 } | |
134 } | |
135 | |
136 private: | |
137 enum Phase { | |
138 // Before Accept() is called. | |
139 PHASE_NYMPH, | |
140 | |
141 // After Accept() is called and until handshake success/fail. | |
142 PHASE_HANDSHAKE, | |
143 | |
144 // Processing data stream. | |
145 PHASE_FRAME_OUTSIDE, // Outside data frame. | |
146 PHASE_FRAME_INSIDE, // Inside text frame. | |
147 PHASE_FRAME_LENGTH, // Reading length of binary frame. | |
148 PHASE_FRAME_SKIP, // Skipping binary frame. | |
149 | |
150 // After termination. | |
151 PHASE_SHUT | |
152 }; | |
153 | |
154 struct PendingReq { | |
155 enum Type { | |
156 // Frame delimiters or handshake (as opposed to user data). | |
157 TYPE_METADATA = 1 << 0, | |
158 // Read request. | |
159 TYPE_READ = 1 << 1, | |
160 // Write request. | |
161 TYPE_WRITE = 1 << 2, | |
162 | |
163 TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA, | |
164 TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA | |
165 }; | |
166 | |
167 PendingReq(Type type, net::DrainableIOBuffer* io_buf, | |
168 const net::CompletionCallback& callback) | |
169 : type(type), | |
170 io_buf(io_buf), | |
171 callback(callback) { | |
172 switch (type) { | |
173 case PendingReq::TYPE_READ: | |
174 case PendingReq::TYPE_WRITE: | |
175 case PendingReq::TYPE_READ_METADATA: | |
176 case PendingReq::TYPE_WRITE_METADATA: { | |
177 DCHECK(io_buf); | |
178 break; | |
179 } | |
180 default: { | |
181 NOTREACHED(); | |
182 break; | |
183 } | |
184 } | |
185 } | |
186 | |
187 Type type; | |
188 scoped_refptr<net::DrainableIOBuffer> io_buf; | |
189 net::CompletionCallback callback; | |
190 }; | |
191 | |
192 // Socket implementation. | |
193 virtual int Read(net::IOBuffer* buf, int buf_len, | |
194 const net::CompletionCallback& callback) OVERRIDE { | |
195 if (buf_len == 0) | |
196 return 0; | |
197 if (buf == NULL || buf_len < 0) { | |
198 NOTREACHED(); | |
199 return net::ERR_INVALID_ARGUMENT; | |
200 } | |
201 while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - | |
202 process_handshake_buf_->BytesConsumed()) { | |
203 DCHECK(!is_transport_read_pending_); | |
204 DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); | |
205 switch (phase_) { | |
206 case PHASE_FRAME_OUTSIDE: | |
207 case PHASE_FRAME_INSIDE: | |
208 case PHASE_FRAME_LENGTH: | |
209 case PHASE_FRAME_SKIP: { | |
210 int n = std::min(bytes_remaining, buf_len); | |
211 int rv = ProcessDataFrames( | |
212 process_handshake_buf_->data(), n, buf->data(), buf_len); | |
213 process_handshake_buf_->DidConsume(n); | |
214 if (rv == 0) { | |
215 // ProcessDataFrames may return zero for non-empty buffer if it | |
216 // contains only frame delimiters without real data. In this case: | |
217 // try again and do not just return zero (zero stands for EOF). | |
218 continue; | |
219 } | |
220 return rv; | |
221 } | |
222 case PHASE_SHUT: { | |
223 return 0; | |
224 } | |
225 case PHASE_NYMPH: | |
226 case PHASE_HANDSHAKE: | |
227 default: { | |
228 NOTREACHED(); | |
229 return net::ERR_UNEXPECTED; | |
230 } | |
231 } | |
232 } | |
233 switch (phase_) { | |
234 case PHASE_FRAME_OUTSIDE: | |
235 case PHASE_FRAME_INSIDE: | |
236 case PHASE_FRAME_LENGTH: | |
237 case PHASE_FRAME_SKIP: { | |
238 pending_reqs_.push_back(PendingReq( | |
239 PendingReq::TYPE_READ, | |
240 new net::DrainableIOBuffer(buf, buf_len), | |
241 callback)); | |
242 ConsiderTransportRead(); | |
243 break; | |
244 } | |
245 case PHASE_SHUT: { | |
246 return 0; | |
247 } | |
248 case PHASE_NYMPH: | |
249 case PHASE_HANDSHAKE: | |
250 default: { | |
251 NOTREACHED(); | |
252 return net::ERR_UNEXPECTED; | |
253 } | |
254 } | |
255 return net::ERR_IO_PENDING; | |
256 } | |
257 | |
258 virtual int Write(net::IOBuffer* buf, int buf_len, | |
259 const net::CompletionCallback& callback) OVERRIDE { | |
260 if (buf_len == 0) | |
261 return 0; | |
262 if (buf == NULL || buf_len < 0) { | |
263 NOTREACHED(); | |
264 return net::ERR_INVALID_ARGUMENT; | |
265 } | |
266 DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), | |
267 buf->data() + buf_len); | |
268 switch (phase_) { | |
269 case PHASE_FRAME_OUTSIDE: | |
270 case PHASE_FRAME_INSIDE: | |
271 case PHASE_FRAME_LENGTH: | |
272 case PHASE_FRAME_SKIP: { | |
273 break; | |
274 } | |
275 case PHASE_SHUT: { | |
276 return net::ERR_SOCKET_NOT_CONNECTED; | |
277 } | |
278 case PHASE_NYMPH: | |
279 case PHASE_HANDSHAKE: | |
280 default: { | |
281 NOTREACHED(); | |
282 return net::ERR_UNEXPECTED; | |
283 } | |
284 } | |
285 | |
286 net::IOBuffer* frame_start = new net::IOBuffer(1); | |
287 frame_start->data()[0] = '\x00'; | |
288 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, | |
289 new net::DrainableIOBuffer(frame_start, 1), | |
290 net::CompletionCallback())); | |
291 | |
292 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, | |
293 new net::DrainableIOBuffer(buf, buf_len), | |
294 callback)); | |
295 | |
296 net::IOBuffer* frame_end = new net::IOBuffer(1); | |
297 frame_end->data()[0] = '\xff'; | |
298 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, | |
299 new net::DrainableIOBuffer(frame_end, 1), | |
300 net::CompletionCallback())); | |
301 | |
302 ConsiderTransportWrite(); | |
303 return net::ERR_IO_PENDING; | |
304 } | |
305 | |
306 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
307 return transport_socket_->SetReceiveBufferSize(size); | |
308 } | |
309 | |
310 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
311 return transport_socket_->SetSendBufferSize(size); | |
312 } | |
313 | |
314 // WebSocketServerSocket implementation. | |
315 virtual int Accept(const net::CompletionCallback& callback) OVERRIDE { | |
316 if (phase_ != PHASE_NYMPH) | |
317 return net::ERR_UNEXPECTED; | |
318 phase_ = PHASE_HANDSHAKE; | |
319 pending_reqs_.push_front(PendingReq( | |
320 PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback)); | |
321 ConsiderTransportRead(); | |
322 return net::ERR_IO_PENDING; | |
323 } | |
324 | |
325 std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) { | |
326 for (std::deque<PendingReq>::iterator it = pending_reqs_.begin(); | |
327 it != pending_reqs_.end(); ++it) { | |
328 if (it->type & type) | |
329 return it; | |
330 } | |
331 return pending_reqs_.end(); | |
332 } | |
333 | |
334 void ConsiderTransportRead() { | |
335 if (pending_reqs_.empty()) | |
336 return; | |
337 if (is_transport_read_pending_) | |
338 return; | |
339 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
340 if (it == pending_reqs_.end()) | |
341 return; | |
342 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { | |
343 NOTREACHED(); | |
344 return; | |
345 } | |
346 is_transport_read_pending_ = true; | |
347 int rv = transport_socket_->Read( | |
348 it->io_buf.get(), it->io_buf->BytesRemaining(), | |
349 base::Bind(&WebSocketServerSocketImpl::OnRead, | |
350 base::Unretained(this))); | |
351 if (rv != net::ERR_IO_PENDING) { | |
352 // PostTask rather than direct call in order to: | |
353 // (1) guarantee calling callback after returning from Read(); | |
354 // (2) avoid potential stack overflow; | |
355 MessageLoop::current()->PostTask( | |
356 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnRead, | |
357 weak_factory_.GetWeakPtr(), rv)); | |
358 } | |
359 } | |
360 | |
361 void ConsiderTransportWrite() { | |
362 if (is_transport_write_pending_) | |
363 return; | |
364 if (pending_reqs_.empty()) | |
365 return; | |
366 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); | |
367 if (it == pending_reqs_.end()) | |
368 return; | |
369 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { | |
370 NOTREACHED(); | |
371 Shut(net::ERR_UNEXPECTED); | |
372 return; | |
373 } | |
374 is_transport_write_pending_ = true; | |
375 int rv = transport_socket_->Write( | |
376 it->io_buf.get(), it->io_buf->BytesRemaining(), | |
377 base::Bind(&WebSocketServerSocketImpl::OnWrite, | |
378 base::Unretained(this))); | |
379 if (rv != net::ERR_IO_PENDING) { | |
380 // PostTask rather than direct call in order to: | |
381 // (1) guarantee calling callback after returning from Read(); | |
382 // (2) avoid potential stack overflow; | |
383 MessageLoop::current()->PostTask( | |
384 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnWrite, | |
385 weak_factory_.GetWeakPtr(), rv)); | |
386 } | |
387 } | |
388 | |
389 void Shut(int result) { | |
390 if (result > 0 || result == net::ERR_IO_PENDING) | |
391 result = net::ERR_UNEXPECTED; | |
392 if (result != 0) { | |
393 while (!pending_reqs_.empty()) { | |
394 PendingReq& req = pending_reqs_.front(); | |
395 if (!req.callback.is_null()) | |
396 req.callback.Run(result); | |
397 pending_reqs_.pop_front(); | |
398 } | |
399 transport_socket_.reset(); // terminate underlying connection. | |
400 } | |
401 phase_ = PHASE_SHUT; | |
402 } | |
403 | |
404 // Callbacks for transport socket. | |
405 void OnRead(int result) { | |
406 if (!is_transport_read_pending_) { | |
407 NOTREACHED(); | |
408 Shut(net::ERR_UNEXPECTED); | |
409 return; | |
410 } | |
411 is_transport_read_pending_ = false; | |
412 | |
413 if (result <= 0) { | |
414 Shut(result); | |
415 return; | |
416 } | |
417 | |
418 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
419 if (it == pending_reqs_.end() || | |
420 it->io_buf == NULL || | |
421 it->io_buf->data() == NULL) { | |
422 NOTREACHED(); | |
423 Shut(net::ERR_UNEXPECTED); | |
424 return; | |
425 } | |
426 if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) { | |
427 NOTREACHED(); | |
428 Shut(net::ERR_UNEXPECTED); | |
429 return; | |
430 } | |
431 | |
432 switch (phase_) { | |
433 case PHASE_HANDSHAKE: { | |
434 if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) { | |
435 NOTREACHED(); | |
436 Shut(net::ERR_UNEXPECTED); | |
437 return; | |
438 } | |
439 fill_handshake_buf_->DidConsume(result); | |
440 // ProcessHandshake invalidates iterators for |pending_reqs_| | |
441 int rv = ProcessHandshake(); | |
442 if (rv > 0) { | |
443 process_handshake_buf_->DidConsume(rv); | |
444 phase_ = PHASE_FRAME_OUTSIDE; | |
445 net::CompletionCallback cb = pending_reqs_.front().callback; | |
446 pending_reqs_.pop_front(); | |
447 ConsiderTransportWrite(); // Schedule answer handshake. | |
448 if (!cb.is_null()) | |
449 cb.Run(0); | |
450 } else if (rv == net::ERR_IO_PENDING) { | |
451 if (fill_handshake_buf_->BytesRemaining() < 1) | |
452 Shut(net::ERR_LIMIT_VIOLATION); | |
453 } else if (rv < 0) { | |
454 Shut(rv); | |
455 } else { | |
456 Shut(net::ERR_UNEXPECTED); | |
457 } | |
458 break; | |
459 } | |
460 case PHASE_FRAME_OUTSIDE: | |
461 case PHASE_FRAME_INSIDE: | |
462 case PHASE_FRAME_LENGTH: | |
463 case PHASE_FRAME_SKIP: { | |
464 int rv = ProcessDataFrames( | |
465 it->io_buf->data(), result, | |
466 it->io_buf->data(), it->io_buf->BytesRemaining()); | |
467 if (rv < 0) { | |
468 Shut(rv); | |
469 return; | |
470 } | |
471 if (rv > 0 || phase_ == PHASE_SHUT) { | |
472 net::CompletionCallback cb = it->callback; | |
473 pending_reqs_.erase(it); | |
474 if (!cb.is_null()) | |
475 cb.Run(rv); | |
476 } | |
477 break; | |
478 } | |
479 case PHASE_NYMPH: | |
480 default: { | |
481 NOTREACHED(); | |
482 Shut(net::ERR_UNEXPECTED); | |
483 break; | |
484 } | |
485 } | |
486 ConsiderTransportRead(); | |
487 } | |
488 | |
489 void OnWrite(int result) { | |
490 if (!is_transport_write_pending_) { | |
491 NOTREACHED(); | |
492 Shut(net::ERR_UNEXPECTED); | |
493 return; | |
494 } | |
495 is_transport_write_pending_ = false; | |
496 | |
497 if (result < 0) { | |
498 Shut(result); | |
499 return; | |
500 } | |
501 | |
502 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); | |
503 if (it == pending_reqs_.end() || | |
504 it->io_buf == NULL || | |
505 it->io_buf->data() == NULL) { | |
506 NOTREACHED(); | |
507 Shut(net::ERR_UNEXPECTED); | |
508 return; | |
509 } | |
510 DCHECK_LE(result, it->io_buf->BytesRemaining()); | |
511 it->io_buf->DidConsume(result); | |
512 if (it->io_buf->BytesRemaining() == 0) { | |
513 net::CompletionCallback cb = it->callback; | |
514 int bytes_written = it->io_buf->BytesConsumed(); | |
515 DCHECK_GT(bytes_written, 0); | |
516 pending_reqs_.erase(it); | |
517 if (!cb.is_null()) | |
518 cb.Run(bytes_written); | |
519 } | |
520 ConsiderTransportWrite(); | |
521 } | |
522 | |
523 // Returns (positive) number of consumed bytes on success. | |
524 // Returns ERR_IO_PENDING in case of incomplete input. | |
525 // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to | |
526 // reasonably parse input. | |
527 int ProcessHandshake() { | |
528 static const char kGetPrefix[] = "GET "; | |
529 static const char kKeyValueDelimiter[] = ": "; | |
530 | |
531 class Fields { | |
532 public: | |
533 bool Has(const std::string& name) { | |
534 return map_.find(StringToLowerASCII(name)) != map_.end(); | |
535 } | |
536 | |
537 std::string Get(const std::string& name) { | |
538 return Has(name) ? map_[StringToLowerASCII(name)] : std::string(); | |
539 } | |
540 | |
541 void Set(const std::string& name, const std::string& value) { | |
542 map_[StringToLowerASCII(name)] = StringToLowerASCII(value); | |
543 } | |
544 | |
545 private: | |
546 std::map<std::string, std::string> map_; | |
547 } fields; | |
548 | |
549 char* buf = process_handshake_buf_->data(); | |
550 size_t buf_size = fill_handshake_buf_->BytesConsumed(); | |
551 | |
552 if (buf_size < 1) | |
553 return net::ERR_IO_PENDING; | |
554 if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)), | |
555 kGetPrefix)) { | |
556 // Data head does not match what is expected. | |
557 return net::ERR_WS_PROTOCOL_ERROR; | |
558 } | |
559 if (buf_size >= kHandshakeLimitBytes) | |
560 return net::ERR_LIMIT_VIOLATION; | |
561 char* buf_end = buf + buf_size; | |
562 | |
563 if (buf_size < strlen(kGetPrefix)) | |
564 return net::ERR_IO_PENDING; | |
565 char* resource_begin = buf + strlen(kGetPrefix); | |
566 char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet); | |
567 if (resource_end == buf_end) | |
568 return net::ERR_IO_PENDING; | |
569 std::string resource(resource_begin, resource_end); | |
570 if (!IsStringUTF8(resource) || | |
571 resource.find_first_of(kCRLF) != std::string::npos) { | |
572 return net::ERR_WS_PROTOCOL_ERROR; | |
573 } | |
574 char* term_pos = std::search( | |
575 buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF)); | |
576 char key3[8]; // Notation (key3) matches websocket RFC. | |
577 size_t message_len = buf_end - term_pos; | |
578 if (message_len < sizeof(key3) + strlen(kCRLFCRLF)) | |
579 return net::ERR_IO_PENDING; | |
580 term_pos += strlen(kCRLFCRLF); | |
581 memcpy(key3, term_pos, sizeof(key3)); | |
582 term_pos += sizeof(key3); | |
583 // First line is "GET resource" line, so skip it. | |
584 char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); | |
585 if (pos == term_pos) | |
586 return net::ERR_WS_PROTOCOL_ERROR; | |
587 for (;;) { | |
588 pos += strlen(kCRLF); | |
589 if (term_pos - pos < | |
590 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { | |
591 return net::ERR_WS_PROTOCOL_ERROR; | |
592 } | |
593 if (term_pos - pos == | |
594 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { | |
595 break; | |
596 } | |
597 char* next_pos = std::search( | |
598 pos, term_pos, kKeyValueDelimiter, | |
599 kKeyValueDelimiter + strlen(kKeyValueDelimiter)); | |
600 if (next_pos == term_pos) | |
601 return net::ERR_WS_PROTOCOL_ERROR; | |
602 std::string key(pos, next_pos); | |
603 if (!IsStringASCII(key) || | |
604 key.find_first_of(kCRLF) != std::string::npos) { | |
605 return net::ERR_WS_PROTOCOL_ERROR; | |
606 } | |
607 pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos, | |
608 kCRLF, kCRLF + strlen(kCRLF)); | |
609 if (pos == term_pos) | |
610 return net::ERR_WS_PROTOCOL_ERROR; | |
611 if (!key.empty()) { | |
612 std::string value(next_pos, pos); | |
613 if (!IsStringASCII(value) || | |
614 value.find_first_of(kCRLF) != std::string::npos) { | |
615 return net::ERR_WS_PROTOCOL_ERROR; | |
616 } | |
617 fields.Set(key, value); | |
618 } | |
619 } | |
620 | |
621 // Values of Upgrade and Connection fields are hardcoded in the protocol. | |
622 if (fields.Get("Upgrade") != "websocket" || | |
623 fields.Get("Connection") != "upgrade") { | |
624 return net::ERR_WS_PROTOCOL_ERROR; | |
625 } | |
626 if (fields.Has(kVersionFieldName)) { | |
627 NOTIMPLEMENTED(); // new protocol. | |
628 return net::ERR_NOT_IMPLEMENTED; | |
629 } | |
630 | |
631 if (!fields.Has(kPlainOriginFieldName)) | |
632 return net::ERR_CONNECTION_REFUSED; | |
633 // Normalize (e.g. w.r.t. leading slashes) origin. | |
634 GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin(); | |
635 if (!origin.is_valid()) | |
636 return net::ERR_WS_PROTOCOL_ERROR; | |
637 std::string normalized_origin = origin.spec(); | |
638 | |
639 if (!fields.Has(kPlainHostFieldName)) | |
640 return net::ERR_CONNECTION_REFUSED; | |
641 | |
642 std::vector<std::string> subprotocol_list; | |
643 if (fields.Has(kProtocolFieldName)) { | |
644 int rv = FetchSubprotocolList( | |
645 fields.Get(kProtocolFieldName), &subprotocol_list); | |
646 if (rv < 0) | |
647 return rv; | |
648 DCHECK(subprotocol_list.end() == std::find( | |
649 subprotocol_list.begin(), subprotocol_list.end(), "")); | |
650 } | |
651 | |
652 std::string location; | |
653 std::string subprotocol; | |
654 if (!delegate_->ValidateWebSocket(resource, | |
655 normalized_origin, | |
656 fields.Get(kPlainHostFieldName), | |
657 subprotocol_list, | |
658 &location, | |
659 &subprotocol)) { | |
660 return net::ERR_CONNECTION_REFUSED; | |
661 } | |
662 if (subprotocol_list.empty()) { | |
663 DCHECK(subprotocol.empty()); | |
664 } else { | |
665 if (!subprotocol.empty()) { | |
666 if (subprotocol_list.end() == std::find( | |
667 subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) { | |
668 NOTREACHED() << "delegate must pick subprotocol from given list"; | |
669 return net::ERR_UNEXPECTED; | |
670 } | |
671 } | |
672 } | |
673 | |
674 uint32 key_number1 = 0; | |
675 uint32 key_number2 = 0; | |
676 if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) || | |
677 !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) { | |
678 return net::ERR_WS_PROTOCOL_ERROR; | |
679 } | |
680 | |
681 // We limit incoming header size so following numbers shall not be too high. | |
682 int spaces1 = CountSpaces(fields.Get(kKey1FieldName)); | |
683 int spaces2 = CountSpaces(fields.Get(kKey2FieldName)); | |
684 if (spaces1 == 0 || | |
685 spaces2 == 0 || | |
686 key_number1 % spaces1 != 0 || | |
687 key_number2 % spaces2 != 0) { | |
688 return net::ERR_WS_PROTOCOL_ERROR; | |
689 } | |
690 | |
691 char challenge[4 + 4 + sizeof(key3)]; | |
692 int32 part1 = base::HostToNet32(key_number1 / spaces1); | |
693 int32 part2 = base::HostToNet32(key_number2 / spaces2); | |
694 memcpy(challenge, &part1, 4); | |
695 memcpy(challenge + 4, &part2, 4); | |
696 memcpy(challenge + 4 + 4, key3, sizeof(key3)); | |
697 base::MD5Digest challenge_response; | |
698 base::MD5Sum(challenge, sizeof(challenge), &challenge_response); | |
699 | |
700 // Concocting response handshake. | |
701 class Buffer { | |
702 public: | |
703 Buffer() | |
704 : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)), | |
705 bytes_written_(0), | |
706 is_ok_(true) { | |
707 } | |
708 | |
709 bool Write(const void* p, int len) { | |
710 DCHECK(p); | |
711 DCHECK_GE(len, 0); | |
712 if (!is_ok_) | |
713 return false; | |
714 if (bytes_written_ + len > kHandshakeLimitBytes) { | |
715 NOTREACHED(); | |
716 is_ok_ = false; | |
717 return false; | |
718 } | |
719 memcpy(io_buf_->data() + bytes_written_, p, len); | |
720 bytes_written_ += len; | |
721 return true; | |
722 } | |
723 | |
724 bool WriteLine(const char* p) { | |
725 return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF)); | |
726 } | |
727 | |
728 operator net::DrainableIOBuffer*() { | |
729 return new net::DrainableIOBuffer(io_buf_.get(), bytes_written_); | |
730 } | |
731 | |
732 bool is_ok() { return is_ok_; } | |
733 | |
734 private: | |
735 scoped_refptr<net::IOBuffer> io_buf_; | |
736 size_t bytes_written_; | |
737 bool is_ok_; | |
738 } buffer; | |
739 | |
740 buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake"); | |
741 buffer.WriteLine("Upgrade: WebSocket"); | |
742 buffer.WriteLine("Connection: Upgrade"); | |
743 | |
744 { | |
745 // Take care of Location field. | |
746 char tmp[2048]; | |
747 int rv = base::snprintf(tmp, sizeof(tmp), | |
748 "%s: %s", | |
749 kLocationFieldName, | |
750 location.c_str()); | |
751 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
752 return net::ERR_LIMIT_VIOLATION; | |
753 buffer.WriteLine(tmp); | |
754 } | |
755 { | |
756 // Take care of Origin field. | |
757 char tmp[2048]; | |
758 int rv = base::snprintf(tmp, sizeof(tmp), | |
759 "%s: %s", | |
760 kOriginFieldName, | |
761 fields.Get(kPlainOriginFieldName).c_str()); | |
762 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
763 return net::ERR_LIMIT_VIOLATION; | |
764 buffer.WriteLine(tmp); | |
765 } | |
766 if (!subprotocol.empty()) { | |
767 char tmp[2048]; | |
768 int rv = base::snprintf(tmp, sizeof(tmp), | |
769 "%s: %s", | |
770 kProtocolFieldName, | |
771 subprotocol.c_str()); | |
772 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
773 return net::ERR_LIMIT_VIOLATION; | |
774 buffer.WriteLine(tmp); | |
775 } | |
776 buffer.WriteLine(""); | |
777 buffer.Write(&challenge_response, sizeof(challenge_response)); | |
778 | |
779 if (!buffer.is_ok()) | |
780 return net::ERR_LIMIT_VIOLATION; | |
781 | |
782 pending_reqs_.push_back(PendingReq( | |
783 PendingReq::TYPE_WRITE_METADATA, buffer, net::CompletionCallback())); | |
784 DCHECK_GT(term_pos - buf, 0); | |
785 return term_pos - buf; | |
786 } | |
787 | |
788 // Removes frame delimiters and returns net number of data bytes (or error). | |
789 // |out| may be equal to |buf|, in that case it is in-place operation. | |
790 int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) { | |
791 if (out_len < buf_len) { | |
792 NOTREACHED(); | |
793 return net::ERR_UNEXPECTED; | |
794 } | |
795 int out_pos = 0; | |
796 for (char* p = buf; p < buf + buf_len; ++p) { | |
797 switch (phase_) { | |
798 case PHASE_FRAME_INSIDE: { | |
799 if (*p == '\x00') | |
800 return net::ERR_WS_PROTOCOL_ERROR; | |
801 if (*p == '\xff') | |
802 phase_ = PHASE_FRAME_OUTSIDE; | |
803 else | |
804 out[out_pos++] = *p; | |
805 break; | |
806 } | |
807 case PHASE_FRAME_OUTSIDE: { | |
808 if (*p == '\x00') { | |
809 phase_ = PHASE_FRAME_INSIDE; | |
810 } else if (*p == '\xff') { | |
811 phase_ = PHASE_FRAME_LENGTH; | |
812 frame_bytes_remaining_ = 0; | |
813 } | |
814 else { | |
815 return net::ERR_WS_PROTOCOL_ERROR; | |
816 } | |
817 break; | |
818 } | |
819 case PHASE_FRAME_LENGTH: { | |
820 static const int kValueBits = 7; | |
821 static const char kValueMask = (1 << kValueBits) - 1; | |
822 frame_bytes_remaining_ <<= kValueBits; | |
823 frame_bytes_remaining_ += (*p & kValueMask); | |
824 if (*p & ~kValueMask) { | |
825 // Check that next byte would not overflow. | |
826 if (frame_bytes_remaining_ > | |
827 (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) { | |
828 return net::ERR_LIMIT_VIOLATION; | |
829 } | |
830 } else { | |
831 if (frame_bytes_remaining_ == 0) { | |
832 phase_ = PHASE_SHUT; | |
833 return out_pos; | |
834 } else { | |
835 phase_ = PHASE_FRAME_SKIP; | |
836 } | |
837 } | |
838 break; | |
839 } | |
840 case PHASE_FRAME_SKIP: { | |
841 DCHECK_GE(frame_bytes_remaining_, 1); | |
842 frame_bytes_remaining_ -= 1; | |
843 if (frame_bytes_remaining_ < 1) | |
844 phase_ = PHASE_FRAME_OUTSIDE; | |
845 break; | |
846 } | |
847 default: { | |
848 NOTREACHED(); | |
849 } | |
850 } | |
851 } | |
852 return out_pos; | |
853 } | |
854 | |
855 // State machinery. | |
856 Phase phase_; | |
857 | |
858 // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP. | |
859 int frame_bytes_remaining_; | |
860 | |
861 // Underlying socket. | |
862 scoped_ptr<net::Socket> transport_socket_; | |
863 | |
864 // Validation is performed via delegate. | |
865 Delegate* delegate_; | |
866 | |
867 // IOBuffer used to communicate with transport at initial stage. | |
868 scoped_refptr<net::IOBuffer> handshake_buf_; | |
869 scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_; | |
870 scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_; | |
871 | |
872 // Pending IO requests we need to complete. | |
873 std::deque<PendingReq> pending_reqs_; | |
874 | |
875 // Whether transport requests are pending. | |
876 bool is_transport_read_pending_; | |
877 bool is_transport_write_pending_; | |
878 | |
879 base::WeakPtrFactory<WebSocketServerSocketImpl> weak_factory_; | |
880 | |
881 DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl); | |
882 }; | |
883 | |
884 } // namespace | |
885 | |
886 namespace net { | |
887 | |
888 WebSocketServerSocket* CreateWebSocketServerSocket( | |
889 Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) { | |
890 return new WebSocketServerSocketImpl(transport_socket, delegate); | |
891 } | |
892 | |
893 WebSocketServerSocket::~WebSocketServerSocket() { | |
894 } | |
895 | |
896 } // namespace net; | |
OLD | NEW |