aws_smithy_runtime/client/retries/strategy/
standard.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::sync::Mutex;
7use std::time::{Duration, SystemTime};
8
9use tokio::sync::OwnedSemaphorePermit;
10use tracing::{debug, trace};
11
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeTransmitInterceptorContextMut, InterceptorContext,
15};
16use aws_smithy_runtime_api::client::interceptors::Intercept;
17use aws_smithy_runtime_api::client::retries::classifiers::{RetryAction, RetryReason};
18use aws_smithy_runtime_api::client::retries::{RequestAttempts, RetryStrategy, ShouldAttempt};
19use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
20use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
21use aws_smithy_types::retry::{ErrorKind, RetryConfig, RetryMode};
22
23use crate::client::retries::classifiers::run_classifiers_on_ctx;
24use crate::client::retries::client_rate_limiter::{ClientRateLimiter, RequestReason};
25use crate::client::retries::strategy::standard::ReleaseResult::{
26    APermitWasReleased, NoPermitWasReleased,
27};
28use crate::client::retries::token_bucket::TokenBucket;
29use crate::client::retries::{ClientRateLimiterPartition, RetryPartition, RetryPartitionInner};
30use crate::static_partition_map::StaticPartitionMap;
31
32static CLIENT_RATE_LIMITER: StaticPartitionMap<ClientRateLimiterPartition, ClientRateLimiter> =
33    StaticPartitionMap::new();
34
35/// Used by token bucket interceptor to ensure a TokenBucket always exists in config bag
36static TOKEN_BUCKET: StaticPartitionMap<RetryPartition, TokenBucket> = StaticPartitionMap::new();
37
38/// Retry strategy with exponential backoff, max attempts, and a token bucket.
39#[derive(Debug, Default)]
40pub struct StandardRetryStrategy {
41    retry_permit: Mutex<Option<OwnedSemaphorePermit>>,
42}
43
44impl Storable for StandardRetryStrategy {
45    type Storer = StoreReplace<Self>;
46}
47
48impl StandardRetryStrategy {
49    /// Create a new standard retry strategy with the given config.
50    pub fn new() -> Self {
51        Default::default()
52    }
53
54    fn release_retry_permit(&self) -> ReleaseResult {
55        let mut retry_permit = self.retry_permit.lock().unwrap();
56        match retry_permit.take() {
57            Some(p) => {
58                drop(p);
59                APermitWasReleased
60            }
61            None => NoPermitWasReleased,
62        }
63    }
64
65    fn set_retry_permit(&self, new_retry_permit: OwnedSemaphorePermit) {
66        let mut old_retry_permit = self.retry_permit.lock().unwrap();
67        if let Some(p) = old_retry_permit.replace(new_retry_permit) {
68            // Whenever we set a new retry permit, and it replaces the old one, we need to "forget"
69            // the old permit, removing it from the bucket forever.
70            p.forget()
71        }
72    }
73
74    /// Returns a [`ClientRateLimiter`] if adaptive retry is configured.
75    fn adaptive_retry_rate_limiter(
76        runtime_components: &RuntimeComponents,
77        cfg: &ConfigBag,
78    ) -> Option<ClientRateLimiter> {
79        let retry_config = cfg.load::<RetryConfig>().expect("retry config is required");
80        if retry_config.mode() == RetryMode::Adaptive {
81            if let Some(time_source) = runtime_components.time_source() {
82                let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
83                let seconds_since_unix_epoch = time_source
84                    .now()
85                    .duration_since(SystemTime::UNIX_EPOCH)
86                    .expect("the present takes place after the UNIX_EPOCH")
87                    .as_secs_f64();
88                let client_rate_limiter = match &retry_partition.inner {
89                    RetryPartitionInner::Default(_) => {
90                        let client_rate_limiter_partition =
91                            ClientRateLimiterPartition::new(retry_partition.clone());
92                        CLIENT_RATE_LIMITER.get_or_init(client_rate_limiter_partition, || {
93                            ClientRateLimiter::new(seconds_since_unix_epoch)
94                        })
95                    }
96                    RetryPartitionInner::Custom {
97                        client_rate_limiter,
98                        ..
99                    } => client_rate_limiter.clone(),
100                };
101                return Some(client_rate_limiter);
102            }
103        }
104        None
105    }
106
107    fn calculate_backoff(
108        &self,
109        runtime_components: &RuntimeComponents,
110        cfg: &ConfigBag,
111        retry_cfg: &RetryConfig,
112        retry_reason: &RetryAction,
113    ) -> Result<Duration, ShouldAttempt> {
114        let request_attempts = cfg
115            .load::<RequestAttempts>()
116            .expect("at least one request attempt is made before any retry is attempted")
117            .attempts();
118
119        match retry_reason {
120            RetryAction::RetryIndicated(RetryReason::RetryableError { kind, retry_after }) => {
121                if let Some(delay) = *retry_after {
122                    let delay = delay.min(retry_cfg.max_backoff());
123                    debug!("explicit request from server to delay {delay:?} before retrying");
124                    Ok(delay)
125                } else if let Some(delay) =
126                    check_rate_limiter_for_delay(runtime_components, cfg, *kind)
127                {
128                    let delay = delay.min(retry_cfg.max_backoff());
129                    debug!("rate limiter has requested a {delay:?} delay before retrying");
130                    Ok(delay)
131                } else {
132                    let base = if retry_cfg.use_static_exponential_base() {
133                        1.0
134                    } else {
135                        fastrand::f64()
136                    };
137                    Ok(calculate_exponential_backoff(
138                        // Generate a random base multiplier to create jitter
139                        base,
140                        // Get the backoff time multiplier in seconds (with fractional seconds)
141                        retry_cfg.initial_backoff().as_secs_f64(),
142                        // `self.local.attempts` tracks number of requests made including the initial request
143                        // The initial attempt shouldn't count towards backoff calculations, so we subtract it
144                        request_attempts - 1,
145                        // Maximum backoff duration as a fallback to prevent overflow when calculating a power
146                        retry_cfg.max_backoff(),
147                    ))
148                }
149            }
150            RetryAction::RetryForbidden | RetryAction::NoActionIndicated => {
151                debug!(
152                    attempts = request_attempts,
153                    max_attempts = retry_cfg.max_attempts(),
154                    "encountered un-retryable error"
155                );
156                Err(ShouldAttempt::No)
157            }
158            _ => unreachable!("RetryAction is non-exhaustive"),
159        }
160    }
161}
162
163enum ReleaseResult {
164    APermitWasReleased,
165    NoPermitWasReleased,
166}
167
168impl RetryStrategy for StandardRetryStrategy {
169    fn should_attempt_initial_request(
170        &self,
171        runtime_components: &RuntimeComponents,
172        cfg: &ConfigBag,
173    ) -> Result<ShouldAttempt, BoxError> {
174        if let Some(crl) = Self::adaptive_retry_rate_limiter(runtime_components, cfg) {
175            let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
176            if let Err(delay) = crl.acquire_permission_to_send_a_request(
177                seconds_since_unix_epoch,
178                RequestReason::InitialRequest,
179            ) {
180                return Ok(ShouldAttempt::YesAfterDelay(delay));
181            }
182        } else {
183            debug!("no client rate limiter configured, so no token is required for the initial request.");
184        }
185
186        Ok(ShouldAttempt::Yes)
187    }
188
189    fn should_attempt_retry(
190        &self,
191        ctx: &InterceptorContext,
192        runtime_components: &RuntimeComponents,
193        cfg: &ConfigBag,
194    ) -> Result<ShouldAttempt, BoxError> {
195        let retry_cfg = cfg.load::<RetryConfig>().expect("retry config is required");
196
197        // bookkeeping
198        let token_bucket = cfg.load::<TokenBucket>().expect("token bucket is required");
199        // run the classifier against the context to determine if we should retry
200        let retry_classifiers = runtime_components.retry_classifiers();
201        let classifier_result = run_classifiers_on_ctx(retry_classifiers, ctx);
202
203        // (adaptive only): update fill rate
204        // NOTE: SEP indicates doing bookkeeping before asking if we should retry. We need to know if
205        // the error was a throttling error though to do adaptive retry bookkeeping so we take
206        // advantage of that information being available via the classifier result
207        let error_kind = error_kind(&classifier_result);
208        let is_throttling_error = error_kind
209            .map(|kind| kind == ErrorKind::ThrottlingError)
210            .unwrap_or(false);
211        update_rate_limiter_if_exists(runtime_components, cfg, is_throttling_error);
212
213        // on success release any retry quota held by previous attempts
214        if !ctx.is_failed() {
215            if let NoPermitWasReleased = self.release_retry_permit() {
216                // In the event that there was no retry permit to release, we generate new
217                // permits from nothing. We do this to make up for permits we had to "forget".
218                // Otherwise, repeated retries would empty the bucket and nothing could fill it
219                // back up again.
220                token_bucket.regenerate_a_token();
221            }
222        }
223        // end bookkeeping
224
225        let request_attempts = cfg
226            .load::<RequestAttempts>()
227            .expect("at least one request attempt is made before any retry is attempted")
228            .attempts();
229
230        // check if retry should be attempted
231        if !classifier_result.should_retry() {
232            debug!(
233                "attempt #{request_attempts} classified as {:?}, not retrying",
234                classifier_result
235            );
236            return Ok(ShouldAttempt::No);
237        }
238
239        // check if we're out of attempts
240        if request_attempts >= retry_cfg.max_attempts() {
241            debug!(
242                attempts = request_attempts,
243                max_attempts = retry_cfg.max_attempts(),
244                "not retrying because we are out of attempts"
245            );
246            return Ok(ShouldAttempt::No);
247        }
248
249        //  acquire permit for retry
250        let error_kind = error_kind.expect("result was classified retryable");
251        match token_bucket.acquire(&error_kind) {
252            Some(permit) => self.set_retry_permit(permit),
253            None => {
254                debug!("attempt #{request_attempts} failed with {error_kind:?}; However, not enough retry quota is available for another attempt so no retry will be attempted.");
255                return Ok(ShouldAttempt::No);
256            }
257        }
258
259        // calculate delay until next attempt
260        let backoff =
261            match self.calculate_backoff(runtime_components, cfg, retry_cfg, &classifier_result) {
262                Ok(value) => value,
263                // In some cases, backoff calculation will decide that we shouldn't retry at all.
264                Err(value) => return Ok(value),
265            };
266
267        debug!(
268            "attempt #{request_attempts} failed with {:?}; retrying after {:?}",
269            classifier_result, backoff
270        );
271        Ok(ShouldAttempt::YesAfterDelay(backoff))
272    }
273}
274
275/// extract the error kind from the classifier result if available
276fn error_kind(classifier_result: &RetryAction) -> Option<ErrorKind> {
277    match classifier_result {
278        RetryAction::RetryIndicated(RetryReason::RetryableError { kind, .. }) => Some(*kind),
279        _ => None,
280    }
281}
282
283fn update_rate_limiter_if_exists(
284    runtime_components: &RuntimeComponents,
285    cfg: &ConfigBag,
286    is_throttling_error: bool,
287) {
288    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
289        let seconds_since_unix_epoch = get_seconds_since_unix_epoch(runtime_components);
290        crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error);
291    }
292}
293
294fn check_rate_limiter_for_delay(
295    runtime_components: &RuntimeComponents,
296    cfg: &ConfigBag,
297    kind: ErrorKind,
298) -> Option<Duration> {
299    if let Some(crl) = StandardRetryStrategy::adaptive_retry_rate_limiter(runtime_components, cfg) {
300        let retry_reason = if kind == ErrorKind::ThrottlingError {
301            RequestReason::RetryTimeout
302        } else {
303            RequestReason::Retry
304        };
305        if let Err(delay) = crl.acquire_permission_to_send_a_request(
306            get_seconds_since_unix_epoch(runtime_components),
307            retry_reason,
308        ) {
309            return Some(delay);
310        }
311    }
312
313    None
314}
315
316fn calculate_exponential_backoff(
317    base: f64,
318    initial_backoff: f64,
319    retry_attempts: u32,
320    max_backoff: Duration,
321) -> Duration {
322    let result = match 2_u32
323        .checked_pow(retry_attempts)
324        .map(|power| (power as f64) * initial_backoff)
325    {
326        Some(backoff) => match Duration::try_from_secs_f64(backoff) {
327            Ok(result) => result.min(max_backoff),
328            Err(e) => {
329                tracing::warn!("falling back to {max_backoff:?} as `Duration` could not be created for exponential backoff: {e}");
330                max_backoff
331            }
332        },
333        None => max_backoff,
334    };
335
336    // Apply jitter to `result`, and note that it can be applied to `max_backoff`.
337    // Won't panic because `base` is either in range 0..1 or a constant 1 in testing (if configured).
338    result.mul_f64(base)
339}
340
341fn get_seconds_since_unix_epoch(runtime_components: &RuntimeComponents) -> f64 {
342    let request_time = runtime_components
343        .time_source()
344        .expect("time source required for retries");
345    request_time
346        .now()
347        .duration_since(SystemTime::UNIX_EPOCH)
348        .unwrap()
349        .as_secs_f64()
350}
351
352/// Interceptor registered in default retry plugin that ensures a token bucket exists in config
353/// bag for every operation. Token bucket provided is partitioned by the retry partition **in the
354/// config bag** at the time an operation is executed.
355#[derive(Debug)]
356pub(crate) struct TokenBucketProvider {
357    default_partition: RetryPartition,
358    token_bucket: TokenBucket,
359}
360
361impl TokenBucketProvider {
362    /// Create a new token bucket provider with the given default retry partition.
363    ///
364    /// NOTE: This partition should be the one used for every operation on a client
365    /// unless config is overridden.
366    pub(crate) fn new(default_partition: RetryPartition) -> Self {
367        let token_bucket = TOKEN_BUCKET.get_or_init_default(default_partition.clone());
368        Self {
369            default_partition,
370            token_bucket,
371        }
372    }
373}
374
375impl Intercept for TokenBucketProvider {
376    fn name(&self) -> &'static str {
377        "TokenBucketProvider"
378    }
379
380    fn modify_before_retry_loop(
381        &self,
382        _context: &mut BeforeTransmitInterceptorContextMut<'_>,
383        _runtime_components: &RuntimeComponents,
384        cfg: &mut ConfigBag,
385    ) -> Result<(), BoxError> {
386        let retry_partition = cfg.load::<RetryPartition>().expect("set in default config");
387
388        let tb = match &retry_partition.inner {
389            RetryPartitionInner::Default(name) => {
390                // we store the original retry partition configured and associated token bucket
391                // for the client when created so that we can avoid locking on _every_ request
392                // from _every_ client
393                if name == self.default_partition.name() {
394                    // avoid contention on the global lock
395                    self.token_bucket.clone()
396                } else {
397                    TOKEN_BUCKET.get_or_init_default(retry_partition.clone())
398                }
399            }
400            RetryPartitionInner::Custom { token_bucket, .. } => token_bucket.clone(),
401        };
402
403        trace!("token bucket for {retry_partition:?} added to config bag");
404        let mut layer = Layer::new("token_bucket_partition");
405        layer.store_put(tb);
406        cfg.push_layer(layer);
407        Ok(())
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    #[allow(unused_imports)] // will be unused with `--no-default-features --features client`
414    use std::fmt;
415    use std::sync::Mutex;
416    use std::time::Duration;
417
418    use aws_smithy_async::time::SystemTimeSource;
419    use aws_smithy_runtime_api::client::interceptors::context::{
420        Input, InterceptorContext, Output,
421    };
422    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
423    use aws_smithy_runtime_api::client::retries::classifiers::{
424        ClassifyRetry, RetryAction, SharedRetryClassifier,
425    };
426    use aws_smithy_runtime_api::client::retries::{
427        AlwaysRetry, RequestAttempts, RetryStrategy, ShouldAttempt,
428    };
429    use aws_smithy_runtime_api::client::runtime_components::{
430        RuntimeComponents, RuntimeComponentsBuilder,
431    };
432    use aws_smithy_types::config_bag::{ConfigBag, Layer};
433    use aws_smithy_types::retry::{ErrorKind, RetryConfig};
434
435    use super::{calculate_exponential_backoff, StandardRetryStrategy};
436    use crate::client::retries::{ClientRateLimiter, RetryPartition, TokenBucket};
437
438    #[test]
439    fn no_retry_necessary_for_ok_result() {
440        let cfg = ConfigBag::of_layers(vec![{
441            let mut layer = Layer::new("test");
442            layer.store_put(RetryConfig::standard());
443            layer.store_put(RequestAttempts::new(1));
444            layer.store_put(TokenBucket::default());
445            layer
446        }]);
447        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
448        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
449        let strategy = StandardRetryStrategy::default();
450        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
451
452        let actual = strategy
453            .should_attempt_retry(&ctx, &rc, &cfg)
454            .expect("method is infallible for this use");
455        assert_eq!(ShouldAttempt::No, actual);
456    }
457
458    fn set_up_cfg_and_context(
459        error_kind: ErrorKind,
460        current_request_attempts: u32,
461        retry_config: RetryConfig,
462    ) -> (InterceptorContext, RuntimeComponents, ConfigBag) {
463        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
464        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
465        let rc = RuntimeComponentsBuilder::for_tests()
466            .with_retry_classifier(SharedRetryClassifier::new(AlwaysRetry(error_kind)))
467            .build()
468            .unwrap();
469        let mut layer = Layer::new("test");
470        layer.store_put(RequestAttempts::new(current_request_attempts));
471        layer.store_put(retry_config);
472        layer.store_put(TokenBucket::default());
473        let cfg = ConfigBag::of_layers(vec![layer]);
474
475        (ctx, rc, cfg)
476    }
477
478    // Test that error kinds produce the correct "retry after X seconds" output.
479    // All error kinds are handled in the same way for the standard strategy.
480    fn test_should_retry_error_kind(error_kind: ErrorKind) {
481        let (ctx, rc, cfg) = set_up_cfg_and_context(
482            error_kind,
483            3,
484            RetryConfig::standard()
485                .with_use_static_exponential_base(true)
486                .with_max_attempts(4),
487        );
488        let strategy = StandardRetryStrategy::new();
489        let actual = strategy
490            .should_attempt_retry(&ctx, &rc, &cfg)
491            .expect("method is infallible for this use");
492        assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual);
493    }
494
495    #[test]
496    fn should_retry_transient_error_result_after_2s() {
497        test_should_retry_error_kind(ErrorKind::TransientError);
498    }
499
500    #[test]
501    fn should_retry_client_error_result_after_2s() {
502        test_should_retry_error_kind(ErrorKind::ClientError);
503    }
504
505    #[test]
506    fn should_retry_server_error_result_after_2s() {
507        test_should_retry_error_kind(ErrorKind::ServerError);
508    }
509
510    #[test]
511    fn should_retry_throttling_error_result_after_2s() {
512        test_should_retry_error_kind(ErrorKind::ThrottlingError);
513    }
514
515    #[test]
516    fn dont_retry_when_out_of_attempts() {
517        let current_attempts = 4;
518        let max_attempts = current_attempts;
519        let (ctx, rc, cfg) = set_up_cfg_and_context(
520            ErrorKind::TransientError,
521            current_attempts,
522            RetryConfig::standard()
523                .with_use_static_exponential_base(true)
524                .with_max_attempts(max_attempts),
525        );
526        let strategy = StandardRetryStrategy::new();
527        let actual = strategy
528            .should_attempt_retry(&ctx, &rc, &cfg)
529            .expect("method is infallible for this use");
530        assert_eq!(ShouldAttempt::No, actual);
531    }
532
533    #[test]
534    fn should_not_panic_when_exponential_backoff_duration_could_not_be_created() {
535        let (ctx, rc, cfg) = set_up_cfg_and_context(
536            ErrorKind::TransientError,
537            // Greater than 32 when subtracted by 1 in `calculate_backoff`, causing overflow in `calculate_exponential_backoff`
538            33,
539            RetryConfig::standard()
540                .with_use_static_exponential_base(true)
541                .with_max_attempts(100), // Any value greater than 33 will do
542        );
543        let strategy = StandardRetryStrategy::new();
544        let actual = strategy
545            .should_attempt_retry(&ctx, &rc, &cfg)
546            .expect("method is infallible for this use");
547        assert_eq!(ShouldAttempt::YesAfterDelay(MAX_BACKOFF), actual);
548    }
549
550    #[test]
551    fn should_yield_client_rate_limiter_from_custom_partition() {
552        let expected = ClientRateLimiter::builder().token_refill_rate(3.14).build();
553        let cfg = ConfigBag::of_layers(vec![
554            // Emulate default config layer overriden by a user config layer
555            {
556                let mut layer = Layer::new("default");
557                layer.store_put(RetryPartition::new("default"));
558                layer
559            },
560            {
561                let mut layer = Layer::new("user");
562                layer.store_put(RetryConfig::adaptive());
563                layer.store_put(
564                    RetryPartition::custom("user")
565                        .client_rate_limiter(expected.clone())
566                        .build(),
567                );
568                layer
569            },
570        ]);
571        let rc = RuntimeComponentsBuilder::for_tests()
572            .with_time_source(Some(SystemTimeSource::new()))
573            .build()
574            .unwrap();
575        let actual = StandardRetryStrategy::adaptive_retry_rate_limiter(&rc, &cfg)
576            .expect("should yield client rate limiter from custom partition");
577        assert!(std::sync::Arc::ptr_eq(&expected.inner, &actual.inner));
578    }
579
580    #[allow(dead_code)] // will be unused with `--no-default-features --features client`
581    #[derive(Debug)]
582    struct PresetReasonRetryClassifier {
583        retry_actions: Mutex<Vec<RetryAction>>,
584    }
585
586    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
587    impl PresetReasonRetryClassifier {
588        fn new(mut retry_reasons: Vec<RetryAction>) -> Self {
589            // We'll pop the retry_reasons in reverse order, so we reverse the list to fix that.
590            retry_reasons.reverse();
591            Self {
592                retry_actions: Mutex::new(retry_reasons),
593            }
594        }
595    }
596
597    impl ClassifyRetry for PresetReasonRetryClassifier {
598        fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
599            // Check for a result
600            let output_or_error = ctx.output_or_error();
601            // Check for an error
602            match output_or_error {
603                Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
604                _ => (),
605            };
606
607            let mut retry_actions = self.retry_actions.lock().unwrap();
608            if retry_actions.len() == 1 {
609                retry_actions.first().unwrap().clone()
610            } else {
611                retry_actions.pop().unwrap()
612            }
613        }
614
615        fn name(&self) -> &'static str {
616            "Always returns a preset retry reason"
617        }
618    }
619
620    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
621    fn setup_test(
622        retry_reasons: Vec<RetryAction>,
623        retry_config: RetryConfig,
624    ) -> (ConfigBag, RuntimeComponents, InterceptorContext) {
625        let rc = RuntimeComponentsBuilder::for_tests()
626            .with_retry_classifier(SharedRetryClassifier::new(
627                PresetReasonRetryClassifier::new(retry_reasons),
628            ))
629            .build()
630            .unwrap();
631        let mut layer = Layer::new("test");
632        layer.store_put(retry_config);
633        let cfg = ConfigBag::of_layers(vec![layer]);
634        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
635        // This type doesn't matter b/c the classifier will just return whatever we tell it to.
636        ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter")));
637
638        (cfg, rc, ctx)
639    }
640
641    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
642    #[test]
643    fn eventual_success() {
644        let (mut cfg, rc, mut ctx) = setup_test(
645            vec![RetryAction::server_error()],
646            RetryConfig::standard()
647                .with_use_static_exponential_base(true)
648                .with_max_attempts(5),
649        );
650        let strategy = StandardRetryStrategy::new();
651        cfg.interceptor_state().store_put(TokenBucket::default());
652        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
653
654        cfg.interceptor_state().store_put(RequestAttempts::new(1));
655        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
656        let dur = should_retry.expect_delay();
657        assert_eq!(dur, Duration::from_secs(1));
658        assert_eq!(token_bucket.available_permits(), 495);
659
660        cfg.interceptor_state().store_put(RequestAttempts::new(2));
661        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
662        let dur = should_retry.expect_delay();
663        assert_eq!(dur, Duration::from_secs(2));
664        assert_eq!(token_bucket.available_permits(), 490);
665
666        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
667
668        cfg.interceptor_state().store_put(RequestAttempts::new(3));
669        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
670        assert_eq!(no_retry, ShouldAttempt::No);
671        assert_eq!(token_bucket.available_permits(), 495);
672    }
673
674    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
675    #[test]
676    fn no_more_attempts() {
677        let (mut cfg, rc, ctx) = setup_test(
678            vec![RetryAction::server_error()],
679            RetryConfig::standard()
680                .with_use_static_exponential_base(true)
681                .with_max_attempts(3),
682        );
683        let strategy = StandardRetryStrategy::new();
684        cfg.interceptor_state().store_put(TokenBucket::default());
685        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
686
687        cfg.interceptor_state().store_put(RequestAttempts::new(1));
688        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
689        let dur = should_retry.expect_delay();
690        assert_eq!(dur, Duration::from_secs(1));
691        assert_eq!(token_bucket.available_permits(), 495);
692
693        cfg.interceptor_state().store_put(RequestAttempts::new(2));
694        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
695        let dur = should_retry.expect_delay();
696        assert_eq!(dur, Duration::from_secs(2));
697        assert_eq!(token_bucket.available_permits(), 490);
698
699        cfg.interceptor_state().store_put(RequestAttempts::new(3));
700        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
701        assert_eq!(no_retry, ShouldAttempt::No);
702        assert_eq!(token_bucket.available_permits(), 490);
703    }
704
705    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
706    #[test]
707    fn successful_request_and_deser_should_be_retryable() {
708        #[derive(Clone, Copy, Debug)]
709        enum LongRunningOperationStatus {
710            Running,
711            Complete,
712        }
713
714        #[derive(Debug)]
715        struct LongRunningOperationOutput {
716            status: Option<LongRunningOperationStatus>,
717        }
718
719        impl LongRunningOperationOutput {
720            fn status(&self) -> Option<LongRunningOperationStatus> {
721                self.status
722            }
723        }
724
725        struct WaiterRetryClassifier {}
726
727        impl WaiterRetryClassifier {
728            fn new() -> Self {
729                WaiterRetryClassifier {}
730            }
731        }
732
733        impl fmt::Debug for WaiterRetryClassifier {
734            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
735                write!(f, "WaiterRetryClassifier")
736            }
737        }
738        impl ClassifyRetry for WaiterRetryClassifier {
739            fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
740                let status: Option<LongRunningOperationStatus> =
741                    ctx.output_or_error().and_then(|res| {
742                        res.ok().and_then(|output| {
743                            output
744                                .downcast_ref::<LongRunningOperationOutput>()
745                                .and_then(|output| output.status())
746                        })
747                    });
748
749                if let Some(LongRunningOperationStatus::Running) = status {
750                    return RetryAction::server_error();
751                };
752
753                RetryAction::NoActionIndicated
754            }
755
756            fn name(&self) -> &'static str {
757                "waiter retry classifier"
758            }
759        }
760
761        let retry_config = RetryConfig::standard()
762            .with_use_static_exponential_base(true)
763            .with_max_attempts(5);
764
765        let rc = RuntimeComponentsBuilder::for_tests()
766            .with_retry_classifier(SharedRetryClassifier::new(WaiterRetryClassifier::new()))
767            .build()
768            .unwrap();
769        let mut layer = Layer::new("test");
770        layer.store_put(retry_config);
771        let mut cfg = ConfigBag::of_layers(vec![layer]);
772        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
773        let strategy = StandardRetryStrategy::new();
774
775        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
776            status: Some(LongRunningOperationStatus::Running),
777        })));
778
779        cfg.interceptor_state().store_put(TokenBucket::new(5));
780        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
781
782        cfg.interceptor_state().store_put(RequestAttempts::new(1));
783        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
784        let dur = should_retry.expect_delay();
785        assert_eq!(dur, Duration::from_secs(1));
786        assert_eq!(token_bucket.available_permits(), 0);
787
788        ctx.set_output_or_error(Ok(Output::erase(LongRunningOperationOutput {
789            status: Some(LongRunningOperationStatus::Complete),
790        })));
791        cfg.interceptor_state().store_put(RequestAttempts::new(2));
792        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
793        should_retry.expect_no();
794        assert_eq!(token_bucket.available_permits(), 5);
795    }
796
797    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
798    #[test]
799    fn no_quota() {
800        let (mut cfg, rc, ctx) = setup_test(
801            vec![RetryAction::server_error()],
802            RetryConfig::standard()
803                .with_use_static_exponential_base(true)
804                .with_max_attempts(5),
805        );
806        let strategy = StandardRetryStrategy::new();
807        cfg.interceptor_state().store_put(TokenBucket::new(5));
808        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
809
810        cfg.interceptor_state().store_put(RequestAttempts::new(1));
811        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
812        let dur = should_retry.expect_delay();
813        assert_eq!(dur, Duration::from_secs(1));
814        assert_eq!(token_bucket.available_permits(), 0);
815
816        cfg.interceptor_state().store_put(RequestAttempts::new(2));
817        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
818        assert_eq!(no_retry, ShouldAttempt::No);
819        assert_eq!(token_bucket.available_permits(), 0);
820    }
821
822    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
823    #[test]
824    fn quota_replenishes_on_success() {
825        let (mut cfg, rc, mut ctx) = setup_test(
826            vec![
827                RetryAction::transient_error(),
828                RetryAction::retryable_error_with_explicit_delay(
829                    ErrorKind::TransientError,
830                    Duration::from_secs(1),
831                ),
832            ],
833            RetryConfig::standard()
834                .with_use_static_exponential_base(true)
835                .with_max_attempts(5),
836        );
837        let strategy = StandardRetryStrategy::new();
838        cfg.interceptor_state().store_put(TokenBucket::new(100));
839        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
840
841        cfg.interceptor_state().store_put(RequestAttempts::new(1));
842        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
843        let dur = should_retry.expect_delay();
844        assert_eq!(dur, Duration::from_secs(1));
845        assert_eq!(token_bucket.available_permits(), 90);
846
847        cfg.interceptor_state().store_put(RequestAttempts::new(2));
848        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
849        let dur = should_retry.expect_delay();
850        assert_eq!(dur, Duration::from_secs(1));
851        assert_eq!(token_bucket.available_permits(), 80);
852
853        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
854
855        cfg.interceptor_state().store_put(RequestAttempts::new(3));
856        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
857        assert_eq!(no_retry, ShouldAttempt::No);
858
859        assert_eq!(token_bucket.available_permits(), 90);
860    }
861
862    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
863    #[test]
864    fn quota_replenishes_on_first_try_success() {
865        const PERMIT_COUNT: usize = 20;
866        let (mut cfg, rc, mut ctx) = setup_test(
867            vec![RetryAction::transient_error()],
868            RetryConfig::standard()
869                .with_use_static_exponential_base(true)
870                .with_max_attempts(u32::MAX),
871        );
872        let strategy = StandardRetryStrategy::new();
873        cfg.interceptor_state()
874            .store_put(TokenBucket::new(PERMIT_COUNT));
875        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
876
877        let mut attempt = 1;
878
879        // Drain all available permits with failed attempts
880        while token_bucket.available_permits() > 0 {
881            // Draining should complete in 2 attempts
882            if attempt > 2 {
883                panic!("This test should have completed by now (drain)");
884            }
885
886            cfg.interceptor_state()
887                .store_put(RequestAttempts::new(attempt));
888            let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
889            assert!(matches!(should_retry, ShouldAttempt::YesAfterDelay(_)));
890            attempt += 1;
891        }
892
893        // Forget the permit so that we can only refill by "success on first try".
894        let permit = strategy.retry_permit.lock().unwrap().take().unwrap();
895        permit.forget();
896
897        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
898
899        // Replenish permits until we get back to `PERMIT_COUNT`
900        while token_bucket.available_permits() < PERMIT_COUNT {
901            if attempt > 23 {
902                panic!("This test should have completed by now (fill-up)");
903            }
904
905            cfg.interceptor_state()
906                .store_put(RequestAttempts::new(attempt));
907            let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
908            assert_eq!(no_retry, ShouldAttempt::No);
909            attempt += 1;
910        }
911
912        assert_eq!(attempt, 23);
913        assert_eq!(token_bucket.available_permits(), PERMIT_COUNT);
914    }
915
916    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
917    #[test]
918    fn backoff_timing() {
919        let (mut cfg, rc, ctx) = setup_test(
920            vec![RetryAction::server_error()],
921            RetryConfig::standard()
922                .with_use_static_exponential_base(true)
923                .with_max_attempts(5),
924        );
925        let strategy = StandardRetryStrategy::new();
926        cfg.interceptor_state().store_put(TokenBucket::default());
927        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
928
929        cfg.interceptor_state().store_put(RequestAttempts::new(1));
930        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
931        let dur = should_retry.expect_delay();
932        assert_eq!(dur, Duration::from_secs(1));
933        assert_eq!(token_bucket.available_permits(), 495);
934
935        cfg.interceptor_state().store_put(RequestAttempts::new(2));
936        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
937        let dur = should_retry.expect_delay();
938        assert_eq!(dur, Duration::from_secs(2));
939        assert_eq!(token_bucket.available_permits(), 490);
940
941        cfg.interceptor_state().store_put(RequestAttempts::new(3));
942        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
943        let dur = should_retry.expect_delay();
944        assert_eq!(dur, Duration::from_secs(4));
945        assert_eq!(token_bucket.available_permits(), 485);
946
947        cfg.interceptor_state().store_put(RequestAttempts::new(4));
948        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
949        let dur = should_retry.expect_delay();
950        assert_eq!(dur, Duration::from_secs(8));
951        assert_eq!(token_bucket.available_permits(), 480);
952
953        cfg.interceptor_state().store_put(RequestAttempts::new(5));
954        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
955        assert_eq!(no_retry, ShouldAttempt::No);
956        assert_eq!(token_bucket.available_permits(), 480);
957    }
958
959    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
960    #[test]
961    fn max_backoff_time() {
962        let (mut cfg, rc, ctx) = setup_test(
963            vec![RetryAction::server_error()],
964            RetryConfig::standard()
965                .with_use_static_exponential_base(true)
966                .with_max_attempts(5)
967                .with_initial_backoff(Duration::from_secs(1))
968                .with_max_backoff(Duration::from_secs(3)),
969        );
970        let strategy = StandardRetryStrategy::new();
971        cfg.interceptor_state().store_put(TokenBucket::default());
972        let token_bucket = cfg.load::<TokenBucket>().unwrap().clone();
973
974        cfg.interceptor_state().store_put(RequestAttempts::new(1));
975        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
976        let dur = should_retry.expect_delay();
977        assert_eq!(dur, Duration::from_secs(1));
978        assert_eq!(token_bucket.available_permits(), 495);
979
980        cfg.interceptor_state().store_put(RequestAttempts::new(2));
981        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
982        let dur = should_retry.expect_delay();
983        assert_eq!(dur, Duration::from_secs(2));
984        assert_eq!(token_bucket.available_permits(), 490);
985
986        cfg.interceptor_state().store_put(RequestAttempts::new(3));
987        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
988        let dur = should_retry.expect_delay();
989        assert_eq!(dur, Duration::from_secs(3));
990        assert_eq!(token_bucket.available_permits(), 485);
991
992        cfg.interceptor_state().store_put(RequestAttempts::new(4));
993        let should_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
994        let dur = should_retry.expect_delay();
995        assert_eq!(dur, Duration::from_secs(3));
996        assert_eq!(token_bucket.available_permits(), 480);
997
998        cfg.interceptor_state().store_put(RequestAttempts::new(5));
999        let no_retry = strategy.should_attempt_retry(&ctx, &rc, &cfg).unwrap();
1000        assert_eq!(no_retry, ShouldAttempt::No);
1001        assert_eq!(token_bucket.available_permits(), 480);
1002    }
1003
1004    const MAX_BACKOFF: Duration = Duration::from_secs(20);
1005
1006    #[test]
1007    fn calculate_exponential_backoff_where_initial_backoff_is_one() {
1008        let initial_backoff = 1.0;
1009
1010        for (attempt, expected_backoff) in [initial_backoff, 2.0, 4.0].into_iter().enumerate() {
1011            let actual_backoff =
1012                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1013            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1014        }
1015    }
1016
1017    #[test]
1018    fn calculate_exponential_backoff_where_initial_backoff_is_greater_than_one() {
1019        let initial_backoff = 3.0;
1020
1021        for (attempt, expected_backoff) in [initial_backoff, 6.0, 12.0].into_iter().enumerate() {
1022            let actual_backoff =
1023                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1024            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1025        }
1026    }
1027
1028    #[test]
1029    fn calculate_exponential_backoff_where_initial_backoff_is_less_than_one() {
1030        let initial_backoff = 0.03;
1031
1032        for (attempt, expected_backoff) in [initial_backoff, 0.06, 0.12].into_iter().enumerate() {
1033            let actual_backoff =
1034                calculate_exponential_backoff(1.0, initial_backoff, attempt as u32, MAX_BACKOFF);
1035            assert_eq!(Duration::from_secs_f64(expected_backoff), actual_backoff);
1036        }
1037    }
1038
1039    #[test]
1040    fn calculate_backoff_overflow_should_gracefully_fallback_to_max_backoff() {
1041        // avoid overflow for a silly large amount of retry attempts
1042        assert_eq!(
1043            MAX_BACKOFF,
1044            calculate_exponential_backoff(1_f64, 10_f64, 100000, MAX_BACKOFF),
1045        );
1046    }
1047}