Index: components/certificate_transparency/log_dns_client_unittest.cc |
diff --git a/components/certificate_transparency/log_dns_client_unittest.cc b/components/certificate_transparency/log_dns_client_unittest.cc |
index 00782bb5e72bb99dbe567d0243a1a55d6b3a35b7..773b21358100ca6dcb34ed04b14f8fd26357a6b9 100644 |
--- a/components/certificate_transparency/log_dns_client_unittest.cc |
+++ b/components/certificate_transparency/log_dns_client_unittest.cc |
@@ -89,24 +89,47 @@ std::vector<std::string> GetSampleAuditProof(size_t length) { |
} |
return audit_proof; |
} |
-// MockAuditProofCallback can be used as an AuditProofCallback. |
-// It will record the arguments it is invoked with and provides a helpful |
-// method for pumping the message loop until it is invoked. |
-class MockAuditProofCallback { |
+// MockCallback can be used as a base::Callback. |
+// It will record the arguments it is invoked with, which can be examined by |
+// calling args() or arg<N>(). |
+// It only expects to be called once, but can be reused by calling Reset(). |
+// Example: |
+// MockCallback<int> mock; |
+// foo.RegisterCallback(mock.AsCallback()); |
+// foo.DoSomething(); |
+// mock.WaitUntilRun(TestTimeouts::action_max_timeout()); |
+// ASSERT_TRUE(mock.called()); |
+// ASSERT_EQ(123, mock.arg<0>()); |
+template <typename... Args> |
+class MockCallback { |
public: |
- MockAuditProofCallback() : called_(false) {} |
+ MockCallback() : called_(false) {} |
+ // Returns true if the callback has been invoked. |
bool called() const { return called_; } |
- net::Error result() const { return result_; } |
- const net::ct::MerkleAuditProof* proof() const { return proof_.get(); } |
- // Get this callback as an AuditProofCallback. |
- LogDnsClient::AuditProofCallback AsCallback() { |
- return base::Bind(&MockAuditProofCallback::Run, base::Unretained(this)); |
+ // The arguments that the callback was called with. |
+ const std::tuple<Args...>& args() const { |
+ DCHECK(called_); |
+ return args_; |
+ } |
+ |
+ // Gets a particular argument that the callback was invoked with. |
+ // For example, to get the first argument: mock_callback.arg<0>(); |
+ template <size_t N> |
+ const typename std::tuple_element<N, std::tuple<Args...>>::type& arg() const { |
+ DCHECK(called_); |
+ return std::get<N>(args_); |
+ } |
+ |
+ // Convert to a base::Callback. |
+ // TODO(robpercival): Could this reasonably be an implicit conversion? |
+ base::Callback<void(Args...)> AsCallback() { |
+ return base::Bind(&MockCallback::Run, base::Unretained(this)); |
} |
// Wait until either the callback is invoked or the message loop goes idle |
// (after a specified |timeout|). Returns immediately if the callback has |
// already been invoked. |
@@ -123,32 +146,44 @@ class MockAuditProofCallback { |
quit_closure, timeout); |
run_loop_->Run(); |
run_loop_.reset(); |
} |
+ void Reset() { |
+ called_ = false; |
+ args_ = std::tuple<Args...>(); |
+ } |
+ |
private: |
- void Run(net::Error result, |
- std::unique_ptr<net::ct::MerkleAuditProof> proof) { |
+ void Run(Args... args) { |
EXPECT_FALSE(called_); |
called_ = true; |
- result_ = result; |
- proof_ = std::move(proof); |
+ args_ = std::make_tuple(std::forward<Args>(args)...); |
if (run_loop_) { |
run_loop_->Quit(); |
} |
} |
// True if the callback has been invoked. |
bool called_; |
// The arguments that the callback was invoked with. |
- net::Error result_; |
- std::unique_ptr<net::ct::MerkleAuditProof> proof_; |
+ std::tuple<Args...> args_; |
// The RunLoop currently being used to pump the message loop, as a means to |
// execute this callback. |
std::unique_ptr<base::RunLoop> run_loop_; |
}; |
+class MockAuditProofCallback |
+ : public MockCallback<net::Error, |
+ std::unique_ptr<net::ct::MerkleAuditProof>> { |
+ public: |
+ net::Error result() const { return arg<0>(); } |
+ const net::ct::MerkleAuditProof* proof() const { return arg<1>().get(); } |
+}; |
+ |
+class MockClosure : public MockCallback<> {}; |
+ |
class LogDnsClientTest : public ::testing::TestWithParam<net::IoMode> { |
protected: |
LogDnsClientTest() |
: network_change_notifier_(net::NetworkChangeNotifier::CreateMock()) { |
mock_dns_.SetSocketReadMode(GetParam()); |
@@ -880,10 +915,67 @@ TEST_P(LogDnsClientTest, CanBeThrottledToOneQueryAtATime) { |
// TODO(robpercival): Enable this once MerkleAuditProof has tree_size. |
// EXPECT_THAT(callback3.proof()->tree_size, Eq(999999)); |
EXPECT_THAT(callback3.proof()->nodes, Eq(audit_proof)); |
} |
+TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) { |
+ const std::vector<std::string> audit_proof = GetSampleAuditProof(20); |
+ |
+ mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 123456); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.", |
+ audit_proof.begin(), |
+ audit_proof.begin() + 7); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("7.123456.999999.tree.ct.test.", |
+ audit_proof.begin() + 7, |
+ audit_proof.begin() + 14); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("14.123456.999999.tree.ct.test.", |
+ audit_proof.begin() + 14, |
+ audit_proof.end()); |
+ |
+ const size_t kMaxConcurrentQueries = 1; |
+ std::unique_ptr<LogDnsClient> log_client = |
+ CreateLogDnsClient(kMaxConcurrentQueries); |
+ |
+ // Start a query. |
+ MockAuditProofCallback proof_callback1; |
+ ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999, |
+ proof_callback1.AsCallback()), |
+ IsError(net::ERR_IO_PENDING)); |
+ |
+ MockClosure not_throttled_callback; |
+ log_client->NotifyWhenNotThrottled(not_throttled_callback.AsCallback()); |
+ |
+ proof_callback1.WaitUntilRun(TestTimeouts::action_max_timeout()); |
+ ASSERT_TRUE(proof_callback1.called()); |
+ ASSERT_TRUE(not_throttled_callback.called()); |
+ |
+ // Start another query to check |not_throttled_callback| doesn't fire again. |
+ not_throttled_callback.Reset(); |
+ |
+ mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[1], 666); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("0.666.999999.tree.ct.test.", |
+ audit_proof.begin(), |
+ audit_proof.begin() + 7); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("7.666.999999.tree.ct.test.", |
+ audit_proof.begin() + 7, |
+ audit_proof.begin() + 14); |
+ mock_dns_.ExpectAuditProofRequestAndResponse("14.666.999999.tree.ct.test.", |
+ audit_proof.begin() + 14, |
+ audit_proof.end()); |
+ |
+ MockAuditProofCallback proof_callback2; |
+ ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999, |
+ proof_callback2.AsCallback()), |
+ IsError(net::ERR_IO_PENDING)); |
+ |
+ // Give the query a chance to run. |
+ proof_callback2.WaitUntilRun(TestTimeouts::action_max_timeout()); |
+ |
+ ASSERT_TRUE(proof_callback2.called()); |
+ ASSERT_FALSE(not_throttled_callback.called()); |
+} |
+ |
INSTANTIATE_TEST_CASE_P(ReadMode, |
LogDnsClientTest, |
::testing::Values(net::IoMode::ASYNC, |
net::IoMode::SYNCHRONOUS)); |