Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(107)

Side by Side Diff: net/socket/web_socket_server_socket_unittest.cc

Issue 10894004: Remove WebSocketServerSocket. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Remove error code. Created 8 years, 3 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/web_socket_server_socket.h"
6
7 #include <stdlib.h>
8 #include <algorithm>
9
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/message_loop.h"
15 #include "base/string_util.h"
16 #include "base/time.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "testing/gtest/include/gtest/gtest.h"
20
21 namespace {
22
23 const char* kSampleHandshakeRequest[] = {
24 "GET /demo HTTP/1.1",
25 "Upgrade: WebSocket",
26 "Connection: Upgrade",
27 "Host: example.com",
28 "Origin: http://example.com",
29 "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5",
30 "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00",
31 "",
32 "^n:ds[4U"
33 };
34
35 const char kSampleHandshakeAnswer[] =
36 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
37 "Upgrade: WebSocket\r\n"
38 "Connection: Upgrade\r\n"
39 "Sec-WebSocket-Location: ws://example.com/demo\r\n"
40 "Sec-WebSocket-Origin: http://example.com\r\n"
41 "\r\n"
42 "8jKS'y:G*Co,Wxa-";
43
44 const int kHandshakeBufBytes = 1 << 12;
45
46 const char kCRLF[] = "\r\n";
47 const char kCRLFCRLF[] = "\r\n\r\n";
48 const char kSpaceOctet = '\x20';
49
50 const int kReadSalt = 7;
51 const int kWriteSalt = 5;
52
53 int GetRand(int min, int max) {
54 CHECK(max >= min);
55 CHECK(max - min < RAND_MAX);
56 return rand() % (max - min + 1) + min;
57 }
58
59 class RandIntClass {
60 public:
61 int operator() (int range) {
62 return GetRand(0, range - 1);
63 }
64 } g_rand;
65
66 net::DrainableIOBuffer* ResizeIOBuffer(net::DrainableIOBuffer* buf, int len) {
67 net::DrainableIOBuffer* rv = new net::DrainableIOBuffer(
68 new net::IOBuffer(len), len);
69 std::copy(buf->data(), buf->data() + std::min(len, buf->BytesRemaining()),
70 rv->data());
71 return rv;
72 }
73
74 // TODO(dilmah): consider switching to socket_test_util.h
75 // Simulates reading from |sample| stream; data supplied in Write() calls are
76 // stored in |answer| buffer.
77 class TestingTransportSocket : public net::Socket {
78 public:
79 TestingTransportSocket(
80 net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer)
81 : sample_(sample),
82 answer_(answer),
83 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
84 }
85
86 ~TestingTransportSocket() {
87 if (!final_read_callback_.is_null()) {
88 MessageLoop::current()->PostTask(FROM_HERE,
89 base::Bind(&TestingTransportSocket::DoReadCallback,
90 weak_factory_.GetWeakPtr(),
91 final_read_callback_, 0));
92 }
93 }
94
95 // Socket implementation.
96 virtual int Read(net::IOBuffer* buf, int buf_len,
97 const net::CompletionCallback& callback) {
98 CHECK_GT(buf_len, 0);
99 int remaining = sample_->BytesRemaining();
100 if (remaining < 1) {
101 if (!final_read_callback_.is_null())
102 return 0;
103 final_read_callback_ = callback;
104 return net::ERR_IO_PENDING;
105 }
106 int lot = GetRand(1, std::min(remaining, buf_len));
107 std::copy(sample_->data(), sample_->data() + lot, buf->data());
108 sample_->DidConsume(lot);
109 if (GetRand(0, 1)) {
110 return lot;
111 }
112 MessageLoop::current()->PostTask(
113 FROM_HERE,
114 base::Bind(&TestingTransportSocket::DoReadCallback,
115 weak_factory_.GetWeakPtr(), callback, lot));
116 return net::ERR_IO_PENDING;
117 }
118
119 virtual int Write(net::IOBuffer* buf, int buf_len,
120 const net::CompletionCallback& callback) {
121 CHECK_GT(buf_len, 0);
122 int remaining = answer_->BytesRemaining();
123 CHECK_GE(remaining, buf_len);
124 int lot = std::min(remaining, buf_len);
125 if (GetRand(0, 1))
126 lot = GetRand(1, lot);
127 std::copy(buf->data(), buf->data() + lot, answer_->data());
128 answer_->DidConsume(lot);
129 if (GetRand(0, 1)) {
130 return lot;
131 }
132 MessageLoop::current()->PostTask(
133 FROM_HERE,
134 base::Bind(&TestingTransportSocket::DoWriteCallback,
135 weak_factory_.GetWeakPtr(), callback, lot));
136 return net::ERR_IO_PENDING;
137 }
138
139 virtual bool SetReceiveBufferSize(int32 size) {
140 return true;
141 }
142
143 virtual bool SetSendBufferSize(int32 size) {
144 return true;
145 }
146
147 net::DrainableIOBuffer* answer() { return answer_.get(); }
148
149 void DoReadCallback(const net::CompletionCallback& callback, int result) {
150 if (result == 0 && !is_closed_) {
151 MessageLoop::current()->PostTask(
152 FROM_HERE,
153 base::Bind(
154 &TestingTransportSocket::DoReadCallback,
155 weak_factory_.GetWeakPtr(), callback, 0));
156 } else {
157 if (!callback.is_null())
158 callback.Run(result);
159 }
160 }
161
162 void DoWriteCallback(const net::CompletionCallback& callback, int result) {
163 if (!callback.is_null())
164 callback.Run(result);
165 }
166
167 bool is_closed_;
168
169 // Data to return for Read requests.
170 scoped_refptr<net::DrainableIOBuffer> sample_;
171
172 // Data pushed to us by server socket (using Write calls).
173 scoped_refptr<net::DrainableIOBuffer> answer_;
174
175 // Final read callback to report zero (zero stands for EOF).
176 net::CompletionCallback final_read_callback_;
177
178 base::WeakPtrFactory<TestingTransportSocket> weak_factory_;
179 };
180
181 class Validator : public net::WebSocketServerSocket::Delegate {
182 public:
183 Validator(const std::string& resource,
184 const std::string& origin,
185 const std::string& host)
186 : resource_(resource), origin_(origin), host_(host) {
187 }
188
189 // WebSocketServerSocket::Delegate implementation.
190 virtual bool ValidateWebSocket(
191 const std::string& resource,
192 const std::string& origin,
193 const std::string& host,
194 const std::vector<std::string>& subprotocol_list,
195 std::string* location_out,
196 std::string* subprotocol_out) {
197 if (resource != resource_ || origin != origin_ || host != host_)
198 return false;
199 if (!subprotocol_list.empty())
200 *subprotocol_out = subprotocol_list.front();
201
202 char tmp[2048];
203 base::snprintf(
204 tmp, sizeof(tmp), "ws://%s%s", host.c_str(), resource.c_str());
205 location_out->assign(tmp);
206 return true;
207 }
208
209 private:
210 std::string resource_;
211 std::string origin_;
212 std::string host_;
213 };
214
215 char ReferenceSeq(unsigned n, unsigned salt) {
216 return (salt * 2 + n * 3) % ('z' - 'a') + 'a';
217 }
218
219 class ReadWriteTracker {
220 public:
221 ReadWriteTracker(
222 net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write)
223 : ws_(ws),
224 buf_size_(1 << 14),
225 read_buf_(new net::IOBuffer(buf_size_)),
226 write_buf_(new net::IOBuffer(buf_size_)),
227 bytes_remaining_to_read_(bytes_to_read),
228 bytes_remaining_to_write_(bytes_to_write),
229 got_final_zero_(false) {
230 int rv = ws_->Accept(
231 base::Bind(&ReadWriteTracker::OnAccept, base::Unretained(this)));
232 if (rv != net::ERR_IO_PENDING)
233 OnAccept(rv);
234 }
235
236 ~ReadWriteTracker() {
237 CHECK_EQ(bytes_remaining_to_write_, 0);
238 CHECK_EQ(bytes_remaining_to_read_, 0);
239 }
240
241 void OnAccept(int result) {
242 ASSERT_EQ(result, 0);
243 if (GetRand(0, 1)) {
244 DoRead();
245 DoWrite();
246 } else {
247 DoWrite();
248 DoRead();
249 }
250 }
251
252 void DoWrite() {
253 if (bytes_remaining_to_write_ < 1)
254 return;
255 int lot = GetRand(1, bytes_remaining_to_write_);
256 lot = std::min(lot, buf_size_);
257 for (int i = 0; i < lot; ++i)
258 write_buf_->data()[i] = ReferenceSeq(
259 bytes_remaining_to_write_ - i - 1, kWriteSalt);
260 int rv = ws_->Write(write_buf_, lot, base::Bind(&ReadWriteTracker::OnWrite,
261 base::Unretained(this)));
262 if (rv != net::ERR_IO_PENDING)
263 OnWrite(rv);
264 }
265
266 void DoRead() {
267 int lot = GetRand(1, buf_size_);
268 if (bytes_remaining_to_read_ < 1) {
269 if (got_final_zero_)
270 return;
271 } else {
272 lot = GetRand(1, bytes_remaining_to_read_);
273 lot = std::min(lot, buf_size_);
274 }
275 int rv = ws_->Read(read_buf_, lot, base::Bind(&ReadWriteTracker::OnRead,
276 base::Unretained(this)));
277 if (rv != net::ERR_IO_PENDING)
278 OnRead(rv);
279 }
280
281 void OnWrite(int result) {
282 ASSERT_GT(result, 0);
283 ASSERT_LE(result, bytes_remaining_to_write_);
284 bytes_remaining_to_write_ -= result;
285 DoWrite();
286 }
287
288 void OnRead(int result) {
289 ASSERT_LE(result, bytes_remaining_to_read_);
290 if (bytes_remaining_to_read_ < 1) {
291 ASSERT_FALSE(got_final_zero_);
292 ASSERT_EQ(result, 0);
293 got_final_zero_ = true;
294 return;
295 }
296 for (int i = 0; i < result; ++i) {
297 ASSERT_EQ(read_buf_->data()[i], ReferenceSeq(
298 bytes_remaining_to_read_ - i - 1, kReadSalt));
299 }
300 bytes_remaining_to_read_ -= result;
301 DoRead();
302 }
303
304 private:
305 net::WebSocketServerSocket* const ws_;
306 int const buf_size_;
307 scoped_refptr<net::IOBuffer> read_buf_;
308 scoped_refptr<net::IOBuffer> write_buf_;
309 int bytes_remaining_to_read_;
310 int bytes_remaining_to_write_;
311 bool got_final_zero_;
312 };
313
314 } // namespace
315
316 namespace net {
317
318 class WebSocketServerSocketTest : public testing::Test {
319 public:
320 virtual ~WebSocketServerSocketTest() {
321 }
322
323 virtual void SetUp() {
324 count_ = 0;
325 }
326
327 virtual void TearDown() {
328 }
329
330 void OnAccept0(int result) {
331 ASSERT_EQ(result, 0);
332 ASSERT_LT(count_, 99999);
333 count_ += 1;
334 }
335
336 void OnAccept1(int result) {
337 ASSERT_TRUE(result == ERR_CONNECTION_REFUSED ||
338 result == ERR_ACCESS_DENIED);
339 ASSERT_LT(count_, 99999);
340 count_ += 1;
341 }
342
343 int count_;
344 };
345
346 TEST_F(WebSocketServerSocketTest, Handshake) {
347 srand(2523456);
348 std::vector<Socket*> kill_list;
349 std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
350 Validator validator("/demo", "http://example.com/", "example.com");
351 count_ = 0;
352 const int kNumTests = 300;
353 for (int run = kNumTests; run--;) {
354 scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
355 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
356 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
357 std::copy(kSampleHandshakeRequest[i],
358 kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]),
359 sample->data());
360 sample->DidConsume(strlen(kSampleHandshakeRequest[i]));
361 if (i != arraysize(kSampleHandshakeRequest) - 1) {
362 std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
363 sample->DidConsume(strlen(kCRLF));
364 }
365 }
366 int sample_len = sample->BytesConsumed();
367 sample->SetOffset(0);
368 DrainableIOBuffer* answer = new DrainableIOBuffer(
369 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
370 answer_list.push_back(answer);
371 TestingTransportSocket* transport = new TestingTransportSocket(
372 ResizeIOBuffer(sample.get(), sample_len), answer);
373 WebSocketServerSocket* ws = CreateWebSocketServerSocket(
374 transport, &validator);
375 ASSERT_TRUE(ws != NULL);
376 kill_list.push_back(ws);
377
378 int rv = ws->Accept(base::Bind(&WebSocketServerSocketTest::OnAccept0,
379 base::Unretained(this)));
380 if (rv != ERR_IO_PENDING)
381 OnAccept0(rv);
382 }
383 MessageLoop::current()->RunAllPending();
384 ASSERT_EQ(count_, kNumTests);
385 for (size_t i = answer_list.size(); i--;) {
386 ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u,
387 strlen(kSampleHandshakeAnswer));
388 ASSERT_TRUE(std::equal(
389 answer_list[i]->data() - answer_list[i]->BytesConsumed(),
390 answer_list[i]->data(), kSampleHandshakeAnswer));
391 }
392 for (size_t i = kill_list.size(); i--;)
393 delete kill_list[i];
394 MessageLoop::current()->RunAllPending();
395 }
396
397 TEST_F(WebSocketServerSocketTest, BadCred) {
398 srand(9034958);
399 std::vector<Socket*> kill_list;
400 std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
401 Validator *validator[] = {
402 new Validator("/demo", "http://gooogle.com/", "example.com"),
403 new Validator("/tcpproxy", "http://example.com/", "example.com"),
404 new Validator("/tcpproxy", "http://gooogle.com/", "example.com"),
405 new Validator("/demo", "http://example.com/", "exmple.com"),
406 new Validator("/demo", "http://gooogle.com/", "gooogle.com")
407 };
408 count_ = 0;
409 for (int run = arraysize(validator); run--;) {
410 scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
411 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
412 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
413 std::copy(kSampleHandshakeRequest[i],
414 kSampleHandshakeRequest[i] + strlen(kSampleHandshakeRequest[i]),
415 sample->data());
416 sample->DidConsume(strlen(kSampleHandshakeRequest[i]));
417 if (i != arraysize(kSampleHandshakeRequest) - 1) {
418 std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
419 sample->DidConsume(strlen(kCRLF));
420 }
421 }
422 int sample_len = sample->BytesConsumed();
423 sample->SetOffset(0);
424 DrainableIOBuffer* answer = new DrainableIOBuffer(
425 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
426 answer_list.push_back(answer);
427 TestingTransportSocket* transport = new TestingTransportSocket(
428 ResizeIOBuffer(sample.get(), sample_len), answer);
429 WebSocketServerSocket* ws = CreateWebSocketServerSocket(
430 transport, validator[run]);
431 ASSERT_TRUE(ws != NULL);
432 kill_list.push_back(ws);
433
434 int rv = ws->Accept(base::Bind(&WebSocketServerSocketTest::OnAccept1,
435 base::Unretained(this)));
436 if (rv != ERR_IO_PENDING)
437 OnAccept1(rv);
438 }
439 MessageLoop::current()->RunAllPending();
440 ASSERT_EQ(count_ + 0u, arraysize(validator));
441 for (size_t i = answer_list.size(); i--;)
442 ASSERT_EQ(answer_list[i]->BytesConsumed(), 0);
443 for (size_t i = kill_list.size(); i--;)
444 delete kill_list[i];
445 for (size_t i = arraysize(validator); i--;)
446 delete validator[i];
447 MessageLoop::current()->RunAllPending();
448 }
449
450 TEST_F(WebSocketServerSocketTest, ReorderedHandshake) {
451 srand(205643459);
452 std::vector<Socket*> kill_list;
453 std::vector< scoped_refptr<DrainableIOBuffer> > answer_list;
454 Validator validator("/demo", "http://example.com/", "example.com");
455 count_ = 0;
456 const int kNumTests = 200;
457 for (int run = kNumTests; run--;) {
458 scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
459 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
460
461 std::vector<size_t> fields_order;
462 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i)
463 fields_order.push_back(i);
464 // One leading and two trailing lines of request are special, leave them.
465 std::random_shuffle(fields_order.begin() + 1,
466 fields_order.begin() + fields_order.size() - 3,
467 g_rand);
468
469 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
470 size_t j = fields_order[i];
471 std::copy(kSampleHandshakeRequest[j],
472 kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]),
473 sample->data());
474 sample->DidConsume(strlen(kSampleHandshakeRequest[j]));
475 if (i != arraysize(kSampleHandshakeRequest) - 1) {
476 std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
477 sample->DidConsume(strlen(kCRLF));
478 }
479 }
480 int sample_len = sample->BytesConsumed();
481 sample->SetOffset(0);
482 DrainableIOBuffer* answer = new DrainableIOBuffer(
483 new IOBuffer(kHandshakeBufBytes), kHandshakeBufBytes);
484 answer_list.push_back(answer);
485 TestingTransportSocket* transport = new TestingTransportSocket(
486 ResizeIOBuffer(sample.get(), sample_len), answer);
487 WebSocketServerSocket* ws = CreateWebSocketServerSocket(
488 transport, &validator);
489 ASSERT_TRUE(ws != NULL);
490 kill_list.push_back(ws);
491
492 int rv = ws->Accept(base::Bind(&WebSocketServerSocketTest::OnAccept0,
493 base::Unretained(this)));
494 if (rv != ERR_IO_PENDING)
495 OnAccept0(rv);
496 }
497 MessageLoop::current()->RunAllPending();
498 ASSERT_EQ(count_, kNumTests);
499 for (size_t i = answer_list.size(); i--;) {
500 ASSERT_EQ(answer_list[i]->BytesConsumed() + 0u,
501 strlen(kSampleHandshakeAnswer));
502 ASSERT_TRUE(std::equal(
503 answer_list[i]->data() - answer_list[i]->BytesConsumed(),
504 answer_list[i]->data(), kSampleHandshakeAnswer));
505 }
506 for (size_t i = kill_list.size(); i--;)
507 delete kill_list[i];
508 MessageLoop::current()->RunAllPending();
509 }
510
511 TEST_F(WebSocketServerSocketTest, ConveyData) {
512 srand(8234523);
513 std::vector<Socket*> kill_list;
514 std::vector<ReadWriteTracker*> tracker_list;
515 Validator validator("/demo", "http://example.com/", "example.com");
516 count_ = 0;
517 const int kNumTests = 150;
518 for (int run = kNumTests; run--;) {
519 int bytes_to_read = GetRand(1, 1 << 14);
520 int bytes_to_write = GetRand(1, 1 << 14);
521 int frames_limit = GetRand(1, 1 << 10);
522 int sample_limit = kHandshakeBufBytes + bytes_to_write + frames_limit * 2;
523 scoped_refptr<DrainableIOBuffer> sample = new DrainableIOBuffer(
524 new IOBuffer(sample_limit), sample_limit);
525
526 std::vector<size_t> fields_order;
527 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i)
528 fields_order.push_back(i);
529 // One leading and two trailing lines of request are special, leave them.
530 std::random_shuffle(fields_order.begin() + 1,
531 fields_order.begin() + fields_order.size() - 3,
532 g_rand);
533
534 for (size_t i = 0; i < arraysize(kSampleHandshakeRequest); ++i) {
535 size_t j = fields_order[i];
536 std::copy(kSampleHandshakeRequest[j],
537 kSampleHandshakeRequest[j] + strlen(kSampleHandshakeRequest[j]),
538 sample->data());
539 sample->DidConsume(strlen(kSampleHandshakeRequest[j]));
540 if (i != arraysize(kSampleHandshakeRequest) - 1) {
541 std::copy(kCRLF, kCRLF + strlen(kCRLF), sample->data());
542 sample->DidConsume(strlen(kCRLF));
543 }
544 }
545 {
546 bool outside_frame = true;
547 int pos = 0;
548 for (int i = 0; i < bytes_to_write; ++i) {
549 if (outside_frame) {
550 sample->data()[pos++] = '\x00';
551 outside_frame = false;
552 CHECK_GE(frames_limit, 1);
553 frames_limit -= 1;
554 }
555 sample->data()[pos++] = ReferenceSeq(bytes_to_write - i - 1, kReadSalt);
556 if ((frames_limit > 1 &&
557 GetRand(0, 1 + (bytes_to_write - i) / frames_limit) == 0) ||
558 i == bytes_to_write - 1) {
559 sample->data()[pos++] = '\xff';
560 outside_frame = true;
561 }
562 }
563 sample->DidConsume(pos);
564 }
565
566 int sample_len = sample->BytesConsumed();
567 sample->SetOffset(0);
568 int answer_limit = kHandshakeBufBytes + bytes_to_read * 3;
569 DrainableIOBuffer* answer = new DrainableIOBuffer(
570 new IOBuffer(answer_limit), answer_limit);
571 TestingTransportSocket* transport = new TestingTransportSocket(
572 ResizeIOBuffer(sample.get(), sample_len), answer);
573 WebSocketServerSocket* ws = CreateWebSocketServerSocket(
574 transport, &validator);
575 ASSERT_TRUE(ws != NULL);
576 kill_list.push_back(ws);
577
578 ReadWriteTracker* tracker = new ReadWriteTracker(
579 ws, bytes_to_write, bytes_to_read);
580 tracker_list.push_back(tracker);
581 }
582 MessageLoop::current()->RunAllPending();
583
584 for (size_t i = kill_list.size(); i--;)
585 delete kill_list[i];
586 for (size_t i = tracker_list.size(); i--;)
587 delete tracker_list[i];
588 MessageLoop::current()->RunAllPending();
589 }
590
591 } // namespace net
OLDNEW
« net/base/net_error_list.h ('K') | « net/socket/web_socket_server_socket.cc ('k') | no next file » | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698