aws_smithy_runtime/client/retries/strategy/
standard.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use tokio::sync::OwnedSemaphorePermit;
10use tracing::{debug, trace};
11
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeTransmitInterceptorContextMut, InterceptorContext,
15};
16use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
17use aws_smithy_runtime_api::client::retries::classifiers::{RetryAction, RetryReason};
18use aws_smithy_runtime_api::client::retries::{RequestAttempts, RetryStrategy, ShouldAttempt};
19use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
20use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
21use aws_smithy_types::retry::{ErrorKind, RetryConfig, RetryMode};
22
23use crate::client::retries::classifiers::run_classifiers_on_ctx;
24use crate::client::retries::client_rate_limiter::{ClientRateLimiter, RequestReason};
25use crate::client::retries::strategy::standard::ReleaseResult::{
26    APermitWasReleased, NoPermitWasReleased,
27};
28use crate::client::retries::token_bucket::TokenBucket;
29use crate::client::retries::{ClientRateLimiterPartition, RetryPartition, RetryPartitionInner};
30use crate::static_partition_map::StaticPartitionMap;
31
32static CLIENT_RATE_LIMITER: StaticPartitionMap<ClientRateLimiterPartition, ClientRateLimiter> =
33    StaticPartitionMap::new();
34
35/// Used by token bucket interceptor to ensure a TokenBucket always exists in config bag
36static TOKEN_BUCKET: StaticPartitionMap<RetryPartition, TokenBucket> = StaticPartitionMap::new();
37
38/// Retry strategy with exponential backoff, max attempts, and a token bucket.
39#[derive(Debug, Default)]
40pub struct StandardRetryStrategy {
41    retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
42}
43
44impl Storable for StandardRetryStrategy {
45    type Storer = StoreReplace<Self>;
46}
47
48impl StandardRetryStrategy {
49    /// Create a new standard retry strategy with the given config.
50    pub fn new() -> Self {
51        Default::default()
52    }
53
54    fn release_retry_permit(&self, token_bucket: &TokenBucket) -> ReleaseResult {
55        let mut retry_permit = self.retry_permit.lock().unwrap();
56        match retry_permit.take() {
57            Some(p) => {
58                // Retry succeeded: reward success and forget permit if configured, otherwise release permit back
59                if token_bucket.success_reward() > 0.0 {
60                    token_bucket.reward_success();
61                    p.forget();
62                } else {
63                    drop(p); // Original behavior - release back to bucket
64                }
65                APermitWasReleased
66            }
67            None => {
68                // First-attempt success: reward success or regenerate token
69                if token_bucket.success_reward() > 0.0 {
70                    token_bucket.reward_success();
71                } else {
72                    token_bucket.regenerate_a_token();
73                }
74                NoPermitWasReleased
75            }
76        }
77    }
78
79    fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
80        let mut old_retry_permit = self.retry_permit.lock().unwrap();
81        if let Some(p) = old_retry_permit.replace(new_retry_permit) {
82            // Whenever we set a new retry permit, and it replaces the old one, we need to "forget"
83            // the old permit, removing it from the bucket forever.
84            p.forget()
85        }
86    }
87
88    /// Returns a [`ClientRateLimiter`] if adaptive retry is configured.
89    fn adaptive_retry_rate_limiter(
90        runtime_components: &RuntimeComponents,
91        cfg: &ConfigBag,
92    ) -> Option<ClientRateLimiter> {
93        let retry_config = cfg.load::<RetryConfig>().expect("retry config is required");
94        if retry_config.mode() == RetryMode::Adaptive {
95            if let Some(time_source) = runtime_components.time_source() {
96                let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
97                let seconds_since_unix_epoch = time_source
98                    .now()
99                    .duration_since(SystemTime::UNIX_EPOCH)
100                    .expect("the present takes place after the UNIX_EPOCH")
101                    .as_secs_f64();
102                let client_rate_limiter = match &retry_partition.inner {
103                    RetryPartitionInner::Default(_) => {
104                        let client_rate_limiter_partition =
105                            ClientRateLimiterPartition::new(retry_partition.clone());
106                        CLIENT_RATE_LIMITER.get_or_init(client_rate_limiter_partition, || {
107                            ClientRateLimiter::new(seconds_since_unix_epoch)
108                        })
109                    }
110                    RetryPartitionInner::Custom {
111                        client_rate_limiter,
112                        ..
113                    } => client_rate_limiter.clone(),
114                };
115                return Some(client_rate_limiter);
116            }
117        }
118        None
119    }
120
121    fn calculate_backoff(
122        &self,
123        runtime_components: &RuntimeComponents,
124        cfg: &ConfigBag,
125        retry_cfg: &RetryConfig,
126        retry_reason: &RetryAction,
127    ) -> Result<Duration, ShouldAttempt> {
128        let request_attempts = cfg
129            .load::<RequestAttempts>()
130            .expect("at least one request attempt is made before any retry is attempted")
131            .attempts();
132
133        match retry_reason {
134            RetryAction::RetryIndicated(RetryReason::RetryableError { kind, retry_after }) => {
135                if let Some(delay) = *retry_after {
136                    let delay = delay.min(retry_cfg.max_backoff());
137                    debug!("explicit request from server to delay {delay:?} before retrying");
138                    Ok(delay)
139                } else if let Some(delay) =
140                    check_rate_limiter_for_delay(runtime_components, cfg, *kind)
141                {
142                    let delay = delay.min(retry_cfg.max_backoff());
143                    debug!("rate limiter has requested a {delay:?} delay before retrying");
144                    Ok(delay)
145                } else {
146                    let base = if retry_cfg.use_static_exponential_base() {
147                        1.0
148                    } else {
149                        fastrand::f64()
150                    };
151                    Ok(calculate_exponential_backoff(
152                        // Generate a random base multiplier to create jitter
153                        base,
154                        // Get the backoff time multiplier in seconds (with fractional seconds)
155                        retry_cfg.initial_backoff().as_secs_f64(),
156                        // `self.local.attempts` tracks number of requests made including the initial request
157                        // The initial attempt shouldn't count towards backoff calculations, so we subtract it
158                        request_attempts - 1,
159                        // Maximum backoff duration as a fallback to prevent overflow when calculating a power
160                        retry_cfg.max_backoff(),
161                    ))
162                }
163            }
164            RetryAction::RetryForbidden | RetryAction::NoActionIndicated => {
165                debug!(
166                    attempts = request_attempts,
167                    max_attempts = retry_cfg.max_attempts(),
168                    "encountered un-retryable error"
169                );
170                Err(ShouldAttempt::No)
171            }
172            _ => unreachable!("RetryAction is non-exhaustive"),
173        }
174    }
175}
176
177enum ReleaseResult {
178    APermitWasReleased,
179    NoPermitWasReleased,
180}
181
182impl RetryStrategy for StandardRetryStrategy {
183    fn should_attempt_initial_request(
184        &self,
185        runtime_components: &RuntimeComponents,
186        cfg: &ConfigBag,
187    ) -> Result<ShouldAttempt, BoxError> {
188        if let Some(crl) = Self::adaptive_retry_rate_limiter(runtime_components, cfg) {
189            let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
190            if let Err(delay) = crl.acquire_permission_to_send_a_request(
191                seconds_since_unix_epoch,
192                RequestReason::InitialRequest,
193            ) {
194                return Ok(ShouldAttempt::YesAfterDelay(delay));
195            }
196        } else {
197            debug!("no client rate limiter configured, so no token is required for the initial request.");
198        }
199
200        Ok(ShouldAttempt::Yes)
201    }
202
203    fn should_attempt_retry(
204        &self,
205        ctx: &InterceptorContext,
206        runtime_components: &RuntimeComponents,
207        cfg: &ConfigBag,
208    ) -> Result<ShouldAttempt, BoxError> {
209        let retry_cfg = cfg.load::<RetryConfig>().expect("retry config is required");
210
211        // bookkeeping
212        let token_bucket = cfg.load::<TokenBucket>().expect("token bucket is required");
213        // run the classifier against the context to determine if we should retry
214        let retry_classifiers = runtime_components.retry_classifiers();
215        let classifier_result = run_classifiers_on_ctx(retry_classifiers, ctx);
216
217        // (adaptive only): update fill rate
218        // NOTE: SEP indicates doing bookkeeping before asking if we should retry. We need to know if
219        // the error was a throttling error though to do adaptive retry bookkeeping so we take
220        // advantage of that information being available via the classifier result
221        let error_kind = error_kind(&classifier_result);
222        let is_throttling_error = error_kind
223            .map(|kind| kind == ErrorKind::ThrottlingError)
224            .unwrap_or(false);
225        update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error);
226
227        // on success release any retry quota held by previous attempts, reward success when indicated
228        if !ctx.is_failed() {
229            self.release_retry_permit(token_bucket);
230        }
231        // end bookkeeping
232
233        let request_attempts = cfg
234            .load::<RequestAttempts>()
235            .expect("at least one request attempt is made before any retry is attempted")
236            .attempts();
237
238        // check if retry should be attempted
239        if !classifier_result.should_retry() {
240            debug!(
241                "attempt #{request_attempts} classified as {:?}, not retrying",
242                classifier_result
243            );
244            return Ok(ShouldAttempt::No);
245        }
246
247        // check if we're out of attempts
248        if request_attempts >= retry_cfg.max_attempts() {
249            debug!(
250                attempts = request_attempts,
251                max_attempts = retry_cfg.max_attempts(),
252                "not retrying because we are out of attempts"
253            );
254            return Ok(ShouldAttempt::No);
255        }
256
257        //  acquire permit for retry
258        let error_kind = error_kind.expect("result was classified retryable");
259        match token_bucket.acquire(
260            &error_kind,
261            &runtime_components.time_source().unwrap_or_default(),
262        ) {
263            Some(permit) => self.set_retry_permit(permit),
264            None => {
265                debug!("attempt #{request_attempts} failed with {error_kind:?}; However, not enough retry quota is available for another attempt so no retry will be attempted.");
266                return Ok(ShouldAttempt::No);
267            }
268        }
269
270        // calculate delay until next attempt
271        let backoff =
272            match self.calculate_backoff(runtime_components, cfg, retry_cfg, &classifier_result) {
273                Ok(value) => value,
274                // In some cases, backoff calculation will decide that we shouldn't retry at all.
275                Err(value) => return Ok(value),
276            };
277
278        debug!(
279            "attempt #{request_attempts} failed with {:?}; retrying after {:?}",
280            classifier_result, backoff
281        );
282        Ok(ShouldAttempt::YesAfterDelay(backoff))
283    }
284}
285
286/// extract the error kind from the classifier result if available
287fn error_kind(classifier_result: &RetryAction) -> Option<ErrorKind> {
288    match classifier_result {
289        RetryAction::RetryIndicated(RetryReason::RetryableError { kind, .. }) => Some(*kind),
290        _ => None,
291    }
292}
293
294fn update_rate_limiter_if_exists(
295    runtime_components: &RuntimeComponents,
296    cfg: &ConfigBag,
297    is_throttling_error: bool,
298) {
299    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
300        let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
301        crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error);
302    }
303}
304
305fn check_rate_limiter_for_delay(
306    runtime_components: &RuntimeComponents,
307    cfg: &ConfigBag,
308    kind: ErrorKind,
309) -> Option<Duration> {
310    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
311        let retry_reason = if kind == ErrorKind::ThrottlingError {
312            RequestReason::RetryTimeout
313        } else {
314            RequestReason::Retry
315        };
316        if let Err(delay) = crl.acquire_permission_to_send_a_request(
317            get_seconds_since_unix_epoch(runtime_components),
318            retry_reason,
319        ) {
320            return Some(delay);
321        }
322    }
323
324    None
325}
326
327pub(super) fn calculate_exponential_backoff(
328    base: f64,
329    initial_backoff: f64,
330    retry_attempts: u32,
331    max_backoff: Duration,
332) -> Duration {
333    let result = match 2_u32
334        .checked_pow(retry_attempts)
335        .map(|power| (power as f64) * initial_backoff)
336    {
337        Some(backoff) => match Duration::try_from_secs_f64(backoff) {
338            Ok(result) => result.min(max_backoff),
339            Err(e) => {
340                tracing::warn!("falling back to {max_backoff:?} as `Duration` could not be created for exponential backoff: {e}");
341                max_backoff
342            }
343        },
344        None => max_backoff,
345    };
346
347    // Apply jitter to `result`, and note that it can be applied to `max_backoff`.
348    // Won't panic because `base` is either in range 0..1 or a constant 1 in testing (if configured).
349    result.mul_f64(base)
350}
351
352pub(super) fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 {
353    let request_time = runtime_components
354        .time_source()
355        .expect("time source required for retries");
356    request_time
357        .now()
358        .duration_since(SystemTime::UNIX_EPOCH)
359        .unwrap()
360        .as_secs_f64()
361}
362
363/// Interceptor registered in default retry plugin that ensures a token bucket exists in config
364/// bag for every operation. Token bucket provided is partitioned by the retry partition **in the
365/// config bag** at the time an operation is executed.
366#[derive(Debug)]
367pub(crate) struct TokenBucketProvider {
368    default_partition: RetryPartition,
369    token_bucket: TokenBucket,
370}
371
372impl TokenBucketProvider {
373    /// Create a new token bucket provider with the given default retry partition.
374    ///
375    /// NOTE: This partition should be the one used for every operation on a client
376    /// unless config is overridden.
377    pub(crate) fn new(default_partition: RetryPartition) -> Self {
378        let token_bucket = TOKEN_BUCKET.get_or_init_default(default_partition.clone());
379        Self {
380            default_partition,
381            token_bucket,
382        }
383    }
384}
385
386#[dyn_dispatch_hint]
387impl Intercept for TokenBucketProvider {
388    fn name(&self) -> &'static str {
389        "TokenBucketProvider"
390    }
391
392    fn modify_before_retry_loop(
393        &self,
394        _context: &mut BeforeTransmitInterceptorContextMut<'_>,
395        _runtime_components: &RuntimeComponents,
396        cfg: &mut ConfigBag,
397    ) -> Result<(), BoxError> {
398        let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
399
400        let tb = match &retry_partition.inner {
401            RetryPartitionInner::Default(name) => {
402                // we store the original retry partition configured and associated token bucket
403                // for the client when created so that we can avoid locking on _every_ request
404                // from _every_ client
405                if name == self.default_partition.name() {
406                    // avoid contention on the global lock
407                    self.token_bucket.clone()
408                } else {
409                    TOKEN_BUCKET.get_or_init_default(retry_partition.clone())
410                }
411            }
412            RetryPartitionInner::Custom { token_bucket, .. } => token_bucket.clone(),
413        };
414
415        trace!("token bucket for {retry_partition:?} added to config bag");
416        let mut layer = Layer::new("token_bucket_partition");
417        layer.store_put(tb);
418        cfg.push_layer(layer);
419        Ok(())
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    #[allow(unused_imports)] // will be unused with `--no-default-features --features client`
426    use std::fmt;
427    use std::sync::Mutex;
428    use std::time::Duration;
429
430    use aws_smithy_async::time::SystemTimeSource;
431    use aws_smithy_runtime_api::client::interceptors::context::{
432        Input, InterceptorContext, Output,
433    };
434    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
435    use aws_smithy_runtime_api::client::retries::classifiers::{
436        ClassifyRetry, RetryAction, SharedRetryClassifier,
437    };
438    use aws_smithy_runtime_api::client::retries::{
439        AlwaysRetry, RequestAttempts, RetryStrategy, ShouldAttempt,
440    };
441    use aws_smithy_runtime_api::client::runtime_components::{
442        RuntimeComponents, RuntimeComponentsBuilder,
443    };
444    use aws_smithy_types::config_bag::{ConfigBag, Layer};
445    use aws_smithy_types::retry::{ErrorKind, RetryConfig};
446
447    use super::{calculate_exponential_backoff, StandardRetryStrategy};
448    use crate::client::retries::{ClientRateLimiter, RetryPartition, TokenBucket};
449
450    #[test]
451    fn no_retry_necessary_for_ok_result() {
452        let cfg = ConfigBag::of_layers(vec![{
453            let mut layer = Layer::new("test");
454            layer.store_put(RetryConfig::standard());
455            layer.store_put(RequestAttempts::new(1));
456            layer.store_put(TokenBucket::default());
457            layer
458        }]);
459        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
460        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
461        let strategy = StandardRetryStrategy::default();
462        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
463
464        let actual = strategy
465            .should_attempt_retry(&ctx, &rc, &cfg)
466            .expect("method is infallible for this use");
467        assert_eq!(ShouldAttempt::No, actual);
468    }
469
470    fn set_up_cfg_and_context(
471        error_kind: ErrorKind,
472        current_request_attempts: u32,
473        retry_config: RetryConfig,
474    ) -> (InterceptorContext, RuntimeComponents, ConfigBag) {
475        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
476        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
477        let rc = RuntimeComponentsBuilder::for_tests()
478            .with_retry_classifier(SharedRetryClassifier::new(AlwaysRetry(error_kind)))
479            .build()
480            .unwrap();
481        let mut layer = Layer::new("test");
482        layer.store_put(RequestAttempts::new(current_request_attempts));
483        layer.store_put(retry_config);
484        layer.store_put(TokenBucket::default());
485        let cfg = ConfigBag::of_layers(vec![layer]);
486
487        (ctx, rc, cfg)
488    }
489
490    // Test that error kinds produce the correct "retry after X seconds" output.
491    // All error kinds are handled in the same way for the standard strategy.
492    fn test_should_retry_error_kind(error_kind: ErrorKind) {
493        let (ctx, rc, cfg) = set_up_cfg_and_context(
494            error_kind,
495            3,
496            RetryConfig::standard()
497                .with_use_static_exponential_base(true)
498                .with_max_attempts(4),
499        );
500        let strategy = StandardRetryStrategy::new();
501        let actual = strategy
502            .should_attempt_retry(&ctx, &rc, &cfg)
503            .expect("method is infallible for this use");
504        assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual);
505    }
506
507    #[test]
508    fn should_retry_transient_error_result_after_2s() {
509        test_should_retry_error_kind(ErrorKind::TransientError);
510    }
511
512    #[test]
513    fn should_retry_client_error_result_after_2s() {
514        test_should_retry_error_kind(ErrorKind::ClientError);
515    }
516
517    #[test]
518    fn should_retry_server_error_result_after_2s() {
519        test_should_retry_error_kind(ErrorKind::ServerError);
520    }
521
522    #[test]
523    fn should_retry_throttling_error_result_after_2s() {
524        test_should_retry_error_kind(ErrorKind::ThrottlingError);
525    }
526
527    #[test]
528    fn dont_retry_when_out_of_attempts() {
529        let current_attempts = 4;
530        let max_attempts = current_attempts;
531        let (ctx, rc, cfg) = set_up_cfg_and_context(
532            ErrorKind::TransientError,
533            current_attempts,
534            RetryConfig::standard()
535                .with_use_static_exponential_base(true)
536                .with_max_attempts(max_attempts),
537        );
538        let strategy = StandardRetryStrategy::new();
539        let actual = strategy
540            .should_attempt_retry(&ctx, &rc, &cfg)
541            .expect("method is infallible for this use");
542        assert_eq!(ShouldAttempt::No, actual);
543    }
544
545    #[test]
546    fn should_not_panic_when_exponential_backoff_duration_could_not_be_created() {
547        let (ctx, rc, cfg) = set_up_cfg_and_context(
548            ErrorKind::TransientError,
549            // Greater than 32 when subtracted by 1 in `calculate_backoff`, causing overflow in `calculate_exponential_backoff`
550            33,
551            RetryConfig::standard()
552                .with_use_static_exponential_base(true)
553                .with_max_attempts(100), // Any value greater than 33 will do
554        );
555        let strategy = StandardRetryStrategy::new();
556        let actual = strategy
557            .should_attempt_retry(&ctx, &rc, &cfg)
558            .expect("method is infallible for this use");
559        assert_eq!(ShouldAttempt::YesAfterDelay(MAX_BACKOFF), actual);
560    }
561
562    #[test]
563    fn should_yield_client_rate_limiter_from_custom_partition() {
564        let expected = ClientRateLimiter::builder().token_refill_rate(3.14).build();
565        let cfg = ConfigBag::of_layers(vec![
566            // Emulate default config layer overriden by a user config layer
567            {
568                let mut layer = Layer::new("default");
569                layer.store_put(RetryPartition::new("default"));
570                layer
571            },
572            {
573                let mut layer = Layer::new("user");
574                layer.store_put(RetryConfig::adaptive());
575                layer.store_put(
576                    RetryPartition::custom("user")
577                        .client_rate_limiter(expected.clone())
578                        .build(),
579                );
580                layer
581            },
582        ]);
583        let rc = RuntimeComponentsBuilder::for_tests()
584            .with_time_source(Some(SystemTimeSource::new()))
585            .build()
586            .unwrap();
587        let actual = StandardRetryStrategy::adaptive_retry_rate_limiter(&rc, &cfg)
588            .expect("should yield client rate limiter from custom partition");
589        assert!(std::sync::Arc::ptr_eq(&expected.inner, &actual.inner));
590    }
591
592    #[allow(dead_code)] // will be unused with `--no-default-features --features client`
593    #[derive(Debug)]
594    struct PresetReasonRetryClassifier {
595        retry_actions: Mutex<Vec<RetryAction>>,
596    }
597
598    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
599    impl PresetReasonRetryClassifier {
600        fn new(mut retry_reasons: Vec<RetryAction>) -> Self {
601            // We'll pop the retry_reasons in reverse order, so we reverse the list to fix that.
602            retry_reasons.reverse();
603            Self {
604                retry_actions: Mutex::new(retry_reasons),
605            }
606        }
607    }
608
609    impl ClassifyRetry for PresetReasonRetryClassifier {
610        fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
611            // Check for a result
612            let output_or_error = ctx.output_or_error();
613            // Check for an error
614            match output_or_error {
615                Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
616                _ => (),
617            };
618
619            let mut retry_actions = self.retry_actions.lock().unwrap();
620            if retry_actions.len() == 1 {
621                retry_actions.first().unwrap().clone()
622            } else {
623                retry_actions.pop().unwrap()
624            }
625        }
626
627        fn name(&self) -> &'static str {
628            "Always returns a preset retry reason"
629        }
630    }
631
632    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
633    fn setup_test(
634        retry_reasons: Vec<RetryAction>,
635        retry_config: RetryConfig,
636    ) -> (ConfigBag, RuntimeComponents, InterceptorContext) {
637        let rc = RuntimeComponentsBuilder::for_tests()
638            .with_retry_classifier(SharedRetryClassifier::new(
639                PresetReasonRetryClassifier::new(retry_reasons),
640            ))
641            .build()
642            .unwrap();
643        let mut layer = Layer::new("test");
644        layer.store_put(retry_config);
645        let cfg = ConfigBag::of_layers(vec![layer]);
646        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
647        // This type doesn't matter b/c the classifier will just return whatever we tell it to.
648        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
649
650        (cfg, rc, ctx)
651    }
652
653    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
654    #[test]
655    fn eventual_success() {
656        let (mut cfg, rc, mut ctx) = setup_test(
657            vec![RetryAction::server_error()],
658            RetryConfig::standard()
659                .with_use_static_exponential_base(true)
660                .with_max_attempts(5),
661        );
662        let strategy = StandardRetryStrategy::new();
663        cfg.interceptor_state().store_put(TokenBucket::default());
664        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
665
666        cfg.interceptor_state().store_put(RequestAttempts::new(1));
667        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
668        let dur = should_retry.expect_delay();
669        assert_eq!(dur, Duration::from_secs(1));
670        assert_eq!(token_bucket.available_permits(), 495);
671
672        cfg.interceptor_state().store_put(RequestAttempts::new(2));
673        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
674        let dur = should_retry.expect_delay();
675        assert_eq!(dur, Duration::from_secs(2));
676        assert_eq!(token_bucket.available_permits(), 490);
677
678        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
679
680        cfg.interceptor_state().store_put(RequestAttempts::new(3));
681        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
682        assert_eq!(no_retry, ShouldAttempt::No);
683        assert_eq!(token_bucket.available_permits(), 495);
684    }
685
686    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
687    #[test]
688    fn no_more_attempts() {
689        let (mut cfg, rc, ctx) = setup_test(
690            vec![RetryAction::server_error()],
691            RetryConfig::standard()
692                .with_use_static_exponential_base(true)
693                .with_max_attempts(3),
694        );
695        let strategy = StandardRetryStrategy::new();
696        cfg.interceptor_state().store_put(TokenBucket::default());
697        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
698
699        cfg.interceptor_state().store_put(RequestAttempts::new(1));
700        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
701        let dur = should_retry.expect_delay();
702        assert_eq!(dur, Duration::from_secs(1));
703        assert_eq!(token_bucket.available_permits(), 495);
704
705        cfg.interceptor_state().store_put(RequestAttempts::new(2));
706        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
707        let dur = should_retry.expect_delay();
708        assert_eq!(dur, Duration::from_secs(2));
709        assert_eq!(token_bucket.available_permits(), 490);
710
711        cfg.interceptor_state().store_put(RequestAttempts::new(3));
712        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
713        assert_eq!(no_retry, ShouldAttempt::No);
714        assert_eq!(token_bucket.available_permits(), 490);
715    }
716
717    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
718    #[test]
719    fn successful_request_and_deser_should_be_retryable() {
720        #[derive(Clone, Copy, Debug)]
721        enum LongRunningOperationStatus {
722            Running,
723            Complete,
724        }
725
726        #[derive(Debug)]
727        struct LongRunningOperationOutput {
728            status: Option<LongRunningOperationStatus>,
729        }
730
731        impl LongRunningOperationOutput {
732            fn status(&self) -> Option<LongRunningOperationStatus> {
733                self.status
734            }
735        }
736
737        struct WaiterRetryClassifier {}
738
739        impl WaiterRetryClassifier {
740            fn new() -> Self {
741                WaiterRetryClassifier {}
742            }
743        }
744
745        impl fmt::Debug for WaiterRetryClassifier {
746            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
747                write!(f, "WaiterRetryClassifier")
748            }
749        }
750        impl ClassifyRetry for WaiterRetryClassifier {
751            fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
752                let status: Option<LongRunningOperationStatus> =
753                    ctx.output_or_error().and_then(|res| {
754                        res.ok().and_then(|output| {
755                            output
756                                .downcast_ref::<LongRunningOperationOutput>()
757                                .and_then(|output| output.status())
758                        })
759                    });
760
761                if let Some(LongRunningOperationStatus::Running) = status {
762                    return RetryAction::server_error();
763                };
764
765                RetryAction::NoActionIndicated
766            }
767
768            fn name(&self) -> &'static str {
769                "waiter retry classifier"
770            }
771        }
772
773        let retry_config = RetryConfig::standard()
774            .with_use_static_exponential_base(true)
775            .with_max_attempts(5);
776
777        let rc = RuntimeComponentsBuilder::for_tests()
778            .with_retry_classifier(SharedRetryClassifier::new(WaiterRetryClassifier::new()))
779            .build()
780            .unwrap();
781        let mut layer = Layer::new("test");
782        layer.store_put(retry_config);
783        let mut cfg = ConfigBag::of_layers(vec![layer]);
784        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
785        let strategy = StandardRetryStrategy::new();
786
787        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
788            status: Some(LongRunningOperationStatus::Running),
789        })));
790
791        cfg.interceptor_state().store_put(TokenBucket::new(5));
792        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
793
794        cfg.interceptor_state().store_put(RequestAttempts::new(1));
795        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
796        let dur = should_retry.expect_delay();
797        assert_eq!(dur, Duration::from_secs(1));
798        assert_eq!(token_bucket.available_permits(), 0);
799
800        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
801            status: Some(LongRunningOperationStatus::Complete),
802        })));
803        cfg.interceptor_state().store_put(RequestAttempts::new(2));
804        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
805        should_retry.expect_no();
806        assert_eq!(token_bucket.available_permits(), 5);
807    }
808
809    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
810    #[test]
811    fn no_quota() {
812        let (mut cfg, rc, ctx) = setup_test(
813            vec![RetryAction::server_error()],
814            RetryConfig::standard()
815                .with_use_static_exponential_base(true)
816                .with_max_attempts(5),
817        );
818        let strategy = StandardRetryStrategy::new();
819        cfg.interceptor_state().store_put(TokenBucket::new(5));
820        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
821
822        cfg.interceptor_state().store_put(RequestAttempts::new(1));
823        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
824        let dur = should_retry.expect_delay();
825        assert_eq!(dur, Duration::from_secs(1));
826        assert_eq!(token_bucket.available_permits(), 0);
827
828        cfg.interceptor_state().store_put(RequestAttempts::new(2));
829        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
830        assert_eq!(no_retry, ShouldAttempt::No);
831        assert_eq!(token_bucket.available_permits(), 0);
832    }
833
834    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
835    #[test]
836    fn quota_replenishes_on_success() {
837        let (mut cfg, rc, mut ctx) = setup_test(
838            vec![
839                RetryAction::transient_error(),
840                RetryAction::retryable_error_with_explicit_delay(
841                    ErrorKind::TransientError,
842                    Duration::from_secs(1),
843                ),
844            ],
845            RetryConfig::standard()
846                .with_use_static_exponential_base(true)
847                .with_max_attempts(5),
848        );
849        let strategy = StandardRetryStrategy::new();
850        cfg.interceptor_state().store_put(TokenBucket::new(100));
851        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
852
853        cfg.interceptor_state().store_put(RequestAttempts::new(1));
854        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
855        let dur = should_retry.expect_delay();
856        assert_eq!(dur, Duration::from_secs(1));
857        assert_eq!(token_bucket.available_permits(), 90);
858
859        cfg.interceptor_state().store_put(RequestAttempts::new(2));
860        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
861        let dur = should_retry.expect_delay();
862        assert_eq!(dur, Duration::from_secs(1));
863        assert_eq!(token_bucket.available_permits(), 80);
864
865        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
866
867        cfg.interceptor_state().store_put(RequestAttempts::new(3));
868        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
869        assert_eq!(no_retry, ShouldAttempt::No);
870
871        assert_eq!(token_bucket.available_permits(), 90);
872    }
873
874    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
875    #[test]
876    fn quota_replenishes_on_first_try_success() {
877        const PERMIT_COUNT: usize = 20;
878        let (mut cfg, rc, mut ctx) = setup_test(
879            vec![RetryAction::transient_error()],
880            RetryConfig::standard()
881                .with_use_static_exponential_base(true)
882                .with_max_attempts(u32::MAX),
883        );
884        let strategy = StandardRetryStrategy::new();
885        cfg.interceptor_state()
886            .store_put(TokenBucket::new(PERMIT_COUNT));
887        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
888
889        let mut attempt = 1;
890
891        // Drain all available permits with failed attempts
892        while token_bucket.available_permits() > 0 {
893            // Draining should complete in 2 attempts
894            if attempt > 2 {
895                panic!("This test should have completed by now (drain)");
896            }
897
898            cfg.interceptor_state()
899                .store_put(RequestAttempts::new(attempt));
900            let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
901            assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
902            attempt += 1;
903        }
904
905        // Forget the permit so that we can only refill by "success on first try".
906        let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
907        permit.forget();
908
909        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
910
911        // Replenish permits until we get back to `PERMIT_COUNT`
912        while token_bucket.available_permits() < PERMIT_COUNT {
913            if attempt > 23 {
914                panic!("This test should have completed by now (fill-up)");
915            }
916
917            cfg.interceptor_state()
918                .store_put(RequestAttempts::new(attempt));
919            let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
920            assert_eq!(no_retry, ShouldAttempt::No);
921            attempt += 1;
922        }
923
924        assert_eq!(attempt, 23);
925        assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
926    }
927
928    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
929    #[test]
930    fn backoff_timing() {
931        let (mut cfg, rc, ctx) = setup_test(
932            vec![RetryAction::server_error()],
933            RetryConfig::standard()
934                .with_use_static_exponential_base(true)
935                .with_max_attempts(5),
936        );
937        let strategy = StandardRetryStrategy::new();
938        cfg.interceptor_state().store_put(TokenBucket::default());
939        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
940
941        cfg.interceptor_state().store_put(RequestAttempts::new(1));
942        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
943        let dur = should_retry.expect_delay();
944        assert_eq!(dur, Duration::from_secs(1));
945        assert_eq!(token_bucket.available_permits(), 495);
946
947        cfg.interceptor_state().store_put(RequestAttempts::new(2));
948        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
949        let dur = should_retry.expect_delay();
950        assert_eq!(dur, Duration::from_secs(2));
951        assert_eq!(token_bucket.available_permits(), 490);
952
953        cfg.interceptor_state().store_put(RequestAttempts::new(3));
954        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
955        let dur = should_retry.expect_delay();
956        assert_eq!(dur, Duration::from_secs(4));
957        assert_eq!(token_bucket.available_permits(), 485);
958
959        cfg.interceptor_state().store_put(RequestAttempts::new(4));
960        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
961        let dur = should_retry.expect_delay();
962        assert_eq!(dur, Duration::from_secs(8));
963        assert_eq!(token_bucket.available_permits(), 480);
964
965        cfg.interceptor_state().store_put(RequestAttempts::new(5));
966        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
967        assert_eq!(no_retry, ShouldAttempt::No);
968        assert_eq!(token_bucket.available_permits(), 480);
969    }
970
971    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
972    #[test]
973    fn max_backoff_time() {
974        let (mut cfg, rc, ctx) = setup_test(
975            vec![RetryAction::server_error()],
976            RetryConfig::standard()
977                .with_use_static_exponential_base(true)
978                .with_max_attempts(5)
979                .with_initial_backoff(Duration::from_secs(1))
980                .with_max_backoff(Duration::from_secs(3)),
981        );
982        let strategy = StandardRetryStrategy::new();
983        cfg.interceptor_state().store_put(TokenBucket::default());
984        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
985
986        cfg.interceptor_state().store_put(RequestAttempts::new(1));
987        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
988        let dur = should_retry.expect_delay();
989        assert_eq!(dur, Duration::from_secs(1));
990        assert_eq!(token_bucket.available_permits(), 495);
991
992        cfg.interceptor_state().store_put(RequestAttempts::new(2));
993        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
994        let dur = should_retry.expect_delay();
995        assert_eq!(dur, Duration::from_secs(2));
996        assert_eq!(token_bucket.available_permits(), 490);
997
998        cfg.interceptor_state().store_put(RequestAttempts::new(3));
999        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1000        let dur = should_retry.expect_delay();
1001        assert_eq!(dur, Duration::from_secs(3));
1002        assert_eq!(token_bucket.available_permits(), 485);
1003
1004        cfg.interceptor_state().store_put(RequestAttempts::new(4));
1005        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1006        let dur = should_retry.expect_delay();
1007        assert_eq!(dur, Duration::from_secs(3));
1008        assert_eq!(token_bucket.available_permits(), 480);
1009
1010        cfg.interceptor_state().store_put(RequestAttempts::new(5));
1011        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1012        assert_eq!(no_retry, ShouldAttempt::No);
1013        assert_eq!(token_bucket.available_permits(), 480);
1014    }
1015
1016    const MAX_BACKOFF: Duration = Duration::from_secs(20);
1017
1018    #[test]
1019    fn calculate_exponential_backoff_where_initial_backoff_is_one() {
1020        let initial_backoff = 1.0;
1021
1022        for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
1023            let actual_backoff =
1024                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1025            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1026        }
1027    }
1028
1029    #[test]
1030    fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
1031        let initial_backoff = 3.0;
1032
1033        for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
1034            let actual_backoff =
1035                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1036            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1037        }
1038    }
1039
1040    #[test]
1041    fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
1042        let initial_backoff = 0.03;
1043
1044        for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
1045            let actual_backoff =
1046                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1047            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1048        }
1049    }
1050
1051    #[test]
1052    fn calculate_backoff_overflow_should_gracefully_fallback_to_max_backoff() {
1053        // avoid overflow for a silly large amount of retry attempts
1054        assert_eq!(
1055            MAX_BACKOFF,
1056            calculate_exponential_backoff(1_f64, 10_f64, 100000, MAX_BACKOFF),
1057        );
1058    }
1059}