aws_smithy_runtime/client/retries/
token_bucket.rs1use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use aws_smithy_types::retry::ErrorKind;
8use std::sync::Arc;
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::trace;
11
12const DEFAULT_CAPACITY: usize = 500;
13const RETRY_COST: u32 = 5;
14const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
15const PERMIT_REGENERATION_AMOUNT: usize = 1;
16
17#[derive(Clone, Debug)]
19pub struct TokenBucket {
20 semaphore: Arc<Semaphore>,
21 max_permits: usize,
22 timeout_retry_cost: u32,
23 retry_cost: u32,
24}
25
26impl Storable for TokenBucket {
27 type Storer = StoreReplace<Self>;
28}
29
30impl Default for TokenBucket {
31 fn default() -> Self {
32 Self {
33 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
34 max_permits: DEFAULT_CAPACITY,
35 timeout_retry_cost: RETRY_TIMEOUT_COST,
36 retry_cost: RETRY_COST,
37 }
38 }
39}
40
41impl TokenBucket {
42 pub fn new(initial_quota: usize) -> Self {
44 Self {
45 semaphore: Arc::new(Semaphore::new(initial_quota)),
46 max_permits: initial_quota,
47 ..Default::default()
48 }
49 }
50
51 pub fn unlimited() -> Self {
53 Self {
54 semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
55 max_permits: Semaphore::MAX_PERMITS,
56 timeout_retry_cost: 0,
57 retry_cost: 0,
58 }
59 }
60
61 pub fn builder() -> TokenBucketBuilder {
63 TokenBucketBuilder::default()
64 }
65
66 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
67 let retry_cost = if err == &ErrorKind::TransientError {
68 self.timeout_retry_cost
69 } else {
70 self.retry_cost
71 };
72
73 self.semaphore
74 .clone()
75 .try_acquire_many_owned(retry_cost)
76 .ok()
77 }
78
79 pub(crate) fn regenerate_a_token(&self) {
80 if self.semaphore.available_permits() < self.max_permits {
81 trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
82 self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
83 }
84 }
85
86 #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
87 pub(crate) fn available_permits(&self) -> usize {
88 self.semaphore.available_permits()
89 }
90
91 pub fn is_full(&self) -> bool {
93 self.semaphore.available_permits() >= self.max_permits
94 }
95
96 pub fn is_empty(&self) -> bool {
98 self.semaphore.available_permits() == 0
99 }
100}
101
102#[derive(Clone, Debug, Default)]
104pub struct TokenBucketBuilder {
105 capacity: Option<usize>,
106 retry_cost: Option<u32>,
107 timeout_retry_cost: Option<u32>,
108}
109
110impl TokenBucketBuilder {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn capacity(mut self, capacity: usize) -> Self {
118 self.capacity = Some(capacity);
119 self
120 }
121
122 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
124 self.retry_cost = Some(retry_cost);
125 self
126 }
127
128 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
130 self.timeout_retry_cost = Some(timeout_retry_cost);
131 self
132 }
133
134 pub fn build(self) -> TokenBucket {
136 TokenBucket {
137 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
138 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
139 retry_cost: self.retry_cost.unwrap_or(RETRY_COST),
140 timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST),
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_unlimited_token_bucket() {
151 let bucket = TokenBucket::unlimited();
152
153 assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
155 assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
156
157 assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS);
159
160 assert_eq!(bucket.retry_cost, 0);
162 assert_eq!(bucket.timeout_retry_cost, 0);
163
164 let mut permits = Vec::new();
166 for _ in 0..100 {
167 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
168 assert!(permit.is_some());
169 permits.push(permit);
170 assert_eq!(
172 tokio::sync::Semaphore::MAX_PERMITS,
173 bucket.semaphore.available_permits()
174 );
175 }
176 }
177
178 #[test]
179 fn test_bounded_permits_exhaustion() {
180 let bucket = TokenBucket::new(10);
181 let mut permits = Vec::new();
182
183 for _ in 0..100 {
184 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
185 if let Some(p) = permit {
186 permits.push(p);
187 } else {
188 break;
189 }
190 }
191
192 assert_eq!(permits.len(), 2); assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
196 }
197}