diff --git a/google/cloud/internal/async_retry_loop.h b/google/cloud/internal/async_retry_loop.h index 9866947c5b4d3..1d121d05902df 100644 --- a/google/cloud/internal/async_retry_loop.h +++ b/google/cloud/internal/async_retry_loop.h @@ -170,16 +170,23 @@ struct FutureValueType> { * functions. If the value is visible, the retry loop will stop on the next * callback and/or before the next request or timer is issued. */ -template +template < + typename Functor, typename Request, typename RetryPolicyType, + typename ReturnType = google::cloud::internal::invoke_result_t< + Functor, google::cloud::CompletionQueue&, + std::shared_ptr, ImmutableOptions, Request const&>> class AsyncRetryLoopImpl : public std::enable_shared_from_this< AsyncRetryLoopImpl> { public: - AsyncRetryLoopImpl(std::unique_ptr retry_policy, - std::unique_ptr backoff_policy, - Idempotency idempotency, google::cloud::CompletionQueue cq, - Functor&& functor, ImmutableOptions options, - Request request, char const* location) + AsyncRetryLoopImpl( + std::unique_ptr retry_policy, + std::unique_ptr backoff_policy, Idempotency idempotency, + google::cloud::CompletionQueue cq, Functor&& functor, + ImmutableOptions options, Request request, char const* location, + std::function< + bool(typename FutureValueType::value_type const&)> + attempt_predicate = {}) : retry_policy_(std::move(retry_policy)), backoff_policy_(std::move(backoff_policy)), idempotency_(idempotency), @@ -188,11 +195,9 @@ class AsyncRetryLoopImpl functor_(std::forward(functor)), request_(std::move(request)), location_(location), - call_context_(std::move(options)) {} + call_context_(std::move(options)), + attempt_predicate_(std::move(attempt_predicate)) {} - using ReturnType = ::google::cloud::internal::invoke_result_t< - Functor, google::cloud::CompletionQueue&, - std::shared_ptr, ImmutableOptions, Request const&>; using T = typename FutureValueType::value_type; future Start() { @@ -256,8 +261,11 @@ class AsyncRetryLoopImpl } void OnAttempt(T result) { - // A successful attempt, set the value and finish the loop. - if (result.ok()) return SetDone(std::move(result)); + // If the attempt is successful and satisfies the attempt predicate, if + // provided, set the value and finish the loop. + if (result.ok() && (!attempt_predicate_ || attempt_predicate_(result))) { + return SetDone(std::move(result)); + } // Some kind of failure, first verify that it is retryable. last_status_ = GetResultStatus(std::move(result)); auto delay = @@ -325,6 +333,8 @@ class AsyncRetryLoopImpl CallContext call_context_; Status last_status_; promise result_; + std::function::value_type const&)> + attempt_predicate_; // Only the following variables require synchronization, as they coordinate // the work between the retry loop (which would be lock-free) and the cancel @@ -339,17 +349,23 @@ class AsyncRetryLoopImpl /** * Create the right AsyncRetryLoopImpl object and start the retry loop on it. */ -template , - ImmutableOptions, Request const&>::value, - int> = 0> -auto AsyncRetryLoop(std::unique_ptr retry_policy, - std::unique_ptr backoff_policy, - Idempotency idempotency, google::cloud::CompletionQueue cq, - Functor&& functor, ImmutableOptions options, - Request request, char const* location) +template < + typename Functor, typename Request, typename RetryPolicyType, + std::enable_if_t, ImmutableOptions, + Request const&>::value, + int> = 0, + typename ReturnType = google::cloud::internal::invoke_result_t< + Functor, google::cloud::CompletionQueue&, + std::shared_ptr, ImmutableOptions, Request const&>> +auto AsyncRetryLoop( + std::unique_ptr retry_policy, + std::unique_ptr backoff_policy, Idempotency idempotency, + google::cloud::CompletionQueue cq, Functor&& functor, + ImmutableOptions options, Request request, char const* location, + std::function::value_type const&)> + attempt_predicate = {}) -> google::cloud::internal::invoke_result_t< Functor, google::cloud::CompletionQueue&, std::shared_ptr, ImmutableOptions, @@ -358,7 +374,7 @@ auto AsyncRetryLoop(std::unique_ptr retry_policy, std::make_shared>( std::move(retry_policy), std::move(backoff_policy), idempotency, std::move(cq), std::forward(functor), options, - std::move(request), location); + std::move(request), location, std::move(attempt_predicate)); return loop->Start(); } diff --git a/google/cloud/internal/async_retry_loop_test.cc b/google/cloud/internal/async_retry_loop_test.cc index 7e200d4eec613..98f77b68ff07e 100644 --- a/google/cloud/internal/async_retry_loop_test.cc +++ b/google/cloud/internal/async_retry_loop_test.cc @@ -160,6 +160,32 @@ TEST(AsyncRetryLoopTest, TransientThenSuccess) { EXPECT_EQ(84, *actual); } +TEST(AsyncRetryLoopTest, TransientPredicateThenSuccess) { + AutomaticallyCreatedBackgroundThreads background; + ::testing::MockFunction)> mock_predicate; + EXPECT_CALL(mock_predicate, Call) + .WillOnce([](StatusOr const&) { return false; }) + .WillOnce([](StatusOr const&) { return false; }) + .WillOnce([](StatusOr const&) { return true; }); + + auto pending = AsyncRetryLoop( + TestRetryPolicy(), TestBackoffPolicy(), Idempotency::kIdempotent, + background.cq(), + [&](google::cloud::CompletionQueue&, auto, + ImmutableOptions const& options, int request) { + EXPECT_EQ(options->get(), "TransientPredicateThenSuccess"); + return make_ready_future(StatusOr(2 * request)); + }, + MakeImmutableOptions( + Options{}.set("TransientPredicateThenSuccess")), + 42, "error message", mock_predicate.AsStdFunction()); + + OptionsSpan overlay(Options{}.set("uh-oh")); + StatusOr actual = pending.get(); + ASSERT_THAT(actual.status(), IsOk()); + EXPECT_EQ(84, *actual); +} + TEST(AsyncRetryLoopTest, ReturnJustStatus) { int counter = 0; AutomaticallyCreatedBackgroundThreads background;