OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/dns/dns_test_util.h" | 5 #include "net/dns/dns_test_util.h" |
6 | 6 |
7 #include <string> | 7 #include <string> |
8 | 8 |
9 #include "base/bind.h" | 9 #include "base/bind.h" |
10 #include "base/memory/weak_ptr.h" | 10 #include "base/memory/weak_ptr.h" |
11 #include "base/message_loop/message_loop.h" | 11 #include "base/message_loop/message_loop.h" |
12 #include "base/sys_byteorder.h" | 12 #include "base/sys_byteorder.h" |
13 #include "net/base/big_endian.h" | 13 #include "net/base/big_endian.h" |
14 #include "net/base/dns_util.h" | 14 #include "net/base/dns_util.h" |
15 #include "net/base/io_buffer.h" | 15 #include "net/base/io_buffer.h" |
16 #include "net/base/net_errors.h" | 16 #include "net/base/net_errors.h" |
17 #include "net/dns/address_sorter.h" | 17 #include "net/dns/address_sorter.h" |
18 #include "net/dns/dns_client.h" | |
19 #include "net/dns/dns_config_service.h" | |
20 #include "net/dns/dns_protocol.h" | |
21 #include "net/dns/dns_query.h" | 18 #include "net/dns/dns_query.h" |
22 #include "net/dns/dns_response.h" | 19 #include "net/dns/dns_response.h" |
23 #include "net/dns/dns_transaction.h" | 20 #include "net/dns/dns_transaction.h" |
24 #include "testing/gtest/include/gtest/gtest.h" | 21 #include "testing/gtest/include/gtest/gtest.h" |
25 | 22 |
26 namespace net { | 23 namespace net { |
27 namespace { | 24 namespace { |
28 | 25 |
| 26 class MockAddressSorter : public AddressSorter { |
| 27 public: |
| 28 virtual ~MockAddressSorter() {} |
| 29 virtual void Sort(const AddressList& list, |
| 30 const CallbackType& callback) const OVERRIDE { |
| 31 // Do nothing. |
| 32 callback.Run(true, list); |
| 33 } |
| 34 }; |
| 35 |
29 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. | 36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
30 class MockTransaction : public DnsTransaction, | 37 class MockTransaction : public DnsTransaction, |
31 public base::SupportsWeakPtr<MockTransaction> { | 38 public base::SupportsWeakPtr<MockTransaction> { |
32 public: | 39 public: |
33 MockTransaction(const MockDnsClientRuleList& rules, | 40 MockTransaction(const MockDnsClientRuleList& rules, |
34 const std::string& hostname, | 41 const std::string& hostname, |
35 uint16 qtype, | 42 uint16 qtype, |
36 const DnsTransactionFactory::CallbackType& callback) | 43 const DnsTransactionFactory::CallbackType& callback) |
37 : result_(MockDnsClientRule::FAIL), | 44 : result_(MockDnsClientRule::FAIL), |
38 hostname_(hostname), | 45 hostname_(hostname), |
39 qtype_(qtype), | 46 qtype_(qtype), |
40 callback_(callback), | 47 callback_(callback), |
41 started_(false) { | 48 started_(false), |
| 49 delayed_(false) { |
42 // Find the relevant rule which matches |qtype| and prefix of |hostname|. | 50 // Find the relevant rule which matches |qtype| and prefix of |hostname|. |
43 for (size_t i = 0; i < rules.size(); ++i) { | 51 for (size_t i = 0; i < rules.size(); ++i) { |
44 const std::string& prefix = rules[i].prefix; | 52 const std::string& prefix = rules[i].prefix; |
45 if ((rules[i].qtype == qtype) && | 53 if ((rules[i].qtype == qtype) && |
46 (hostname.size() >= prefix.size()) && | 54 (hostname.size() >= prefix.size()) && |
47 (hostname.compare(0, prefix.size(), prefix) == 0)) { | 55 (hostname.compare(0, prefix.size(), prefix) == 0)) { |
48 result_ = rules[i].result; | 56 result_ = rules[i].result; |
| 57 delayed_ = rules[i].delay; |
49 break; | 58 break; |
50 } | 59 } |
51 } | 60 } |
52 } | 61 } |
53 | 62 |
54 virtual const std::string& GetHostname() const OVERRIDE { | 63 virtual const std::string& GetHostname() const OVERRIDE { |
55 return hostname_; | 64 return hostname_; |
56 } | 65 } |
57 | 66 |
58 virtual uint16 GetType() const OVERRIDE { | 67 virtual uint16 GetType() const OVERRIDE { |
59 return qtype_; | 68 return qtype_; |
60 } | 69 } |
61 | 70 |
62 virtual void Start() OVERRIDE { | 71 virtual void Start() OVERRIDE { |
63 EXPECT_FALSE(started_); | 72 EXPECT_FALSE(started_); |
64 started_ = true; | 73 started_ = true; |
| 74 if (delayed_) |
| 75 return; |
65 // Using WeakPtr to cleanly cancel when transaction is destroyed. | 76 // Using WeakPtr to cleanly cancel when transaction is destroyed. |
66 base::MessageLoop::current()->PostTask( | 77 base::MessageLoop::current()->PostTask( |
67 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); | 78 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); |
68 } | 79 } |
69 | 80 |
| 81 void FinishDelayedTransaction() { |
| 82 EXPECT_TRUE(delayed_); |
| 83 delayed_ = false; |
| 84 Finish(); |
| 85 } |
| 86 |
| 87 bool delayed() const { return delayed_; } |
| 88 |
70 private: | 89 private: |
71 void Finish() { | 90 void Finish() { |
72 switch (result_) { | 91 switch (result_) { |
73 case MockDnsClientRule::EMPTY: | 92 case MockDnsClientRule::EMPTY: |
74 case MockDnsClientRule::OK: { | 93 case MockDnsClientRule::OK: { |
75 std::string qname; | 94 std::string qname; |
76 DNSDomainFromDot(hostname_, &qname); | 95 DNSDomainFromDot(hostname_, &qname); |
77 DnsQuery query(0, qname, qtype_); | 96 DnsQuery query(0, qname, qtype_); |
78 | 97 |
79 DnsResponse response; | 98 DnsResponse response; |
(...skipping 49 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
129 NOTREACHED(); | 148 NOTREACHED(); |
130 break; | 149 break; |
131 } | 150 } |
132 } | 151 } |
133 | 152 |
134 MockDnsClientRule::Result result_; | 153 MockDnsClientRule::Result result_; |
135 const std::string hostname_; | 154 const std::string hostname_; |
136 const uint16 qtype_; | 155 const uint16 qtype_; |
137 DnsTransactionFactory::CallbackType callback_; | 156 DnsTransactionFactory::CallbackType callback_; |
138 bool started_; | 157 bool started_; |
| 158 bool delayed_; |
139 }; | 159 }; |
140 | 160 |
| 161 } // namespace |
141 | 162 |
142 // A DnsTransactionFactory which creates MockTransaction. | 163 // A DnsTransactionFactory which creates MockTransaction. |
143 class MockTransactionFactory : public DnsTransactionFactory { | 164 class MockTransactionFactory : public DnsTransactionFactory { |
144 public: | 165 public: |
145 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) | 166 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
146 : rules_(rules) {} | 167 : rules_(rules) {} |
| 168 |
147 virtual ~MockTransactionFactory() {} | 169 virtual ~MockTransactionFactory() {} |
148 | 170 |
149 virtual scoped_ptr<DnsTransaction> CreateTransaction( | 171 virtual scoped_ptr<DnsTransaction> CreateTransaction( |
150 const std::string& hostname, | 172 const std::string& hostname, |
151 uint16 qtype, | 173 uint16 qtype, |
152 const DnsTransactionFactory::CallbackType& callback, | 174 const DnsTransactionFactory::CallbackType& callback, |
153 const BoundNetLog&) OVERRIDE { | 175 const BoundNetLog&) OVERRIDE { |
154 return scoped_ptr<DnsTransaction>( | 176 MockTransaction* transaction = |
155 new MockTransaction(rules_, hostname, qtype, callback)); | 177 new MockTransaction(rules_, hostname, qtype, callback); |
| 178 if (transaction->delayed()) |
| 179 delayed_transactions_.push_back(transaction->AsWeakPtr()); |
| 180 return scoped_ptr<DnsTransaction>(transaction); |
| 181 } |
| 182 |
| 183 void CompleteDelayedTransactions() { |
| 184 DelayedTransactionList old_delayed_transactions; |
| 185 old_delayed_transactions.swap(delayed_transactions_); |
| 186 for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); |
| 187 it != old_delayed_transactions.end(); ++it) { |
| 188 if (it->get()) |
| 189 (*it)->FinishDelayedTransaction(); |
| 190 } |
156 } | 191 } |
157 | 192 |
158 private: | 193 private: |
| 194 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; |
| 195 |
159 MockDnsClientRuleList rules_; | 196 MockDnsClientRuleList rules_; |
| 197 DelayedTransactionList delayed_transactions_; |
160 }; | 198 }; |
161 | 199 |
162 class MockAddressSorter : public AddressSorter { | 200 MockDnsClient::MockDnsClient(const DnsConfig& config, |
163 public: | 201 const MockDnsClientRuleList& rules) |
164 virtual ~MockAddressSorter() {} | 202 : config_(config), |
165 virtual void Sort(const AddressList& list, | 203 factory_(new MockTransactionFactory(rules)), |
166 const CallbackType& callback) const OVERRIDE { | 204 address_sorter_(new MockAddressSorter()) { |
167 // Do nothing. | 205 } |
168 callback.Run(true, list); | |
169 } | |
170 }; | |
171 | 206 |
172 // MockDnsClient provides MockTransactionFactory. | 207 MockDnsClient::~MockDnsClient() {} |
173 class MockDnsClient : public DnsClient { | |
174 public: | |
175 MockDnsClient(const DnsConfig& config, | |
176 const MockDnsClientRuleList& rules) | |
177 : config_(config), factory_(rules) {} | |
178 virtual ~MockDnsClient() {} | |
179 | 208 |
180 virtual void SetConfig(const DnsConfig& config) OVERRIDE { | 209 void MockDnsClient::SetConfig(const DnsConfig& config) { |
181 config_ = config; | 210 config_ = config; |
182 } | 211 } |
183 | 212 |
184 virtual const DnsConfig* GetConfig() const OVERRIDE { | 213 const DnsConfig* MockDnsClient::GetConfig() const { |
185 return config_.IsValid() ? &config_ : NULL; | 214 return config_.IsValid() ? &config_ : NULL; |
186 } | 215 } |
187 | 216 |
188 virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { | 217 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { |
189 return config_.IsValid() ? &factory_ : NULL; | 218 return config_.IsValid() ? factory_.get() : NULL; |
190 } | 219 } |
191 | 220 |
192 virtual AddressSorter* GetAddressSorter() OVERRIDE { | 221 AddressSorter* MockDnsClient::GetAddressSorter() { |
193 return &address_sorter_; | 222 return address_sorter_.get(); |
194 } | 223 } |
195 | 224 |
196 private: | 225 void MockDnsClient::CompleteDelayedTransactions() { |
197 DnsConfig config_; | 226 factory_->CompleteDelayedTransactions(); |
198 MockTransactionFactory factory_; | |
199 MockAddressSorter address_sorter_; | |
200 }; | |
201 | |
202 } // namespace | |
203 | |
204 // static | |
205 scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, | |
206 const MockDnsClientRuleList& rules) { | |
207 return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); | |
208 } | 227 } |
209 | 228 |
210 } // namespace net | 229 } // namespace net |
OLD | NEW |