Index: net/dns/dns_test_util.cc |
=================================================================== |
--- net/dns/dns_test_util.cc (revision 219192) |
+++ net/dns/dns_test_util.cc (working copy) |
@@ -15,9 +15,6 @@ |
#include "net/base/io_buffer.h" |
#include "net/base/net_errors.h" |
#include "net/dns/address_sorter.h" |
-#include "net/dns/dns_client.h" |
-#include "net/dns/dns_config_service.h" |
-#include "net/dns/dns_protocol.h" |
#include "net/dns/dns_query.h" |
#include "net/dns/dns_response.h" |
#include "net/dns/dns_transaction.h" |
@@ -26,6 +23,16 @@ |
namespace net { |
namespace { |
+class MockAddressSorter : public AddressSorter { |
+ public: |
+ virtual ~MockAddressSorter() {} |
+ virtual void Sort(const AddressList& list, |
+ const CallbackType& callback) const OVERRIDE { |
+ // Do nothing. |
+ callback.Run(true, list); |
+ } |
+}; |
+ |
// A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
class MockTransaction : public DnsTransaction, |
public base::SupportsWeakPtr<MockTransaction> { |
@@ -38,7 +45,8 @@ |
hostname_(hostname), |
qtype_(qtype), |
callback_(callback), |
- started_(false) { |
+ started_(false), |
+ delayed_(false) { |
// Find the relevant rule which matches |qtype| and prefix of |hostname|. |
for (size_t i = 0; i < rules.size(); ++i) { |
const std::string& prefix = rules[i].prefix; |
@@ -46,6 +54,7 @@ |
(hostname.size() >= prefix.size()) && |
(hostname.compare(0, prefix.size(), prefix) == 0)) { |
result_ = rules[i].result; |
+ delayed_ = rules[i].delay; |
break; |
} |
} |
@@ -62,11 +71,21 @@ |
virtual void Start() OVERRIDE { |
EXPECT_FALSE(started_); |
started_ = true; |
+ if (delayed_) |
+ return; |
// Using WeakPtr to cleanly cancel when transaction is destroyed. |
base::MessageLoop::current()->PostTask( |
FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); |
} |
+ void FinishDelayedTransaction() { |
+ EXPECT_TRUE(delayed_); |
+ delayed_ = false; |
+ Finish(); |
+ } |
+ |
+ bool delayed() const { return delayed_; } |
+ |
private: |
void Finish() { |
switch (result_) { |
@@ -136,14 +155,17 @@ |
const uint16 qtype_; |
DnsTransactionFactory::CallbackType callback_; |
bool started_; |
+ bool delayed_; |
}; |
+} // namespace |
// A DnsTransactionFactory which creates MockTransaction. |
class MockTransactionFactory : public DnsTransactionFactory { |
public: |
explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
: rules_(rules) {} |
+ |
virtual ~MockTransactionFactory() {} |
virtual scoped_ptr<DnsTransaction> CreateTransaction( |
@@ -151,60 +173,57 @@ |
uint16 qtype, |
const DnsTransactionFactory::CallbackType& callback, |
const BoundNetLog&) OVERRIDE { |
- return scoped_ptr<DnsTransaction>( |
- new MockTransaction(rules_, hostname, qtype, callback)); |
+ MockTransaction* transaction = |
+ new MockTransaction(rules_, hostname, qtype, callback); |
+ if (transaction->delayed()) |
+ delayed_transactions_.push_back(transaction->AsWeakPtr()); |
+ return scoped_ptr<DnsTransaction>(transaction); |
} |
+ void CompleteDelayedTransactions() { |
+ DelayedTransactionList old_delayed_transactions; |
+ old_delayed_transactions.swap(delayed_transactions_); |
+ for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); |
+ it != old_delayed_transactions.end(); ++it) { |
+ if (it->get()) |
+ (*it)->FinishDelayedTransaction(); |
+ } |
+ } |
+ |
private: |
+ typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; |
+ |
MockDnsClientRuleList rules_; |
+ DelayedTransactionList delayed_transactions_; |
}; |
-class MockAddressSorter : public AddressSorter { |
- public: |
- virtual ~MockAddressSorter() {} |
- virtual void Sort(const AddressList& list, |
- const CallbackType& callback) const OVERRIDE { |
- // Do nothing. |
- callback.Run(true, list); |
- } |
-}; |
+MockDnsClient::MockDnsClient(const DnsConfig& config, |
+ const MockDnsClientRuleList& rules) |
+ : config_(config), |
+ factory_(new MockTransactionFactory(rules)), |
+ address_sorter_(new MockAddressSorter()) { |
+} |
-// MockDnsClient provides MockTransactionFactory. |
-class MockDnsClient : public DnsClient { |
- public: |
- MockDnsClient(const DnsConfig& config, |
- const MockDnsClientRuleList& rules) |
- : config_(config), factory_(rules) {} |
- virtual ~MockDnsClient() {} |
+MockDnsClient::~MockDnsClient() {} |
- virtual void SetConfig(const DnsConfig& config) OVERRIDE { |
- config_ = config; |
- } |
+void MockDnsClient::SetConfig(const DnsConfig& config) { |
+ config_ = config; |
+} |
- virtual const DnsConfig* GetConfig() const OVERRIDE { |
- return config_.IsValid() ? &config_ : NULL; |
- } |
+const DnsConfig* MockDnsClient::GetConfig() const { |
+ return config_.IsValid() ? &config_ : NULL; |
+} |
- virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { |
- return config_.IsValid() ? &factory_ : NULL; |
- } |
+DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { |
+ return config_.IsValid() ? factory_.get() : NULL; |
+} |
- virtual AddressSorter* GetAddressSorter() OVERRIDE { |
- return &address_sorter_; |
- } |
+AddressSorter* MockDnsClient::GetAddressSorter() { |
+ return address_sorter_.get(); |
+} |
- private: |
- DnsConfig config_; |
- MockTransactionFactory factory_; |
- MockAddressSorter address_sorter_; |
-}; |
- |
-} // namespace |
- |
-// static |
-scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, |
- const MockDnsClientRuleList& rules) { |
- return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); |
+void MockDnsClient::CompleteDelayedTransactions() { |
+ factory_->CompleteDelayedTransactions(); |
} |
} // namespace net |