aws_smithy_runtime/client/retries/
token_bucket.rs1use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use aws_smithy_types::retry::ErrorKind;
8use std::sync::Arc;
9use tokio::sync::{OwnedSemaphorePermit, Semaphore};
10use tracing::trace;
11
12const DEFAULT_CAPACITY: usize = 500;
13const RETRY_COST: u32 = 5;
14const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2;
15const PERMIT_REGENERATION_AMOUNT: usize = 1;
16
17#[derive(Clone, Debug)]
19pub struct TokenBucket {
20 semaphore: Arc<Semaphore>,
21 max_permits: usize,
22 timeout_retry_cost: u32,
23 retry_cost: u32,
24}
25
26impl Storable for TokenBucket {
27 type Storer = StoreReplace<Self>;
28}
29
30impl Default for TokenBucket {
31 fn default() -> Self {
32 Self {
33 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
34 max_permits: DEFAULT_CAPACITY,
35 timeout_retry_cost: RETRY_TIMEOUT_COST,
36 retry_cost: RETRY_COST,
37 }
38 }
39}
40
41impl TokenBucket {
42 pub fn new(initial_quota: usize) -> Self {
44 Self {
45 semaphore: Arc::new(Semaphore::new(initial_quota)),
46 max_permits: initial_quota,
47 ..Default::default()
48 }
49 }
50
51 pub fn unlimited() -> Self {
53 Self {
54 semaphore: Arc::new(Semaphore::new(Semaphore::MAX_PERMITS)),
55 max_permits: Semaphore::MAX_PERMITS,
56 timeout_retry_cost: 0,
57 retry_cost: 0,
58 }
59 }
60
61 pub fn builder() -> TokenBucketBuilder {
63 TokenBucketBuilder::default()
64 }
65
66 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
67 let retry_cost = if err == &ErrorKind::TransientError {
68 self.timeout_retry_cost
69 } else {
70 self.retry_cost
71 };
72
73 self.semaphore
74 .clone()
75 .try_acquire_many_owned(retry_cost)
76 .ok()
77 }
78
79 pub(crate) fn regenerate_a_token(&self) {
80 if self.semaphore.available_permits() < self.max_permits {
81 trace!("adding {PERMIT_REGENERATION_AMOUNT} back into the bucket");
82 self.semaphore.add_permits(PERMIT_REGENERATION_AMOUNT)
83 }
84 }
85
86 #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
87 pub(crate) fn available_permits(&self) -> usize {
88 self.semaphore.available_permits()
89 }
90}
91
92#[derive(Clone, Debug, Default)]
94pub struct TokenBucketBuilder {
95 capacity: Option<usize>,
96 retry_cost: Option<u32>,
97 timeout_retry_cost: Option<u32>,
98}
99
100impl TokenBucketBuilder {
101 pub fn new() -> Self {
103 Self::default()
104 }
105
106 pub fn capacity(mut self, capacity: usize) -> Self {
108 self.capacity = Some(capacity);
109 self
110 }
111
112 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
114 self.retry_cost = Some(retry_cost);
115 self
116 }
117
118 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
120 self.timeout_retry_cost = Some(timeout_retry_cost);
121 self
122 }
123
124 pub fn build(self) -> TokenBucket {
126 TokenBucket {
127 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
128 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
129 retry_cost: self.retry_cost.unwrap_or(RETRY_COST),
130 timeout_retry_cost: self.timeout_retry_cost.unwrap_or(RETRY_TIMEOUT_COST),
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_unlimited_token_bucket() {
141 let bucket = TokenBucket::unlimited();
142
143 assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
145 assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
146
147 assert_eq!(bucket.max_permits, Semaphore::MAX_PERMITS);
149
150 assert_eq!(bucket.retry_cost, 0);
152 assert_eq!(bucket.timeout_retry_cost, 0);
153
154 let mut permits = Vec::new();
156 for _ in 0..100 {
157 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
158 assert!(permit.is_some());
159 permits.push(permit);
160 assert_eq!(
162 tokio::sync::Semaphore::MAX_PERMITS,
163 bucket.semaphore.available_permits()
164 );
165 }
166 }
167
168 #[test]
169 fn test_bounded_permits_exhaustion() {
170 let bucket = TokenBucket::new(10);
171 let mut permits = Vec::new();
172
173 for _ in 0..100 {
174 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
175 if let Some(p) = permit {
176 permits.push(p);
177 } else {
178 break;
179 }
180 }
181
182 assert_eq!(permits.len(), 2); assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
186 }
187}