aws_smithy_runtime/client/retries/strategy/
standard.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use aws_smithy_async::time::SharedTimeSource;
10use tokio::sync::OwnedSemaphorePermit;
11use tracing::{debug, trace};
12
13use aws_smithy_runtime_api::box_error::BoxError;
14use aws_smithy_runtime_api::client::interceptors::context::{
15    BeforeTransmitInterceptorContextMut, InterceptorContext,
16};
17use aws_smithy_runtime_api::client::interceptors::Intercept;
18use aws_smithy_runtime_api::client::retries::classifiers::{RetryAction, RetryReason};
19use aws_smithy_runtime_api::client::retries::{RequestAttempts, RetryStrategy, ShouldAttempt};
20use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
21use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
22use aws_smithy_types::retry::{ErrorKind, RetryConfig, RetryMode};
23
24use crate::client::retries::classifiers::run_classifiers_on_ctx;
25use crate::client::retries::client_rate_limiter::{ClientRateLimiter, RequestReason};
26use crate::client::retries::strategy::standard::ReleaseResult::{
27    APermitWasReleased, NoPermitWasReleased,
28};
29use crate::client::retries::token_bucket::TokenBucket;
30use crate::client::retries::{ClientRateLimiterPartition, RetryPartition, RetryPartitionInner};
31use crate::static_partition_map::StaticPartitionMap;
32
33static CLIENT_RATE_LIMITER: StaticPartitionMap<ClientRateLimiterPartition, ClientRateLimiter> =
34    StaticPartitionMap::new();
35
36/// Used by token bucket interceptor to ensure a TokenBucket always exists in config bag
37static TOKEN_BUCKET: StaticPartitionMap<RetryPartition, TokenBucket> = StaticPartitionMap::new();
38
39/// Retry strategy with exponential backoff, max attempts, and a token bucket.
40#[derive(Debug, Default)]
41pub struct StandardRetryStrategy {
42    retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
43}
44
45impl Storable for StandardRetryStrategy {
46    type Storer = StoreReplace<Self>;
47}
48
49impl StandardRetryStrategy {
50    /// Create a new standard retry strategy with the given config.
51    pub fn new() -> Self {
52        Default::default()
53    }
54
55    fn release_retry_permit(&self, token_bucket: &TokenBucket) -> ReleaseResult {
56        let mut retry_permit = self.retry_permit.lock().unwrap();
57        match retry_permit.take() {
58            Some(p) => {
59                // Retry succeeded: reward success and forget permit if configured, otherwise release permit back
60                if token_bucket.success_reward() > 0.0 {
61                    token_bucket.reward_success();
62                    p.forget();
63                } else {
64                    drop(p); // Original behavior - release back to bucket
65                }
66                APermitWasReleased
67            }
68            None => {
69                // First-attempt success: reward success or regenerate token
70                if token_bucket.success_reward() > 0.0 {
71                    token_bucket.reward_success();
72                } else {
73                    token_bucket.regenerate_a_token();
74                }
75                NoPermitWasReleased
76            }
77        }
78    }
79
80    fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
81        let mut old_retry_permit = self.retry_permit.lock().unwrap();
82        if let Some(p) = old_retry_permit.replace(new_retry_permit) {
83            // Whenever we set a new retry permit, and it replaces the old one, we need to "forget"
84            // the old permit, removing it from the bucket forever.
85            p.forget()
86        }
87    }
88
89    /// Returns a [`ClientRateLimiter`] if adaptive retry is configured.
90    fn adaptive_retry_rate_limiter(
91        runtime_components: &RuntimeComponents,
92        cfg: &ConfigBag,
93    ) -> Option<ClientRateLimiter> {
94        let retry_config = cfg.load::<RetryConfig>().expect("retry config is required");
95        if retry_config.mode() == RetryMode::Adaptive {
96            if let Some(time_source) = runtime_components.time_source() {
97                let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
98                let seconds_since_unix_epoch = time_source
99                    .now()
100                    .duration_since(SystemTime::UNIX_EPOCH)
101                    .expect("the present takes place after the UNIX_EPOCH")
102                    .as_secs_f64();
103                let client_rate_limiter = match &retry_partition.inner {
104                    RetryPartitionInner::Default(_) => {
105                        let client_rate_limiter_partition =
106                            ClientRateLimiterPartition::new(retry_partition.clone());
107                        CLIENT_RATE_LIMITER.get_or_init(client_rate_limiter_partition, || {
108                            ClientRateLimiter::new(seconds_since_unix_epoch)
109                        })
110                    }
111                    RetryPartitionInner::Custom {
112                        client_rate_limiter,
113                        ..
114                    } => client_rate_limiter.clone(),
115                };
116                return Some(client_rate_limiter);
117            }
118        }
119        None
120    }
121
122    fn calculate_backoff(
123        &self,
124        runtime_components: &RuntimeComponents,
125        cfg: &ConfigBag,
126        retry_cfg: &RetryConfig,
127        retry_reason: &RetryAction,
128    ) -> Result<Duration, ShouldAttempt> {
129        let request_attempts = cfg
130            .load::<RequestAttempts>()
131            .expect("at least one request attempt is made before any retry is attempted")
132            .attempts();
133
134        match retry_reason {
135            RetryAction::RetryIndicated(RetryReason::RetryableError { kind, retry_after }) => {
136                if let Some(delay) = *retry_after {
137                    let delay = delay.min(retry_cfg.max_backoff());
138                    debug!("explicit request from server to delay {delay:?} before retrying");
139                    Ok(delay)
140                } else if let Some(delay) =
141                    check_rate_limiter_for_delay(runtime_components, cfg, *kind)
142                {
143                    let delay = delay.min(retry_cfg.max_backoff());
144                    debug!("rate limiter has requested a {delay:?} delay before retrying");
145                    Ok(delay)
146                } else {
147                    let base = if retry_cfg.use_static_exponential_base() {
148                        1.0
149                    } else {
150                        fastrand::f64()
151                    };
152                    Ok(calculate_exponential_backoff(
153                        // Generate a random base multiplier to create jitter
154                        base,
155                        // Get the backoff time multiplier in seconds (with fractional seconds)
156                        retry_cfg.initial_backoff().as_secs_f64(),
157                        // `self.local.attempts` tracks number of requests made including the initial request
158                        // The initial attempt shouldn't count towards backoff calculations, so we subtract it
159                        request_attempts - 1,
160                        // Maximum backoff duration as a fallback to prevent overflow when calculating a power
161                        retry_cfg.max_backoff(),
162                    ))
163                }
164            }
165            RetryAction::RetryForbidden | RetryAction::NoActionIndicated => {
166                debug!(
167                    attempts = request_attempts,
168                    max_attempts = retry_cfg.max_attempts(),
169                    "encountered un-retryable error"
170                );
171                Err(ShouldAttempt::No)
172            }
173            _ => unreachable!("RetryAction is non-exhaustive"),
174        }
175    }
176}
177
178enum ReleaseResult {
179    APermitWasReleased,
180    NoPermitWasReleased,
181}
182
183impl RetryStrategy for StandardRetryStrategy {
184    fn should_attempt_initial_request(
185        &self,
186        runtime_components: &RuntimeComponents,
187        cfg: &ConfigBag,
188    ) -> Result<ShouldAttempt, BoxError> {
189        if let Some(crl) = Self::adaptive_retry_rate_limiter(runtime_components, cfg) {
190            let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
191            if let Err(delay) = crl.acquire_permission_to_send_a_request(
192                seconds_since_unix_epoch,
193                RequestReason::InitialRequest,
194            ) {
195                return Ok(ShouldAttempt::YesAfterDelay(delay));
196            }
197        } else {
198            debug!("no client rate limiter configured, so no token is required for the initial request.");
199        }
200
201        Ok(ShouldAttempt::Yes)
202    }
203
204    fn should_attempt_retry(
205        &self,
206        ctx: &InterceptorContext,
207        runtime_components: &RuntimeComponents,
208        cfg: &ConfigBag,
209    ) -> Result<ShouldAttempt, BoxError> {
210        let retry_cfg = cfg.load::<RetryConfig>().expect("retry config is required");
211
212        // bookkeeping
213        let token_bucket = cfg.load::<TokenBucket>().expect("token bucket is required");
214        // run the classifier against the context to determine if we should retry
215        let retry_classifiers = runtime_components.retry_classifiers();
216        let classifier_result = run_classifiers_on_ctx(retry_classifiers, ctx);
217
218        // (adaptive only): update fill rate
219        // NOTE: SEP indicates doing bookkeeping before asking if we should retry. We need to know if
220        // the error was a throttling error though to do adaptive retry bookkeeping so we take
221        // advantage of that information being available via the classifier result
222        let error_kind = error_kind(&classifier_result);
223        let is_throttling_error = error_kind
224            .map(|kind| kind == ErrorKind::ThrottlingError)
225            .unwrap_or(false);
226        update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error);
227
228        // on success release any retry quota held by previous attempts, reward success when indicated
229        if !ctx.is_failed() {
230            self.release_retry_permit(token_bucket);
231        }
232        // end bookkeeping
233
234        let request_attempts = cfg
235            .load::<RequestAttempts>()
236            .expect("at least one request attempt is made before any retry is attempted")
237            .attempts();
238
239        // check if retry should be attempted
240        if !classifier_result.should_retry() {
241            debug!(
242                "attempt #{request_attempts} classified as {:?}, not retrying",
243                classifier_result
244            );
245            return Ok(ShouldAttempt::No);
246        }
247
248        // check if we're out of attempts
249        if request_attempts >= retry_cfg.max_attempts() {
250            debug!(
251                attempts = request_attempts,
252                max_attempts = retry_cfg.max_attempts(),
253                "not retrying because we are out of attempts"
254            );
255            return Ok(ShouldAttempt::No);
256        }
257
258        //  acquire permit for retry
259        let error_kind = error_kind.expect("result was classified retryable");
260        match token_bucket.acquire(&error_kind) {
261            Some(permit) => self.set_retry_permit(permit),
262            None => {
263                debug!("attempt #{request_attempts} failed with {error_kind:?}; However, not enough retry quota is available for another attempt so no retry will be attempted.");
264                return Ok(ShouldAttempt::No);
265            }
266        }
267
268        // calculate delay until next attempt
269        let backoff =
270            match self.calculate_backoff(runtime_components, cfg, retry_cfg, &classifier_result) {
271                Ok(value) => value,
272                // In some cases, backoff calculation will decide that we shouldn't retry at all.
273                Err(value) => return Ok(value),
274            };
275
276        debug!(
277            "attempt #{request_attempts} failed with {:?}; retrying after {:?}",
278            classifier_result, backoff
279        );
280        Ok(ShouldAttempt::YesAfterDelay(backoff))
281    }
282}
283
284/// extract the error kind from the classifier result if available
285fn error_kind(classifier_result: &RetryAction) -> Option<ErrorKind> {
286    match classifier_result {
287        RetryAction::RetryIndicated(RetryReason::RetryableError { kind, .. }) => Some(*kind),
288        _ => None,
289    }
290}
291
292fn update_rate_limiter_if_exists(
293    runtime_components: &RuntimeComponents,
294    cfg: &ConfigBag,
295    is_throttling_error: bool,
296) {
297    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
298        let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
299        crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error);
300    }
301}
302
303fn check_rate_limiter_for_delay(
304    runtime_components: &RuntimeComponents,
305    cfg: &ConfigBag,
306    kind: ErrorKind,
307) -> Option<Duration> {
308    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
309        let retry_reason = if kind == ErrorKind::ThrottlingError {
310            RequestReason::RetryTimeout
311        } else {
312            RequestReason::Retry
313        };
314        if let Err(delay) = crl.acquire_permission_to_send_a_request(
315            get_seconds_since_unix_epoch(runtime_components),
316            retry_reason,
317        ) {
318            return Some(delay);
319        }
320    }
321
322    None
323}
324
325pub(super) fn calculate_exponential_backoff(
326    base: f64,
327    initial_backoff: f64,
328    retry_attempts: u32,
329    max_backoff: Duration,
330) -> Duration {
331    let result = match 2_u32
332        .checked_pow(retry_attempts)
333        .map(|power| (power as f64) * initial_backoff)
334    {
335        Some(backoff) => match Duration::try_from_secs_f64(backoff) {
336            Ok(result) => result.min(max_backoff),
337            Err(e) => {
338                tracing::warn!("falling back to {max_backoff:?} as `Duration` could not be created for exponential backoff: {e}");
339                max_backoff
340            }
341        },
342        None => max_backoff,
343    };
344
345    // Apply jitter to `result`, and note that it can be applied to `max_backoff`.
346    // Won't panic because `base` is either in range 0..1 or a constant 1 in testing (if configured).
347    result.mul_f64(base)
348}
349
350pub(super) fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 {
351    let request_time = runtime_components
352        .time_source()
353        .expect("time source required for retries");
354    request_time
355        .now()
356        .duration_since(SystemTime::UNIX_EPOCH)
357        .unwrap()
358        .as_secs_f64()
359}
360
361/// Interceptor registered in default retry plugin that ensures a token bucket exists in config
362/// bag for every operation. Token bucket provided is partitioned by the retry partition **in the
363/// config bag** at the time an operation is executed.
364#[derive(Debug)]
365pub(crate) struct TokenBucketProvider {
366    default_partition: RetryPartition,
367    token_bucket: TokenBucket,
368}
369
370impl TokenBucketProvider {
371    /// Create a new token bucket provider with the given default retry partition.
372    ///
373    /// NOTE: This partition should be the one used for every operation on a client
374    /// unless config is overridden.
375    pub(crate) fn new(default_partition: RetryPartition, time_source: SharedTimeSource) -> Self {
376        let token_bucket = TOKEN_BUCKET.get_or_init(default_partition.clone(), || {
377            let mut tb = TokenBucket::default();
378            tb.update_time_source(time_source);
379            tb
380        });
381        Self {
382            default_partition,
383            token_bucket,
384        }
385    }
386}
387
388impl Intercept for TokenBucketProvider {
389    fn name(&self) -> &'static str {
390        "TokenBucketProvider"
391    }
392
393    fn modify_before_retry_loop(
394        &self,
395        _context: &mut BeforeTransmitInterceptorContextMut<'_>,
396        runtime_components: &RuntimeComponents,
397        cfg: &mut ConfigBag,
398    ) -> Result<(), BoxError> {
399        let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
400
401        let tb = match &retry_partition.inner {
402            RetryPartitionInner::Default(name) => {
403                // we store the original retry partition configured and associated token bucket
404                // for the client when created so that we can avoid locking on _every_ request
405                // from _every_ client
406                if name == self.default_partition.name() {
407                    // avoid contention on the global lock
408                    self.token_bucket.clone()
409                } else {
410                    TOKEN_BUCKET.get_or_init(retry_partition.clone(), || {
411                        let mut tb = TokenBucket::default();
412                        tb.update_time_source(runtime_components.time_source().unwrap_or_default());
413                        tb
414                    })
415                }
416            }
417            RetryPartitionInner::Custom { token_bucket, .. } => token_bucket.clone(),
418        };
419
420        trace!("token bucket for {retry_partition:?} added to config bag");
421        let mut layer = Layer::new("token_bucket_partition");
422        layer.store_put(tb);
423        cfg.push_layer(layer);
424        Ok(())
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    #[allow(unused_imports)] // will be unused with `--no-default-features --features client`
431    use std::fmt;
432    use std::sync::Mutex;
433    use std::time::Duration;
434
435    use aws_smithy_async::time::SystemTimeSource;
436    use aws_smithy_runtime_api::client::interceptors::context::{
437        Input, InterceptorContext, Output,
438    };
439    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
440    use aws_smithy_runtime_api::client::retries::classifiers::{
441        ClassifyRetry, RetryAction, SharedRetryClassifier,
442    };
443    use aws_smithy_runtime_api::client::retries::{
444        AlwaysRetry, RequestAttempts, RetryStrategy, ShouldAttempt,
445    };
446    use aws_smithy_runtime_api::client::runtime_components::{
447        RuntimeComponents, RuntimeComponentsBuilder,
448    };
449    use aws_smithy_types::config_bag::{ConfigBag, Layer};
450    use aws_smithy_types::retry::{ErrorKind, RetryConfig};
451
452    use super::{calculate_exponential_backoff, StandardRetryStrategy};
453    use crate::client::retries::{ClientRateLimiter, RetryPartition, TokenBucket};
454
455    #[test]
456    fn no_retry_necessary_for_ok_result() {
457        let cfg = ConfigBag::of_layers(vec![{
458            let mut layer = Layer::new("test");
459            layer.store_put(RetryConfig::standard());
460            layer.store_put(RequestAttempts::new(1));
461            layer.store_put(TokenBucket::default());
462            layer
463        }]);
464        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
465        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
466        let strategy = StandardRetryStrategy::default();
467        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
468
469        let actual = strategy
470            .should_attempt_retry(&ctx, &rc, &cfg)
471            .expect("method is infallible for this use");
472        assert_eq!(ShouldAttempt::No, actual);
473    }
474
475    fn set_up_cfg_and_context(
476        error_kind: ErrorKind,
477        current_request_attempts: u32,
478        retry_config: RetryConfig,
479    ) -> (InterceptorContext, RuntimeComponents, ConfigBag) {
480        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
481        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
482        let rc = RuntimeComponentsBuilder::for_tests()
483            .with_retry_classifier(SharedRetryClassifier::new(AlwaysRetry(error_kind)))
484            .build()
485            .unwrap();
486        let mut layer = Layer::new("test");
487        layer.store_put(RequestAttempts::new(current_request_attempts));
488        layer.store_put(retry_config);
489        layer.store_put(TokenBucket::default());
490        let cfg = ConfigBag::of_layers(vec![layer]);
491
492        (ctx, rc, cfg)
493    }
494
495    // Test that error kinds produce the correct "retry after X seconds" output.
496    // All error kinds are handled in the same way for the standard strategy.
497    fn test_should_retry_error_kind(error_kind: ErrorKind) {
498        let (ctx, rc, cfg) = set_up_cfg_and_context(
499            error_kind,
500            3,
501            RetryConfig::standard()
502                .with_use_static_exponential_base(true)
503                .with_max_attempts(4),
504        );
505        let strategy = StandardRetryStrategy::new();
506        let actual = strategy
507            .should_attempt_retry(&ctx, &rc, &cfg)
508            .expect("method is infallible for this use");
509        assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual);
510    }
511
512    #[test]
513    fn should_retry_transient_error_result_after_2s() {
514        test_should_retry_error_kind(ErrorKind::TransientError);
515    }
516
517    #[test]
518    fn should_retry_client_error_result_after_2s() {
519        test_should_retry_error_kind(ErrorKind::ClientError);
520    }
521
522    #[test]
523    fn should_retry_server_error_result_after_2s() {
524        test_should_retry_error_kind(ErrorKind::ServerError);
525    }
526
527    #[test]
528    fn should_retry_throttling_error_result_after_2s() {
529        test_should_retry_error_kind(ErrorKind::ThrottlingError);
530    }
531
532    #[test]
533    fn dont_retry_when_out_of_attempts() {
534        let current_attempts = 4;
535        let max_attempts = current_attempts;
536        let (ctx, rc, cfg) = set_up_cfg_and_context(
537            ErrorKind::TransientError,
538            current_attempts,
539            RetryConfig::standard()
540                .with_use_static_exponential_base(true)
541                .with_max_attempts(max_attempts),
542        );
543        let strategy = StandardRetryStrategy::new();
544        let actual = strategy
545            .should_attempt_retry(&ctx, &rc, &cfg)
546            .expect("method is infallible for this use");
547        assert_eq!(ShouldAttempt::No, actual);
548    }
549
550    #[test]
551    fn should_not_panic_when_exponential_backoff_duration_could_not_be_created() {
552        let (ctx, rc, cfg) = set_up_cfg_and_context(
553            ErrorKind::TransientError,
554            // Greater than 32 when subtracted by 1 in `calculate_backoff`, causing overflow in `calculate_exponential_backoff`
555            33,
556            RetryConfig::standard()
557                .with_use_static_exponential_base(true)
558                .with_max_attempts(100), // Any value greater than 33 will do
559        );
560        let strategy = StandardRetryStrategy::new();
561        let actual = strategy
562            .should_attempt_retry(&ctx, &rc, &cfg)
563            .expect("method is infallible for this use");
564        assert_eq!(ShouldAttempt::YesAfterDelay(MAX_BACKOFF), actual);
565    }
566
567    #[test]
568    fn should_yield_client_rate_limiter_from_custom_partition() {
569        let expected = ClientRateLimiter::builder().token_refill_rate(3.14).build();
570        let cfg = ConfigBag::of_layers(vec![
571            // Emulate default config layer overriden by a user config layer
572            {
573                let mut layer = Layer::new("default");
574                layer.store_put(RetryPartition::new("default"));
575                layer
576            },
577            {
578                let mut layer = Layer::new("user");
579                layer.store_put(RetryConfig::adaptive());
580                layer.store_put(
581                    RetryPartition::custom("user")
582                        .client_rate_limiter(expected.clone())
583                        .build(),
584                );
585                layer
586            },
587        ]);
588        let rc = RuntimeComponentsBuilder::for_tests()
589            .with_time_source(Some(SystemTimeSource::new()))
590            .build()
591            .unwrap();
592        let actual = StandardRetryStrategy::adaptive_retry_rate_limiter(&rc, &cfg)
593            .expect("should yield client rate limiter from custom partition");
594        assert!(std::sync::Arc::ptr_eq(&expected.inner, &actual.inner));
595    }
596
597    #[allow(dead_code)] // will be unused with `--no-default-features --features client`
598    #[derive(Debug)]
599    struct PresetReasonRetryClassifier {
600        retry_actions: Mutex<Vec<RetryAction>>,
601    }
602
603    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
604    impl PresetReasonRetryClassifier {
605        fn new(mut retry_reasons: Vec<RetryAction>) -> Self {
606            // We'll pop the retry_reasons in reverse order, so we reverse the list to fix that.
607            retry_reasons.reverse();
608            Self {
609                retry_actions: Mutex::new(retry_reasons),
610            }
611        }
612    }
613
614    impl ClassifyRetry for PresetReasonRetryClassifier {
615        fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
616            // Check for a result
617            let output_or_error = ctx.output_or_error();
618            // Check for an error
619            match output_or_error {
620                Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
621                _ => (),
622            };
623
624            let mut retry_actions = self.retry_actions.lock().unwrap();
625            if retry_actions.len() == 1 {
626                retry_actions.first().unwrap().clone()
627            } else {
628                retry_actions.pop().unwrap()
629            }
630        }
631
632        fn name(&self) -> &'static str {
633            "Always returns a preset retry reason"
634        }
635    }
636
637    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
638    fn setup_test(
639        retry_reasons: Vec<RetryAction>,
640        retry_config: RetryConfig,
641    ) -> (ConfigBag, RuntimeComponents, InterceptorContext) {
642        let rc = RuntimeComponentsBuilder::for_tests()
643            .with_retry_classifier(SharedRetryClassifier::new(
644                PresetReasonRetryClassifier::new(retry_reasons),
645            ))
646            .build()
647            .unwrap();
648        let mut layer = Layer::new("test");
649        layer.store_put(retry_config);
650        let cfg = ConfigBag::of_layers(vec![layer]);
651        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
652        // This type doesn't matter b/c the classifier will just return whatever we tell it to.
653        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
654
655        (cfg, rc, ctx)
656    }
657
658    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
659    #[test]
660    fn eventual_success() {
661        let (mut cfg, rc, mut ctx) = setup_test(
662            vec![RetryAction::server_error()],
663            RetryConfig::standard()
664                .with_use_static_exponential_base(true)
665                .with_max_attempts(5),
666        );
667        let strategy = StandardRetryStrategy::new();
668        cfg.interceptor_state().store_put(TokenBucket::default());
669        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
670
671        cfg.interceptor_state().store_put(RequestAttempts::new(1));
672        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
673        let dur = should_retry.expect_delay();
674        assert_eq!(dur, Duration::from_secs(1));
675        assert_eq!(token_bucket.available_permits(), 495);
676
677        cfg.interceptor_state().store_put(RequestAttempts::new(2));
678        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
679        let dur = should_retry.expect_delay();
680        assert_eq!(dur, Duration::from_secs(2));
681        assert_eq!(token_bucket.available_permits(), 490);
682
683        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
684
685        cfg.interceptor_state().store_put(RequestAttempts::new(3));
686        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
687        assert_eq!(no_retry, ShouldAttempt::No);
688        assert_eq!(token_bucket.available_permits(), 495);
689    }
690
691    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
692    #[test]
693    fn no_more_attempts() {
694        let (mut cfg, rc, ctx) = setup_test(
695            vec![RetryAction::server_error()],
696            RetryConfig::standard()
697                .with_use_static_exponential_base(true)
698                .with_max_attempts(3),
699        );
700        let strategy = StandardRetryStrategy::new();
701        cfg.interceptor_state().store_put(TokenBucket::default());
702        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
703
704        cfg.interceptor_state().store_put(RequestAttempts::new(1));
705        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
706        let dur = should_retry.expect_delay();
707        assert_eq!(dur, Duration::from_secs(1));
708        assert_eq!(token_bucket.available_permits(), 495);
709
710        cfg.interceptor_state().store_put(RequestAttempts::new(2));
711        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
712        let dur = should_retry.expect_delay();
713        assert_eq!(dur, Duration::from_secs(2));
714        assert_eq!(token_bucket.available_permits(), 490);
715
716        cfg.interceptor_state().store_put(RequestAttempts::new(3));
717        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
718        assert_eq!(no_retry, ShouldAttempt::No);
719        assert_eq!(token_bucket.available_permits(), 490);
720    }
721
722    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
723    #[test]
724    fn successful_request_and_deser_should_be_retryable() {
725        #[derive(Clone, Copy, Debug)]
726        enum LongRunningOperationStatus {
727            Running,
728            Complete,
729        }
730
731        #[derive(Debug)]
732        struct LongRunningOperationOutput {
733            status: Option<LongRunningOperationStatus>,
734        }
735
736        impl LongRunningOperationOutput {
737            fn status(&self) -> Option<LongRunningOperationStatus> {
738                self.status
739            }
740        }
741
742        struct WaiterRetryClassifier {}
743
744        impl WaiterRetryClassifier {
745            fn new() -> Self {
746                WaiterRetryClassifier {}
747            }
748        }
749
750        impl fmt::Debug for WaiterRetryClassifier {
751            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
752                write!(f, "WaiterRetryClassifier")
753            }
754        }
755        impl ClassifyRetry for WaiterRetryClassifier {
756            fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
757                let status: Option<LongRunningOperationStatus> =
758                    ctx.output_or_error().and_then(|res| {
759                        res.ok().and_then(|output| {
760                            output
761                                .downcast_ref::<LongRunningOperationOutput>()
762                                .and_then(|output| output.status())
763                        })
764                    });
765
766                if let Some(LongRunningOperationStatus::Running) = status {
767                    return RetryAction::server_error();
768                };
769
770                RetryAction::NoActionIndicated
771            }
772
773            fn name(&self) -> &'static str {
774                "waiter retry classifier"
775            }
776        }
777
778        let retry_config = RetryConfig::standard()
779            .with_use_static_exponential_base(true)
780            .with_max_attempts(5);
781
782        let rc = RuntimeComponentsBuilder::for_tests()
783            .with_retry_classifier(SharedRetryClassifier::new(WaiterRetryClassifier::new()))
784            .build()
785            .unwrap();
786        let mut layer = Layer::new("test");
787        layer.store_put(retry_config);
788        let mut cfg = ConfigBag::of_layers(vec![layer]);
789        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
790        let strategy = StandardRetryStrategy::new();
791
792        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
793            status: Some(LongRunningOperationStatus::Running),
794        })));
795
796        cfg.interceptor_state().store_put(TokenBucket::new(5));
797        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
798
799        cfg.interceptor_state().store_put(RequestAttempts::new(1));
800        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
801        let dur = should_retry.expect_delay();
802        assert_eq!(dur, Duration::from_secs(1));
803        assert_eq!(token_bucket.available_permits(), 0);
804
805        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
806            status: Some(LongRunningOperationStatus::Complete),
807        })));
808        cfg.interceptor_state().store_put(RequestAttempts::new(2));
809        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
810        should_retry.expect_no();
811        assert_eq!(token_bucket.available_permits(), 5);
812    }
813
814    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
815    #[test]
816    fn no_quota() {
817        let (mut cfg, rc, ctx) = setup_test(
818            vec![RetryAction::server_error()],
819            RetryConfig::standard()
820                .with_use_static_exponential_base(true)
821                .with_max_attempts(5),
822        );
823        let strategy = StandardRetryStrategy::new();
824        cfg.interceptor_state().store_put(TokenBucket::new(5));
825        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
826
827        cfg.interceptor_state().store_put(RequestAttempts::new(1));
828        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
829        let dur = should_retry.expect_delay();
830        assert_eq!(dur, Duration::from_secs(1));
831        assert_eq!(token_bucket.available_permits(), 0);
832
833        cfg.interceptor_state().store_put(RequestAttempts::new(2));
834        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
835        assert_eq!(no_retry, ShouldAttempt::No);
836        assert_eq!(token_bucket.available_permits(), 0);
837    }
838
839    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
840    #[test]
841    fn quota_replenishes_on_success() {
842        let (mut cfg, rc, mut ctx) = setup_test(
843            vec![
844                RetryAction::transient_error(),
845                RetryAction::retryable_error_with_explicit_delay(
846                    ErrorKind::TransientError,
847                    Duration::from_secs(1),
848                ),
849            ],
850            RetryConfig::standard()
851                .with_use_static_exponential_base(true)
852                .with_max_attempts(5),
853        );
854        let strategy = StandardRetryStrategy::new();
855        cfg.interceptor_state().store_put(TokenBucket::new(100));
856        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
857
858        cfg.interceptor_state().store_put(RequestAttempts::new(1));
859        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
860        let dur = should_retry.expect_delay();
861        assert_eq!(dur, Duration::from_secs(1));
862        assert_eq!(token_bucket.available_permits(), 90);
863
864        cfg.interceptor_state().store_put(RequestAttempts::new(2));
865        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
866        let dur = should_retry.expect_delay();
867        assert_eq!(dur, Duration::from_secs(1));
868        assert_eq!(token_bucket.available_permits(), 80);
869
870        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
871
872        cfg.interceptor_state().store_put(RequestAttempts::new(3));
873        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
874        assert_eq!(no_retry, ShouldAttempt::No);
875
876        assert_eq!(token_bucket.available_permits(), 90);
877    }
878
879    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
880    #[test]
881    fn quota_replenishes_on_first_try_success() {
882        const PERMIT_COUNT: usize = 20;
883        let (mut cfg, rc, mut ctx) = setup_test(
884            vec![RetryAction::transient_error()],
885            RetryConfig::standard()
886                .with_use_static_exponential_base(true)
887                .with_max_attempts(u32::MAX),
888        );
889        let strategy = StandardRetryStrategy::new();
890        cfg.interceptor_state()
891            .store_put(TokenBucket::new(PERMIT_COUNT));
892        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
893
894        let mut attempt = 1;
895
896        // Drain all available permits with failed attempts
897        while token_bucket.available_permits() > 0 {
898            // Draining should complete in 2 attempts
899            if attempt > 2 {
900                panic!("This test should have completed by now (drain)");
901            }
902
903            cfg.interceptor_state()
904                .store_put(RequestAttempts::new(attempt));
905            let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
906            assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
907            attempt += 1;
908        }
909
910        // Forget the permit so that we can only refill by "success on first try".
911        let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
912        permit.forget();
913
914        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
915
916        // Replenish permits until we get back to `PERMIT_COUNT`
917        while token_bucket.available_permits() < PERMIT_COUNT {
918            if attempt > 23 {
919                panic!("This test should have completed by now (fill-up)");
920            }
921
922            cfg.interceptor_state()
923                .store_put(RequestAttempts::new(attempt));
924            let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
925            assert_eq!(no_retry, ShouldAttempt::No);
926            attempt += 1;
927        }
928
929        assert_eq!(attempt, 23);
930        assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
931    }
932
933    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
934    #[test]
935    fn backoff_timing() {
936        let (mut cfg, rc, ctx) = setup_test(
937            vec![RetryAction::server_error()],
938            RetryConfig::standard()
939                .with_use_static_exponential_base(true)
940                .with_max_attempts(5),
941        );
942        let strategy = StandardRetryStrategy::new();
943        cfg.interceptor_state().store_put(TokenBucket::default());
944        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
945
946        cfg.interceptor_state().store_put(RequestAttempts::new(1));
947        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
948        let dur = should_retry.expect_delay();
949        assert_eq!(dur, Duration::from_secs(1));
950        assert_eq!(token_bucket.available_permits(), 495);
951
952        cfg.interceptor_state().store_put(RequestAttempts::new(2));
953        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
954        let dur = should_retry.expect_delay();
955        assert_eq!(dur, Duration::from_secs(2));
956        assert_eq!(token_bucket.available_permits(), 490);
957
958        cfg.interceptor_state().store_put(RequestAttempts::new(3));
959        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
960        let dur = should_retry.expect_delay();
961        assert_eq!(dur, Duration::from_secs(4));
962        assert_eq!(token_bucket.available_permits(), 485);
963
964        cfg.interceptor_state().store_put(RequestAttempts::new(4));
965        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
966        let dur = should_retry.expect_delay();
967        assert_eq!(dur, Duration::from_secs(8));
968        assert_eq!(token_bucket.available_permits(), 480);
969
970        cfg.interceptor_state().store_put(RequestAttempts::new(5));
971        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
972        assert_eq!(no_retry, ShouldAttempt::No);
973        assert_eq!(token_bucket.available_permits(), 480);
974    }
975
976    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
977    #[test]
978    fn max_backoff_time() {
979        let (mut cfg, rc, ctx) = setup_test(
980            vec![RetryAction::server_error()],
981            RetryConfig::standard()
982                .with_use_static_exponential_base(true)
983                .with_max_attempts(5)
984                .with_initial_backoff(Duration::from_secs(1))
985                .with_max_backoff(Duration::from_secs(3)),
986        );
987        let strategy = StandardRetryStrategy::new();
988        cfg.interceptor_state().store_put(TokenBucket::default());
989        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
990
991        cfg.interceptor_state().store_put(RequestAttempts::new(1));
992        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
993        let dur = should_retry.expect_delay();
994        assert_eq!(dur, Duration::from_secs(1));
995        assert_eq!(token_bucket.available_permits(), 495);
996
997        cfg.interceptor_state().store_put(RequestAttempts::new(2));
998        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
999        let dur = should_retry.expect_delay();
1000        assert_eq!(dur, Duration::from_secs(2));
1001        assert_eq!(token_bucket.available_permits(), 490);
1002
1003        cfg.interceptor_state().store_put(RequestAttempts::new(3));
1004        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1005        let dur = should_retry.expect_delay();
1006        assert_eq!(dur, Duration::from_secs(3));
1007        assert_eq!(token_bucket.available_permits(), 485);
1008
1009        cfg.interceptor_state().store_put(RequestAttempts::new(4));
1010        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1011        let dur = should_retry.expect_delay();
1012        assert_eq!(dur, Duration::from_secs(3));
1013        assert_eq!(token_bucket.available_permits(), 480);
1014
1015        cfg.interceptor_state().store_put(RequestAttempts::new(5));
1016        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1017        assert_eq!(no_retry, ShouldAttempt::No);
1018        assert_eq!(token_bucket.available_permits(), 480);
1019    }
1020
1021    const MAX_BACKOFF: Duration = Duration::from_secs(20);
1022
1023    #[test]
1024    fn calculate_exponential_backoff_where_initial_backoff_is_one() {
1025        let initial_backoff = 1.0;
1026
1027        for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
1028            let actual_backoff =
1029                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1030            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1031        }
1032    }
1033
1034    #[test]
1035    fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
1036        let initial_backoff = 3.0;
1037
1038        for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
1039            let actual_backoff =
1040                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1041            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1042        }
1043    }
1044
1045    #[test]
1046    fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
1047        let initial_backoff = 0.03;
1048
1049        for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
1050            let actual_backoff =
1051                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1052            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1053        }
1054    }
1055
1056    #[test]
1057    fn calculate_backoff_overflow_should_gracefully_fallback_to_max_backoff() {
1058        // avoid overflow for a silly large amount of retry attempts
1059        assert_eq!(
1060            MAX_BACKOFF,
1061            calculate_exponential_backoff(1_f64, 10_f64, 100000, MAX_BACKOFF),
1062        );
1063    }
1064}