aws_smithy_runtime/client/retries/strategy/
standard.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use tokio::sync::OwnedSemaphorePermit;
10use tracing::{debug, trace};
11
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeTransmitInterceptorContextMut, InterceptorContext,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::retries::classifiers::{RetryAction, RetryReason};
18use aws_smithy_runtime_api::client::retries::{RequestAttempts, RetryStrategy, ShouldAttempt};
19use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
20use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
21use aws_smithy_types::retry::{ErrorKind, RetryConfig, RetryMode};
22
23use crate::client::retries::classifiers::run_classifiers_on_ctx;
24use crate::client::retries::client_rate_limiter::{ClientRateLimiter, RequestReason};
25use crate::client::retries::strategy::standard::ReleaseResult::{
26    APermitWasReleased, NoPermitWasReleased,
27};
28use crate::client::retries::token_bucket::TokenBucket;
29use crate::client::retries::{ClientRateLimiterPartition, RetryPartition, RetryPartitionInner};
30use crate::static_partition_map::StaticPartitionMap;
31
32static CLIENT_RATE_LIMITER: StaticPartitionMap<ClientRateLimiterPartition, ClientRateLimiter> =
33    StaticPartitionMap::new();
34
35/// Used by token bucket interceptor to ensure a TokenBucket always exists in config bag
36static TOKEN_BUCKET: StaticPartitionMap<RetryPartition, TokenBucket> = StaticPartitionMap::new();
37
38/// Retry strategy with exponential backoff, max attempts, and a token bucket.
39#[derive(Debug, Default)]
40pub struct StandardRetryStrategy {
41    retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
42}
43
44impl Storable for StandardRetryStrategy {
45    type Storer = StoreReplace<Self>;
46}
47
48impl StandardRetryStrategy {
49    /// Create a new standard retry strategy with the given config.
50    pub fn new() -> Self {
51        Default::default()
52    }
53
54    fn release_retry_permit(&self, token_bucket: &TokenBucket) -> ReleaseResult {
55        let mut retry_permit = self.retry_permit.lock().unwrap();
56        match retry_permit.take() {
57            Some(p) => {
58                // Retry succeeded: reward success and forget permit if configured, otherwise release permit back
59                if token_bucket.success_reward() > 0.0 {
60                    token_bucket.reward_success();
61                    p.forget();
62                } else {
63                    drop(p); // Original behavior - release back to bucket
64                }
65                APermitWasReleased
66            }
67            None => {
68                // First-attempt success: reward success or regenerate token
69                if token_bucket.success_reward() > 0.0 {
70                    token_bucket.reward_success();
71                } else {
72                    token_bucket.regenerate_a_token();
73                }
74                NoPermitWasReleased
75            }
76        }
77    }
78
79    fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
80        let mut old_retry_permit = self.retry_permit.lock().unwrap();
81        if let Some(p) = old_retry_permit.replace(new_retry_permit) {
82            // Whenever we set a new retry permit, and it replaces the old one, we need to "forget"
83            // the old permit, removing it from the bucket forever.
84            p.forget()
85        }
86    }
87
88    /// Returns a [`ClientRateLimiter`] if adaptive retry is configured.
89    fn adaptive_retry_rate_limiter(
90        runtime_components: &RuntimeComponents,
91        cfg: &ConfigBag,
92    ) -> Option<ClientRateLimiter> {
93        let retry_config = cfg.load::<RetryConfig>().expect("retry config is required");
94        if retry_config.mode() == RetryMode::Adaptive {
95            if let Some(time_source) = runtime_components.time_source() {
96                let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
97                let seconds_since_unix_epoch = time_source
98                    .now()
99                    .duration_since(SystemTime::UNIX_EPOCH)
100                    .expect("the present takes place after the UNIX_EPOCH")
101                    .as_secs_f64();
102                let client_rate_limiter = match &retry_partition.inner {
103                    RetryPartitionInner::Default(_) => {
104                        let client_rate_limiter_partition =
105                            ClientRateLimiterPartition::new(retry_partition.clone());
106                        CLIENT_RATE_LIMITER.get_or_init(client_rate_limiter_partition, || {
107                            ClientRateLimiter::new(seconds_since_unix_epoch)
108                        })
109                    }
110                    RetryPartitionInner::Custom {
111                        client_rate_limiter,
112                        ..
113                    } => client_rate_limiter.clone(),
114                };
115                return Some(client_rate_limiter);
116            }
117        }
118        None
119    }
120
121    fn calculate_backoff(
122        &self,
123        runtime_components: &RuntimeComponents,
124        cfg: &ConfigBag,
125        retry_cfg: &RetryConfig,
126        retry_reason: &RetryAction,
127    ) -> Result<Duration, ShouldAttempt> {
128        let request_attempts = cfg
129            .load::<RequestAttempts>()
130            .expect("at least one request attempt is made before any retry is attempted")
131            .attempts();
132
133        match retry_reason {
134            RetryAction::RetryIndicated(RetryReason::RetryableError { kind, retry_after }) => {
135                if let Some(delay) = *retry_after {
136                    let delay = delay.min(retry_cfg.max_backoff());
137                    debug!("explicit request from server to delay {delay:?} before retrying");
138                    Ok(delay)
139                } else if let Some(delay) =
140                    check_rate_limiter_for_delay(runtime_components, cfg, *kind)
141                {
142                    let delay = delay.min(retry_cfg.max_backoff());
143                    debug!("rate limiter has requested a {delay:?} delay before retrying");
144                    Ok(delay)
145                } else {
146                    let base = if retry_cfg.use_static_exponential_base() {
147                        1.0
148                    } else {
149                        fastrand::f64()
150                    };
151                    Ok(calculate_exponential_backoff(
152                        // Generate a random base multiplier to create jitter
153                        base,
154                        // Get the backoff time multiplier in seconds (with fractional seconds)
155                        retry_cfg.initial_backoff().as_secs_f64(),
156                        // `self.local.attempts` tracks number of requests made including the initial request
157                        // The initial attempt shouldn't count towards backoff calculations, so we subtract it
158                        request_attempts - 1,
159                        // Maximum backoff duration as a fallback to prevent overflow when calculating a power
160                        retry_cfg.max_backoff(),
161                    ))
162                }
163            }
164            RetryAction::RetryForbidden | RetryAction::NoActionIndicated => {
165                debug!(
166                    attempts = request_attempts,
167                    max_attempts = retry_cfg.max_attempts(),
168                    "encountered un-retryable error"
169                );
170                Err(ShouldAttempt::No)
171            }
172            _ => unreachable!("RetryAction is non-exhaustive"),
173        }
174    }
175}
176
177enum ReleaseResult {
178    APermitWasReleased,
179    NoPermitWasReleased,
180}
181
182impl RetryStrategy for StandardRetryStrategy {
183    fn should_attempt_initial_request(
184        &self,
185        runtime_components: &RuntimeComponents,
186        cfg: &ConfigBag,
187    ) -> Result<ShouldAttempt, BoxError> {
188        if let Some(crl) = Self::adaptive_retry_rate_limiter(runtime_components, cfg) {
189            let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
190            if let Err(delay) = crl.acquire_permission_to_send_a_request(
191                seconds_since_unix_epoch,
192                RequestReason::InitialRequest,
193            ) {
194                return Ok(ShouldAttempt::YesAfterDelay(delay));
195            }
196        } else {
197            debug!("no client rate limiter configured, so no token is required for the initial request.");
198        }
199
200        Ok(ShouldAttempt::Yes)
201    }
202
203    fn should_attempt_retry(
204        &self,
205        ctx: &InterceptorContext,
206        runtime_components: &RuntimeComponents,
207        cfg: &ConfigBag,
208    ) -> Result<ShouldAttempt, BoxError> {
209        let retry_cfg = cfg.load::<RetryConfig>().expect("retry config is required");
210
211        // bookkeeping
212        let token_bucket = cfg.load::<TokenBucket>().expect("token bucket is required");
213        // run the classifier against the context to determine if we should retry
214        let retry_classifiers = runtime_components.retry_classifiers();
215        let classifier_result = run_classifiers_on_ctx(retry_classifiers, ctx);
216
217        // (adaptive only): update fill rate
218        // NOTE: SEP indicates doing bookkeeping before asking if we should retry. We need to know if
219        // the error was a throttling error though to do adaptive retry bookkeeping so we take
220        // advantage of that information being available via the classifier result
221        let error_kind = error_kind(&classifier_result);
222        let is_throttling_error = error_kind
223            .map(|kind| kind == ErrorKind::ThrottlingError)
224            .unwrap_or(false);
225        update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error);
226
227        // on success release any retry quota held by previous attempts, reward success when indicated
228        if !ctx.is_failed() {
229            self.release_retry_permit(token_bucket);
230        }
231        // end bookkeeping
232
233        let request_attempts = cfg
234            .load::<RequestAttempts>()
235            .expect("at least one request attempt is made before any retry is attempted")
236            .attempts();
237
238        // check if retry should be attempted
239        if !classifier_result.should_retry() {
240            debug!(
241                "attempt #{request_attempts} classified as {:?}, not retrying",
242                classifier_result
243            );
244            return Ok(ShouldAttempt::No);
245        }
246
247        // check if we're out of attempts
248        if request_attempts >= retry_cfg.max_attempts() {
249            debug!(
250                attempts = request_attempts,
251                max_attempts = retry_cfg.max_attempts(),
252                "not retrying because we are out of attempts"
253            );
254            return Ok(ShouldAttempt::No);
255        }
256
257        //  acquire permit for retry
258        let error_kind = error_kind.expect("result was classified retryable");
259        match token_bucket.acquire(&error_kind) {
260            Some(permit) => self.set_retry_permit(permit),
261            None => {
262                debug!("attempt #{request_attempts} failed with {error_kind:?}; However, not enough retry quota is available for another attempt so no retry will be attempted.");
263                return Ok(ShouldAttempt::No);
264            }
265        }
266
267        // calculate delay until next attempt
268        let backoff =
269            match self.calculate_backoff(runtime_components, cfg, retry_cfg, &classifier_result) {
270                Ok(value) => value,
271                // In some cases, backoff calculation will decide that we shouldn't retry at all.
272                Err(value) => return Ok(value),
273            };
274
275        debug!(
276            "attempt #{request_attempts} failed with {:?}; retrying after {:?}",
277            classifier_result, backoff
278        );
279        Ok(ShouldAttempt::YesAfterDelay(backoff))
280    }
281}
282
283/// extract the error kind from the classifier result if available
284fn error_kind(classifier_result: &RetryAction) -> Option<ErrorKind> {
285    match classifier_result {
286        RetryAction::RetryIndicated(RetryReason::RetryableError { kind, .. }) => Some(*kind),
287        _ => None,
288    }
289}
290
291fn update_rate_limiter_if_exists(
292    runtime_components: &RuntimeComponents,
293    cfg: &ConfigBag,
294    is_throttling_error: bool,
295) {
296    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
297        let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
298        crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error);
299    }
300}
301
302fn check_rate_limiter_for_delay(
303    runtime_components: &RuntimeComponents,
304    cfg: &ConfigBag,
305    kind: ErrorKind,
306) -> Option<Duration> {
307    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
308        let retry_reason = if kind == ErrorKind::ThrottlingError {
309            RequestReason::RetryTimeout
310        } else {
311            RequestReason::Retry
312        };
313        if let Err(delay) = crl.acquire_permission_to_send_a_request(
314            get_seconds_since_unix_epoch(runtime_components),
315            retry_reason,
316        ) {
317            return Some(delay);
318        }
319    }
320
321    None
322}
323
324pub(super) fn calculate_exponential_backoff(
325    base: f64,
326    initial_backoff: f64,
327    retry_attempts: u32,
328    max_backoff: Duration,
329) -> Duration {
330    let result = match 2_u32
331        .checked_pow(retry_attempts)
332        .map(|power| (power as f64) * initial_backoff)
333    {
334        Some(backoff) => match Duration::try_from_secs_f64(backoff) {
335            Ok(result) => result.min(max_backoff),
336            Err(e) => {
337                tracing::warn!("falling back to {max_backoff:?} as `Duration` could not be created for exponential backoff: {e}");
338                max_backoff
339            }
340        },
341        None => max_backoff,
342    };
343
344    // Apply jitter to `result`, and note that it can be applied to `max_backoff`.
345    // Won't panic because `base` is either in range 0..1 or a constant 1 in testing (if configured).
346    result.mul_f64(base)
347}
348
349pub(super) fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 {
350    let request_time = runtime_components
351        .time_source()
352        .expect("time source required for retries");
353    request_time
354        .now()
355        .duration_since(SystemTime::UNIX_EPOCH)
356        .unwrap()
357        .as_secs_f64()
358}
359
360/// Interceptor registered in default retry plugin that ensures a token bucket exists in config
361/// bag for every operation. Token bucket provided is partitioned by the retry partition **in the
362/// config bag** at the time an operation is executed.
363#[derive(Debug)]
364pub(crate) struct TokenBucketProvider {
365    default_partition: RetryPartition,
366    token_bucket: TokenBucket,
367}
368
369impl TokenBucketProvider {
370    /// Create a new token bucket provider with the given default retry partition.
371    ///
372    /// NOTE: This partition should be the one used for every operation on a client
373    /// unless config is overridden.
374    pub(crate) fn new(default_partition: RetryPartition) -> Self {
375        let token_bucket = TOKEN_BUCKET.get_or_init_default(default_partition.clone());
376        Self {
377            default_partition,
378            token_bucket,
379        }
380    }
381}
382
383impl Intercept for TokenBucketProvider {
384    fn name(&self) -> &'static str {
385        "TokenBucketProvider"
386    }
387
388    fn modify_before_retry_loop(
389        &self,
390        _context: &mut BeforeTransmitInterceptorContextMut<'_>,
391        _runtime_components: &RuntimeComponents,
392        cfg: &mut ConfigBag,
393    ) -> Result<(), BoxError> {
394        let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
395
396        let tb = match &retry_partition.inner {
397            RetryPartitionInner::Default(name) => {
398                // we store the original retry partition configured and associated token bucket
399                // for the client when created so that we can avoid locking on _every_ request
400                // from _every_ client
401                if name == self.default_partition.name() {
402                    // avoid contention on the global lock
403                    self.token_bucket.clone()
404                } else {
405                    TOKEN_BUCKET.get_or_init_default(retry_partition.clone())
406                }
407            }
408            RetryPartitionInner::Custom { token_bucket, .. } => token_bucket.clone(),
409        };
410
411        trace!("token bucket for {retry_partition:?} added to config bag");
412        let mut layer = Layer::new("token_bucket_partition");
413        layer.store_put(tb);
414        cfg.push_layer(layer);
415        Ok(())
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    #[allow(unused_imports)] // will be unused with `--no-default-features --features client`
422    use std::fmt;
423    use std::sync::Mutex;
424    use std::time::Duration;
425
426    use aws_smithy_async::time::SystemTimeSource;
427    use aws_smithy_runtime_api::client::interceptors::context::{
428        Input, InterceptorContext, Output,
429    };
430    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
431    use aws_smithy_runtime_api::client::retries::classifiers::{
432        ClassifyRetry, RetryAction, SharedRetryClassifier,
433    };
434    use aws_smithy_runtime_api::client::retries::{
435        AlwaysRetry, RequestAttempts, RetryStrategy, ShouldAttempt,
436    };
437    use aws_smithy_runtime_api::client::runtime_components::{
438        RuntimeComponents, RuntimeComponentsBuilder,
439    };
440    use aws_smithy_types::config_bag::{ConfigBag, Layer};
441    use aws_smithy_types::retry::{ErrorKind, RetryConfig};
442
443    use super::{calculate_exponential_backoff, StandardRetryStrategy};
444    use crate::client::retries::{ClientRateLimiter, RetryPartition, TokenBucket};
445
446    #[test]
447    fn no_retry_necessary_for_ok_result() {
448        let cfg = ConfigBag::of_layers(vec![{
449            let mut layer = Layer::new("test");
450            layer.store_put(RetryConfig::standard());
451            layer.store_put(RequestAttempts::new(1));
452            layer.store_put(TokenBucket::default());
453            layer
454        }]);
455        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
456        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
457        let strategy = StandardRetryStrategy::default();
458        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
459
460        let actual = strategy
461            .should_attempt_retry(&ctx, &rc, &cfg)
462            .expect("method is infallible for this use");
463        assert_eq!(ShouldAttempt::No, actual);
464    }
465
466    fn set_up_cfg_and_context(
467        error_kind: ErrorKind,
468        current_request_attempts: u32,
469        retry_config: RetryConfig,
470    ) -> (InterceptorContext, RuntimeComponents, ConfigBag) {
471        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
472        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
473        let rc = RuntimeComponentsBuilder::for_tests()
474            .with_retry_classifier(SharedRetryClassifier::new(AlwaysRetry(error_kind)))
475            .build()
476            .unwrap();
477        let mut layer = Layer::new("test");
478        layer.store_put(RequestAttempts::new(current_request_attempts));
479        layer.store_put(retry_config);
480        layer.store_put(TokenBucket::default());
481        let cfg = ConfigBag::of_layers(vec![layer]);
482
483        (ctx, rc, cfg)
484    }
485
486    // Test that error kinds produce the correct "retry after X seconds" output.
487    // All error kinds are handled in the same way for the standard strategy.
488    fn test_should_retry_error_kind(error_kind: ErrorKind) {
489        let (ctx, rc, cfg) = set_up_cfg_and_context(
490            error_kind,
491            3,
492            RetryConfig::standard()
493                .with_use_static_exponential_base(true)
494                .with_max_attempts(4),
495        );
496        let strategy = StandardRetryStrategy::new();
497        let actual = strategy
498            .should_attempt_retry(&ctx, &rc, &cfg)
499            .expect("method is infallible for this use");
500        assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual);
501    }
502
503    #[test]
504    fn should_retry_transient_error_result_after_2s() {
505        test_should_retry_error_kind(ErrorKind::TransientError);
506    }
507
508    #[test]
509    fn should_retry_client_error_result_after_2s() {
510        test_should_retry_error_kind(ErrorKind::ClientError);
511    }
512
513    #[test]
514    fn should_retry_server_error_result_after_2s() {
515        test_should_retry_error_kind(ErrorKind::ServerError);
516    }
517
518    #[test]
519    fn should_retry_throttling_error_result_after_2s() {
520        test_should_retry_error_kind(ErrorKind::ThrottlingError);
521    }
522
523    #[test]
524    fn dont_retry_when_out_of_attempts() {
525        let current_attempts = 4;
526        let max_attempts = current_attempts;
527        let (ctx, rc, cfg) = set_up_cfg_and_context(
528            ErrorKind::TransientError,
529            current_attempts,
530            RetryConfig::standard()
531                .with_use_static_exponential_base(true)
532                .with_max_attempts(max_attempts),
533        );
534        let strategy = StandardRetryStrategy::new();
535        let actual = strategy
536            .should_attempt_retry(&ctx, &rc, &cfg)
537            .expect("method is infallible for this use");
538        assert_eq!(ShouldAttempt::No, actual);
539    }
540
541    #[test]
542    fn should_not_panic_when_exponential_backoff_duration_could_not_be_created() {
543        let (ctx, rc, cfg) = set_up_cfg_and_context(
544            ErrorKind::TransientError,
545            // Greater than 32 when subtracted by 1 in `calculate_backoff`, causing overflow in `calculate_exponential_backoff`
546            33,
547            RetryConfig::standard()
548                .with_use_static_exponential_base(true)
549                .with_max_attempts(100), // Any value greater than 33 will do
550        );
551        let strategy = StandardRetryStrategy::new();
552        let actual = strategy
553            .should_attempt_retry(&ctx, &rc, &cfg)
554            .expect("method is infallible for this use");
555        assert_eq!(ShouldAttempt::YesAfterDelay(MAX_BACKOFF), actual);
556    }
557
558    #[test]
559    fn should_yield_client_rate_limiter_from_custom_partition() {
560        let expected = ClientRateLimiter::builder().token_refill_rate(3.14).build();
561        let cfg = ConfigBag::of_layers(vec![
562            // Emulate default config layer overriden by a user config layer
563            {
564                let mut layer = Layer::new("default");
565                layer.store_put(RetryPartition::new("default"));
566                layer
567            },
568            {
569                let mut layer = Layer::new("user");
570                layer.store_put(RetryConfig::adaptive());
571                layer.store_put(
572                    RetryPartition::custom("user")
573                        .client_rate_limiter(expected.clone())
574                        .build(),
575                );
576                layer
577            },
578        ]);
579        let rc = RuntimeComponentsBuilder::for_tests()
580            .with_time_source(Some(SystemTimeSource::new()))
581            .build()
582            .unwrap();
583        let actual = StandardRetryStrategy::adaptive_retry_rate_limiter(&rc, &cfg)
584            .expect("should yield client rate limiter from custom partition");
585        assert!(std::sync::Arc::ptr_eq(&expected.inner, &actual.inner));
586    }
587
588    #[allow(dead_code)] // will be unused with `--no-default-features --features client`
589    #[derive(Debug)]
590    struct PresetReasonRetryClassifier {
591        retry_actions: Mutex<Vec<RetryAction>>,
592    }
593
594    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
595    impl PresetReasonRetryClassifier {
596        fn new(mut retry_reasons: Vec<RetryAction>) -> Self {
597            // We'll pop the retry_reasons in reverse order, so we reverse the list to fix that.
598            retry_reasons.reverse();
599            Self {
600                retry_actions: Mutex::new(retry_reasons),
601            }
602        }
603    }
604
605    impl ClassifyRetry for PresetReasonRetryClassifier {
606        fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
607            // Check for a result
608            let output_or_error = ctx.output_or_error();
609            // Check for an error
610            match output_or_error {
611                Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
612                _ => (),
613            };
614
615            let mut retry_actions = self.retry_actions.lock().unwrap();
616            if retry_actions.len() == 1 {
617                retry_actions.first().unwrap().clone()
618            } else {
619                retry_actions.pop().unwrap()
620            }
621        }
622
623        fn name(&self) -> &'static str {
624            "Always returns a preset retry reason"
625        }
626    }
627
628    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
629    fn setup_test(
630        retry_reasons: Vec<RetryAction>,
631        retry_config: RetryConfig,
632    ) -> (ConfigBag, RuntimeComponents, InterceptorContext) {
633        let rc = RuntimeComponentsBuilder::for_tests()
634            .with_retry_classifier(SharedRetryClassifier::new(
635                PresetReasonRetryClassifier::new(retry_reasons),
636            ))
637            .build()
638            .unwrap();
639        let mut layer = Layer::new("test");
640        layer.store_put(retry_config);
641        let cfg = ConfigBag::of_layers(vec![layer]);
642        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
643        // This type doesn't matter b/c the classifier will just return whatever we tell it to.
644        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
645
646        (cfg, rc, ctx)
647    }
648
649    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
650    #[test]
651    fn eventual_success() {
652        let (mut cfg, rc, mut ctx) = setup_test(
653            vec![RetryAction::server_error()],
654            RetryConfig::standard()
655                .with_use_static_exponential_base(true)
656                .with_max_attempts(5),
657        );
658        let strategy = StandardRetryStrategy::new();
659        cfg.interceptor_state().store_put(TokenBucket::default());
660        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
661
662        cfg.interceptor_state().store_put(RequestAttempts::new(1));
663        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
664        let dur = should_retry.expect_delay();
665        assert_eq!(dur, Duration::from_secs(1));
666        assert_eq!(token_bucket.available_permits(), 495);
667
668        cfg.interceptor_state().store_put(RequestAttempts::new(2));
669        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
670        let dur = should_retry.expect_delay();
671        assert_eq!(dur, Duration::from_secs(2));
672        assert_eq!(token_bucket.available_permits(), 490);
673
674        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
675
676        cfg.interceptor_state().store_put(RequestAttempts::new(3));
677        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
678        assert_eq!(no_retry, ShouldAttempt::No);
679        assert_eq!(token_bucket.available_permits(), 495);
680    }
681
682    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
683    #[test]
684    fn no_more_attempts() {
685        let (mut cfg, rc, ctx) = setup_test(
686            vec![RetryAction::server_error()],
687            RetryConfig::standard()
688                .with_use_static_exponential_base(true)
689                .with_max_attempts(3),
690        );
691        let strategy = StandardRetryStrategy::new();
692        cfg.interceptor_state().store_put(TokenBucket::default());
693        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
694
695        cfg.interceptor_state().store_put(RequestAttempts::new(1));
696        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
697        let dur = should_retry.expect_delay();
698        assert_eq!(dur, Duration::from_secs(1));
699        assert_eq!(token_bucket.available_permits(), 495);
700
701        cfg.interceptor_state().store_put(RequestAttempts::new(2));
702        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
703        let dur = should_retry.expect_delay();
704        assert_eq!(dur, Duration::from_secs(2));
705        assert_eq!(token_bucket.available_permits(), 490);
706
707        cfg.interceptor_state().store_put(RequestAttempts::new(3));
708        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
709        assert_eq!(no_retry, ShouldAttempt::No);
710        assert_eq!(token_bucket.available_permits(), 490);
711    }
712
713    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
714    #[test]
715    fn successful_request_and_deser_should_be_retryable() {
716        #[derive(Clone, Copy, Debug)]
717        enum LongRunningOperationStatus {
718            Running,
719            Complete,
720        }
721
722        #[derive(Debug)]
723        struct LongRunningOperationOutput {
724            status: Option<LongRunningOperationStatus>,
725        }
726
727        impl LongRunningOperationOutput {
728            fn status(&self) -> Option<LongRunningOperationStatus> {
729                self.status
730            }
731        }
732
733        struct WaiterRetryClassifier {}
734
735        impl WaiterRetryClassifier {
736            fn new() -> Self {
737                WaiterRetryClassifier {}
738            }
739        }
740
741        impl fmt::Debug for WaiterRetryClassifier {
742            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
743                write!(f, "WaiterRetryClassifier")
744            }
745        }
746        impl ClassifyRetry for WaiterRetryClassifier {
747            fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
748                let status: Option<LongRunningOperationStatus> =
749                    ctx.output_or_error().and_then(|res| {
750                        res.ok().and_then(|output| {
751                            output
752                                .downcast_ref::<LongRunningOperationOutput>()
753                                .and_then(|output| output.status())
754                        })
755                    });
756
757                if let Some(LongRunningOperationStatus::Running) = status {
758                    return RetryAction::server_error();
759                };
760
761                RetryAction::NoActionIndicated
762            }
763
764            fn name(&self) -> &'static str {
765                "waiter retry classifier"
766            }
767        }
768
769        let retry_config = RetryConfig::standard()
770            .with_use_static_exponential_base(true)
771            .with_max_attempts(5);
772
773        let rc = RuntimeComponentsBuilder::for_tests()
774            .with_retry_classifier(SharedRetryClassifier::new(WaiterRetryClassifier::new()))
775            .build()
776            .unwrap();
777        let mut layer = Layer::new("test");
778        layer.store_put(retry_config);
779        let mut cfg = ConfigBag::of_layers(vec![layer]);
780        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
781        let strategy = StandardRetryStrategy::new();
782
783        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
784            status: Some(LongRunningOperationStatus::Running),
785        })));
786
787        cfg.interceptor_state().store_put(TokenBucket::new(5));
788        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
789
790        cfg.interceptor_state().store_put(RequestAttempts::new(1));
791        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
792        let dur = should_retry.expect_delay();
793        assert_eq!(dur, Duration::from_secs(1));
794        assert_eq!(token_bucket.available_permits(), 0);
795
796        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
797            status: Some(LongRunningOperationStatus::Complete),
798        })));
799        cfg.interceptor_state().store_put(RequestAttempts::new(2));
800        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
801        should_retry.expect_no();
802        assert_eq!(token_bucket.available_permits(), 5);
803    }
804
805    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
806    #[test]
807    fn no_quota() {
808        let (mut cfg, rc, ctx) = setup_test(
809            vec![RetryAction::server_error()],
810            RetryConfig::standard()
811                .with_use_static_exponential_base(true)
812                .with_max_attempts(5),
813        );
814        let strategy = StandardRetryStrategy::new();
815        cfg.interceptor_state().store_put(TokenBucket::new(5));
816        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
817
818        cfg.interceptor_state().store_put(RequestAttempts::new(1));
819        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
820        let dur = should_retry.expect_delay();
821        assert_eq!(dur, Duration::from_secs(1));
822        assert_eq!(token_bucket.available_permits(), 0);
823
824        cfg.interceptor_state().store_put(RequestAttempts::new(2));
825        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
826        assert_eq!(no_retry, ShouldAttempt::No);
827        assert_eq!(token_bucket.available_permits(), 0);
828    }
829
830    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
831    #[test]
832    fn quota_replenishes_on_success() {
833        let (mut cfg, rc, mut ctx) = setup_test(
834            vec![
835                RetryAction::transient_error(),
836                RetryAction::retryable_error_with_explicit_delay(
837                    ErrorKind::TransientError,
838                    Duration::from_secs(1),
839                ),
840            ],
841            RetryConfig::standard()
842                .with_use_static_exponential_base(true)
843                .with_max_attempts(5),
844        );
845        let strategy = StandardRetryStrategy::new();
846        cfg.interceptor_state().store_put(TokenBucket::new(100));
847        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
848
849        cfg.interceptor_state().store_put(RequestAttempts::new(1));
850        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
851        let dur = should_retry.expect_delay();
852        assert_eq!(dur, Duration::from_secs(1));
853        assert_eq!(token_bucket.available_permits(), 90);
854
855        cfg.interceptor_state().store_put(RequestAttempts::new(2));
856        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
857        let dur = should_retry.expect_delay();
858        assert_eq!(dur, Duration::from_secs(1));
859        assert_eq!(token_bucket.available_permits(), 80);
860
861        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
862
863        cfg.interceptor_state().store_put(RequestAttempts::new(3));
864        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
865        assert_eq!(no_retry, ShouldAttempt::No);
866
867        assert_eq!(token_bucket.available_permits(), 90);
868    }
869
870    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
871    #[test]
872    fn quota_replenishes_on_first_try_success() {
873        const PERMIT_COUNT: usize = 20;
874        let (mut cfg, rc, mut ctx) = setup_test(
875            vec![RetryAction::transient_error()],
876            RetryConfig::standard()
877                .with_use_static_exponential_base(true)
878                .with_max_attempts(u32::MAX),
879        );
880        let strategy = StandardRetryStrategy::new();
881        cfg.interceptor_state()
882            .store_put(TokenBucket::new(PERMIT_COUNT));
883        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
884
885        let mut attempt = 1;
886
887        // Drain all available permits with failed attempts
888        while token_bucket.available_permits() > 0 {
889            // Draining should complete in 2 attempts
890            if attempt > 2 {
891                panic!("This test should have completed by now (drain)");
892            }
893
894            cfg.interceptor_state()
895                .store_put(RequestAttempts::new(attempt));
896            let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
897            assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
898            attempt += 1;
899        }
900
901        // Forget the permit so that we can only refill by "success on first try".
902        let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
903        permit.forget();
904
905        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
906
907        // Replenish permits until we get back to `PERMIT_COUNT`
908        while token_bucket.available_permits() < PERMIT_COUNT {
909            if attempt > 23 {
910                panic!("This test should have completed by now (fill-up)");
911            }
912
913            cfg.interceptor_state()
914                .store_put(RequestAttempts::new(attempt));
915            let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
916            assert_eq!(no_retry, ShouldAttempt::No);
917            attempt += 1;
918        }
919
920        assert_eq!(attempt, 23);
921        assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
922    }
923
924    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
925    #[test]
926    fn backoff_timing() {
927        let (mut cfg, rc, ctx) = setup_test(
928            vec![RetryAction::server_error()],
929            RetryConfig::standard()
930                .with_use_static_exponential_base(true)
931                .with_max_attempts(5),
932        );
933        let strategy = StandardRetryStrategy::new();
934        cfg.interceptor_state().store_put(TokenBucket::default());
935        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
936
937        cfg.interceptor_state().store_put(RequestAttempts::new(1));
938        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
939        let dur = should_retry.expect_delay();
940        assert_eq!(dur, Duration::from_secs(1));
941        assert_eq!(token_bucket.available_permits(), 495);
942
943        cfg.interceptor_state().store_put(RequestAttempts::new(2));
944        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
945        let dur = should_retry.expect_delay();
946        assert_eq!(dur, Duration::from_secs(2));
947        assert_eq!(token_bucket.available_permits(), 490);
948
949        cfg.interceptor_state().store_put(RequestAttempts::new(3));
950        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
951        let dur = should_retry.expect_delay();
952        assert_eq!(dur, Duration::from_secs(4));
953        assert_eq!(token_bucket.available_permits(), 485);
954
955        cfg.interceptor_state().store_put(RequestAttempts::new(4));
956        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
957        let dur = should_retry.expect_delay();
958        assert_eq!(dur, Duration::from_secs(8));
959        assert_eq!(token_bucket.available_permits(), 480);
960
961        cfg.interceptor_state().store_put(RequestAttempts::new(5));
962        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
963        assert_eq!(no_retry, ShouldAttempt::No);
964        assert_eq!(token_bucket.available_permits(), 480);
965    }
966
967    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
968    #[test]
969    fn max_backoff_time() {
970        let (mut cfg, rc, ctx) = setup_test(
971            vec![RetryAction::server_error()],
972            RetryConfig::standard()
973                .with_use_static_exponential_base(true)
974                .with_max_attempts(5)
975                .with_initial_backoff(Duration::from_secs(1))
976                .with_max_backoff(Duration::from_secs(3)),
977        );
978        let strategy = StandardRetryStrategy::new();
979        cfg.interceptor_state().store_put(TokenBucket::default());
980        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
981
982        cfg.interceptor_state().store_put(RequestAttempts::new(1));
983        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
984        let dur = should_retry.expect_delay();
985        assert_eq!(dur, Duration::from_secs(1));
986        assert_eq!(token_bucket.available_permits(), 495);
987
988        cfg.interceptor_state().store_put(RequestAttempts::new(2));
989        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
990        let dur = should_retry.expect_delay();
991        assert_eq!(dur, Duration::from_secs(2));
992        assert_eq!(token_bucket.available_permits(), 490);
993
994        cfg.interceptor_state().store_put(RequestAttempts::new(3));
995        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
996        let dur = should_retry.expect_delay();
997        assert_eq!(dur, Duration::from_secs(3));
998        assert_eq!(token_bucket.available_permits(), 485);
999
1000        cfg.interceptor_state().store_put(RequestAttempts::new(4));
1001        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1002        let dur = should_retry.expect_delay();
1003        assert_eq!(dur, Duration::from_secs(3));
1004        assert_eq!(token_bucket.available_permits(), 480);
1005
1006        cfg.interceptor_state().store_put(RequestAttempts::new(5));
1007        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1008        assert_eq!(no_retry, ShouldAttempt::No);
1009        assert_eq!(token_bucket.available_permits(), 480);
1010    }
1011
1012    const MAX_BACKOFF: Duration = Duration::from_secs(20);
1013
1014    #[test]
1015    fn calculate_exponential_backoff_where_initial_backoff_is_one() {
1016        let initial_backoff = 1.0;
1017
1018        for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
1019            let actual_backoff =
1020                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1021            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1022        }
1023    }
1024
1025    #[test]
1026    fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
1027        let initial_backoff = 3.0;
1028
1029        for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
1030            let actual_backoff =
1031                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1032            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1033        }
1034    }
1035
1036    #[test]
1037    fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
1038        let initial_backoff = 0.03;
1039
1040        for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
1041            let actual_backoff =
1042                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1043            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1044        }
1045    }
1046
1047    #[test]
1048    fn calculate_backoff_overflow_should_gracefully_fallback_to_max_backoff() {
1049        // avoid overflow for a silly large amount of retry attempts
1050        assert_eq!(
1051            MAX_BACKOFF,
1052            calculate_exponential_backoff(1_f64, 10_f64, 100000, MAX_BACKOFF),
1053        );
1054    }
1055}