aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use aws_smithy_types::retry::ErrorKind;
8use std::sync::Arc;
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::trace;
11
12const DEFAULT_CAPACITY: usize = 500;
13const RETRY_COST: u32 = 5;
14const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
15const PERMIT_REGENERATION_AMOUNT: usize = 1;
16
17/// Token bucket used for standard and adaptive retry.
18#[derive(Clone, Debug)]
19pub struct TokenBucket {
20    semaphore: Arc<Semaphore>,
21    max_permits: usize,
22    timeout_retry_cost: u32,
23    retry_cost: u32,
24}
25
26impl Storable for TokenBucket {
27    type Storer = StoreReplace<Self>;
28}
29
30impl Default for TokenBucket {
31    fn default() -> Self {
32        Self {
33            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
34            max_permits: DEFAULT_CAPACITY,
35            timeout_retry_cost: RETRY_TIMEOUT_COST,
36            retry_cost: RETRY_COST,
37        }
38    }
39}
40
41impl TokenBucket {
42    /// Creates a new `TokenBucket` with the given initial quota.
43    pub fn new(initial_quota: usize) -> Self {
44        Self {
45            semaphore: Arc::new(Semaphore::new(initial_quota)),
46            max_permits: initial_quota,
47            ..Default::default()
48        }
49    }
50
51    /// A token bucket with unlimited capacity that allows retries at no cost.
52    pub fn unlimited() -> Self {
53        Self {
54            semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
55            max_permits: Semaphore::MAX_PERMITS,
56            timeout_retry_cost: 0,
57            retry_cost: 0,
58        }
59    }
60
61    /// Creates a builder for constructing a `TokenBucket`.
62    pub fn builder() -> TokenBucketBuilder {
63        TokenBucketBuilder::default()
64    }
65
66    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
67        let retry_cost = if err == &ErrorKind::TransientError {
68            self.timeout_retry_cost
69        } else {
70            self.retry_cost
71        };
72
73        self.semaphore
74            .clone()
75            .try_acquire_many_owned(retry_cost)
76            .ok()
77    }
78
79    pub(crate) fn regenerate_a_token(&self) {
80        if self.semaphore.available_permits() < self.max_permits {
81            trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
82            self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
83        }
84    }
85
86    #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
87    pub(crate) fn available_permits(&self) -> usize {
88        self.semaphore.available_permits()
89    }
90}
91
92/// Builder for constructing a `TokenBucket`.
93#[derive(Clone, Debug, Default)]
94pub struct TokenBucketBuilder {
95    capacity: Option<usize>,
96    retry_cost: Option<u32>,
97    timeout_retry_cost: Option<u32>,
98}
99
100impl TokenBucketBuilder {
101    /// Creates a new `TokenBucketBuilder` with default values.
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    /// Sets the maximum bucket capacity for the builder.
107    pub fn capacity(mut self, capacity: usize) -> Self {
108        self.capacity = Some(capacity);
109        self
110    }
111
112    /// Sets the specified retry cost for the builder.
113    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
114        self.retry_cost = Some(retry_cost);
115        self
116    }
117
118    /// Sets the specified timeout retry cost for the builder.
119    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
120        self.timeout_retry_cost = Some(timeout_retry_cost);
121        self
122    }
123
124    /// Builds a `TokenBucket`.
125    pub fn build(self) -> TokenBucket {
126        TokenBucket {
127            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
128            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
129            retry_cost: self.retry_cost.unwrap_or(RETRY_COST),
130            timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST),
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_unlimited_token_bucket() {
141        let bucket = TokenBucket::unlimited();
142
143        // Should always acquire permits regardless of error type
144        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
145        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
146
147        // Should have maximum capacity
148        assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS);
149
150        // Should have zero retry costs
151        assert_eq!(bucket.retry_cost, 0);
152        assert_eq!(bucket.timeout_retry_cost, 0);
153
154        // The loop count is arbitrary; should obtain permits without limit
155        let mut permits = Vec::new();
156        for _ in 0..100 {
157            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
158            assert!(permit.is_some());
159            permits.push(permit);
160            // Available permits should stay constant
161            assert_eq!(
162                tokio::sync::Semaphore::MAX_PERMITS,
163                bucket.semaphore.available_permits()
164            );
165        }
166    }
167
168    #[test]
169    fn test_bounded_permits_exhaustion() {
170        let bucket = TokenBucket::new(10);
171        let mut permits = Vec::new();
172
173        for _ in 0..100 {
174            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
175            if let Some(p) = permit {
176                permits.push(p);
177            } else {
178                break;
179            }
180        }
181
182        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
183
184        // Verify next acquisition fails
185        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
186    }
187}