aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use aws_smithy_types::retry::ErrorKind;
8use std::sync::Arc;
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::trace;
11
12const DEFAULT_CAPACITY: usize = 500;
13const RETRY_COST: u32 = 5;
14const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
15const PERMIT_REGENERATION_AMOUNT: usize = 1;
16
17/// Token bucket used for standard and adaptive retry.
18#[derive(Clone, Debug)]
19pub struct TokenBucket {
20    semaphore: Arc<Semaphore>,
21    max_permits: usize,
22    timeout_retry_cost: u32,
23    retry_cost: u32,
24}
25
26impl Storable for TokenBucket {
27    type Storer = StoreReplace<Self>;
28}
29
30impl Default for TokenBucket {
31    fn default() -> Self {
32        Self {
33            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
34            max_permits: DEFAULT_CAPACITY,
35            timeout_retry_cost: RETRY_TIMEOUT_COST,
36            retry_cost: RETRY_COST,
37        }
38    }
39}
40
41impl TokenBucket {
42    /// Creates a new `TokenBucket` with the given initial quota.
43    pub fn new(initial_quota: usize) -> Self {
44        Self {
45            semaphore: Arc::new(Semaphore::new(initial_quota)),
46            max_permits: initial_quota,
47            ..Default::default()
48        }
49    }
50
51    /// A token bucket with unlimited capacity that allows retries at no cost.
52    pub fn unlimited() -> Self {
53        Self {
54            semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
55            max_permits: Semaphore::MAX_PERMITS,
56            timeout_retry_cost: 0,
57            retry_cost: 0,
58        }
59    }
60
61    /// Creates a builder for constructing a `TokenBucket`.
62    pub fn builder() -> TokenBucketBuilder {
63        TokenBucketBuilder::default()
64    }
65
66    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
67        let retry_cost = if err == &ErrorKind::TransientError {
68            self.timeout_retry_cost
69        } else {
70            self.retry_cost
71        };
72
73        self.semaphore
74            .clone()
75            .try_acquire_many_owned(retry_cost)
76            .ok()
77    }
78
79    pub(crate) fn regenerate_a_token(&self) {
80        if self.semaphore.available_permits() < self.max_permits {
81            trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
82            self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
83        }
84    }
85
86    #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
87    pub(crate) fn available_permits(&self) -> usize {
88        self.semaphore.available_permits()
89    }
90
91    /// Returns true if the token bucket is full, false otherwise
92    pub fn is_full(&self) -> bool {
93        self.semaphore.available_permits() >= self.max_permits
94    }
95
96    /// Returns true if the token bucket is empty, false otherwise
97    pub fn is_empty(&self) -> bool {
98        self.semaphore.available_permits() == 0
99    }
100}
101
102/// Builder for constructing a `TokenBucket`.
103#[derive(Clone, Debug, Default)]
104pub struct TokenBucketBuilder {
105    capacity: Option<usize>,
106    retry_cost: Option<u32>,
107    timeout_retry_cost: Option<u32>,
108}
109
110impl TokenBucketBuilder {
111    /// Creates a new `TokenBucketBuilder` with default values.
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Sets the maximum bucket capacity for the builder.
117    pub fn capacity(mut self, capacity: usize) -> Self {
118        self.capacity = Some(capacity);
119        self
120    }
121
122    /// Sets the specified retry cost for the builder.
123    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
124        self.retry_cost = Some(retry_cost);
125        self
126    }
127
128    /// Sets the specified timeout retry cost for the builder.
129    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
130        self.timeout_retry_cost = Some(timeout_retry_cost);
131        self
132    }
133
134    /// Builds a `TokenBucket`.
135    pub fn build(self) -> TokenBucket {
136        TokenBucket {
137            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
138            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
139            retry_cost: self.retry_cost.unwrap_or(RETRY_COST),
140            timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST),
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_unlimited_token_bucket() {
151        let bucket = TokenBucket::unlimited();
152
153        // Should always acquire permits regardless of error type
154        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
155        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
156
157        // Should have maximum capacity
158        assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS);
159
160        // Should have zero retry costs
161        assert_eq!(bucket.retry_cost, 0);
162        assert_eq!(bucket.timeout_retry_cost, 0);
163
164        // The loop count is arbitrary; should obtain permits without limit
165        let mut permits = Vec::new();
166        for _ in 0..100 {
167            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
168            assert!(permit.is_some());
169            permits.push(permit);
170            // Available permits should stay constant
171            assert_eq!(
172                tokio::sync::Semaphore::MAX_PERMITS,
173                bucket.semaphore.available_permits()
174            );
175        }
176    }
177
178    #[test]
179    fn test_bounded_permits_exhaustion() {
180        let bucket = TokenBucket::new(10);
181        let mut permits = Vec::new();
182
183        for _ in 0..100 {
184            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
185            if let Some(p) = permit {
186                permits.push(p);
187            } else {
188                break;
189            }
190        }
191
192        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
193
194        // Verify next acquisition fails
195        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
196    }
197}