1use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16pub(crate) const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24#[allow(dead_code)]
25pub(crate) const DEFAULT_RETRY_COST: u32 = 14;
26#[allow(dead_code)]
27pub(crate) const DEFAULT_RETRY_TIMEOUT_COST: u32 = 14;
28#[allow(dead_code)]
29pub(crate) const THROTTLING_RETRY_COST: u32 = 5;
30
31const LEGACY_RETRY_COST: u32 = 5;
33const LEGACY_RETRY_TIMEOUT_COST: u32 = LEGACY_RETRY_COST * 2;
34const PERMIT_REGENERATION_AMOUNT: usize = 1;
35const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
36
37#[derive(Clone, Debug)]
39pub struct TokenBucket {
40 semaphore: Arc<Semaphore>,
41 max_permits: usize,
42 timeout_retry_cost: u32,
43 retry_cost: u32,
44 throttling_retry_cost: u32,
45 success_reward: f32,
46 fractional_tokens: Arc<AtomicF32>,
47 refill_rate: f32,
48 last_refill_time_secs: Arc<AtomicU32>,
51}
52
53impl std::panic::UnwindSafe for AtomicF32 {}
54impl std::panic::RefUnwindSafe for AtomicF32 {}
55struct AtomicF32 {
56 storage: AtomicU32,
57}
58impl AtomicF32 {
59 fn new(value: f32) -> Self {
60 let as_u32 = value.to_bits();
61 Self {
62 storage: AtomicU32::new(as_u32),
63 }
64 }
65 fn store(&self, value: f32) {
66 let as_u32 = value.to_bits();
67 self.storage.store(as_u32, Ordering::Relaxed)
68 }
69 fn load(&self) -> f32 {
70 let as_u32 = self.storage.load(Ordering::Relaxed);
71 f32::from_bits(as_u32)
72 }
73}
74
75impl fmt::Debug for AtomicF32 {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("AtomicF32")
79 .field("value", &self.load())
80 .finish()
81 }
82}
83
84impl Clone for AtomicF32 {
85 fn clone(&self) -> Self {
86 AtomicF32 {
88 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
89 }
90 }
91}
92
93impl Storable for TokenBucket {
94 type Storer = StoreReplace<Self>;
95}
96
97impl Default for TokenBucket {
98 fn default() -> Self {
99 Self {
100 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
101 max_permits: DEFAULT_CAPACITY,
102 timeout_retry_cost: LEGACY_RETRY_TIMEOUT_COST,
103 retry_cost: LEGACY_RETRY_COST,
104 throttling_retry_cost: LEGACY_RETRY_COST,
105 success_reward: DEFAULT_SUCCESS_REWARD,
106 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
107 refill_rate: 0.0,
108 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
109 }
110 }
111}
112
113impl TokenBucket {
114 pub fn new(initial_quota: usize) -> Self {
116 Self {
117 semaphore: Arc::new(Semaphore::new(initial_quota)),
118 max_permits: initial_quota,
119 ..Default::default()
120 }
121 }
122
123 pub fn unlimited() -> Self {
125 Self {
126 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
127 max_permits: MAXIMUM_CAPACITY,
128 timeout_retry_cost: 0,
129 retry_cost: 0,
130 throttling_retry_cost: 0,
131 success_reward: 0.0,
132 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
133 refill_rate: 0.0,
134 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
135 }
136 }
137
138 pub fn builder() -> TokenBucketBuilder {
140 TokenBucketBuilder::default()
141 }
142
143 pub(crate) fn acquire(
144 &self,
145 err: &ErrorKind,
146 time_source: &impl TimeSource,
147 ) -> Option<OwnedSemaphorePermit> {
148 self.refill_tokens_based_on_time(time_source);
150 self.convert_fractional_tokens();
152
153 let retry_cost = match err {
154 ErrorKind::TransientError => self.timeout_retry_cost,
155 ErrorKind::ThrottlingError => self.throttling_retry_cost,
156 _ => self.retry_cost,
157 };
158
159 self.semaphore
160 .clone()
161 .try_acquire_many_owned(retry_cost)
162 .ok()
163 }
164
165 pub(crate) fn success_reward(&self) -> f32 {
166 self.success_reward
167 }
168
169 pub(crate) fn regenerate_a_token(&self) {
170 self.add_permits(PERMIT_REGENERATION_AMOUNT);
171 }
172
173 #[inline]
177 fn convert_fractional_tokens(&self) {
178 let mut calc_fractional_tokens = self.fractional_tokens.load();
179 if !calc_fractional_tokens.is_finite() {
181 tracing::error!(
182 "Fractional tokens corrupted to: {}, resetting to 0.0",
183 calc_fractional_tokens
184 );
185 self.fractional_tokens.store(0.0);
186 return;
187 }
188
189 let full_tokens_accumulated = calc_fractional_tokens.floor();
190 if full_tokens_accumulated >= 1.0 {
191 self.add_permits(full_tokens_accumulated as usize);
192 calc_fractional_tokens -= full_tokens_accumulated;
193 }
194 self.fractional_tokens.store(calc_fractional_tokens);
196 }
197
198 #[inline]
202 fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
203 if self.refill_rate > 0.0 {
204 let current_time_secs = time_source
206 .now()
207 .duration_since(SystemTime::UNIX_EPOCH)
208 .unwrap_or(Duration::ZERO)
209 .as_secs() as u32;
210
211 let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
212
213 if current_time_secs == last_refill_secs {
215 return;
216 }
217
218 if self
221 .last_refill_time_secs
222 .compare_exchange(
223 last_refill_secs,
224 current_time_secs,
225 Ordering::Relaxed,
226 Ordering::Relaxed,
227 )
228 .is_err()
229 {
230 return;
232 }
233
234 let current_fractional = self.fractional_tokens.load();
236 let max_fractional = self.max_permits as f32;
237
238 if current_fractional >= max_fractional {
240 return;
241 }
242
243 let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
244 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
245
246 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
248 self.fractional_tokens.store(new_fractional);
249 }
250 }
251
252 #[inline]
253 pub(crate) fn reward_success(&self) {
254 if self.success_reward > 0.0 {
255 let current = self.fractional_tokens.load();
256 let max_fractional = self.max_permits as f32;
257 if current >= max_fractional {
259 return;
260 }
261 let new_fractional = (current + self.success_reward).min(max_fractional);
263 self.fractional_tokens.store(new_fractional);
264 }
265 }
266
267 pub(crate) fn add_permits(&self, amount: usize) {
268 let available = self.semaphore.available_permits();
269 if available >= self.max_permits {
270 return;
271 }
272 self.semaphore
273 .add_permits(amount.min(self.max_permits - available));
274 }
275
276 pub fn is_full(&self) -> bool {
278 self.convert_fractional_tokens();
279 self.semaphore.available_permits() >= self.max_permits
280 }
281
282 pub fn is_empty(&self) -> bool {
284 self.convert_fractional_tokens();
285 self.semaphore.available_permits() == 0
286 }
287
288 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
290 pub(crate) fn available_permits(&self) -> usize {
291 self.semaphore.available_permits()
292 }
293
294 #[allow(dead_code)]
296 #[doc(hidden)]
297 #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
298 pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
299 self.last_refill_time_secs.clone()
300 }
301}
302
303#[derive(Clone, Debug, Default)]
305pub struct TokenBucketBuilder {
306 capacity: Option<usize>,
307 retry_cost: Option<u32>,
308 throttling_retry_cost: Option<u32>,
309 timeout_retry_cost: Option<u32>,
310 success_reward: Option<f32>,
311 refill_rate: Option<f32>,
312}
313
314impl TokenBucketBuilder {
315 pub fn new() -> Self {
317 Self::default()
318 }
319
320 pub fn capacity(mut self, mut capacity: usize) -> Self {
322 if capacity > MAXIMUM_CAPACITY {
323 capacity = MAXIMUM_CAPACITY;
324 }
325 self.capacity = Some(capacity);
326 self
327 }
328
329 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
331 self.retry_cost = Some(retry_cost);
332 self
333 }
334
335 pub fn throttling_retry_cost(mut self, throttling_retry_cost: u32) -> Self {
337 self.throttling_retry_cost = Some(throttling_retry_cost);
338 self
339 }
340
341 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
343 self.timeout_retry_cost = Some(timeout_retry_cost);
344 self
345 }
346
347 pub fn success_reward(mut self, reward: f32) -> Self {
349 self.success_reward = Some(reward);
350 self
351 }
352
353 pub fn refill_rate(mut self, rate: f32) -> Self {
358 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
359 self.refill_rate = Some(validated_rate);
360 self
361 }
362
363 pub fn build(self) -> TokenBucket {
365 TokenBucket {
366 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
367 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
368 retry_cost: self.retry_cost.unwrap_or(LEGACY_RETRY_COST),
369 throttling_retry_cost: self.throttling_retry_cost.unwrap_or(LEGACY_RETRY_COST),
370 timeout_retry_cost: self.timeout_retry_cost.unwrap_or(LEGACY_RETRY_TIMEOUT_COST),
371 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
372 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
373 refill_rate: self.refill_rate.unwrap_or(0.0),
374 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381
382 use super::*;
383 use aws_smithy_async::test_util::ManualTimeSource;
384 use std::{sync::LazyLock, time::UNIX_EPOCH};
385
386 static TIME_SOURCE: LazyLock<ManualTimeSource> =
387 LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
388
389 #[test]
390 fn test_unlimited_token_bucket() {
391 let bucket = TokenBucket::unlimited();
392
393 assert!(bucket
395 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
396 .is_some());
397 assert!(bucket
398 .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
399 .is_some());
400
401 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
403
404 assert_eq!(bucket.retry_cost, 0);
406 assert_eq!(bucket.timeout_retry_cost, 0);
407
408 let mut permits = Vec::new();
410 for _ in 0..100 {
411 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
412 assert!(permit.is_some());
413 permits.push(permit);
414 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
416 }
417 }
418
419 #[test]
420 fn test_bounded_permits_exhaustion() {
421 let bucket = TokenBucket::new(10);
422 let mut permits = Vec::new();
423
424 for _ in 0..100 {
425 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
426 if let Some(p) = permit {
427 permits.push(p);
428 } else {
429 break;
430 }
431 }
432
433 assert_eq!(permits.len(), 2); assert!(bucket
437 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
438 .is_none());
439 }
440
441 #[test]
442 fn test_fractional_tokens_accumulate_and_convert() {
443 let bucket = TokenBucket::builder()
444 .capacity(10)
445 .success_reward(0.4)
446 .build();
447
448 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
450 assert_eq!(bucket.semaphore.available_permits(), 0);
451
452 bucket.reward_success();
454 bucket.convert_fractional_tokens();
455 assert_eq!(bucket.semaphore.available_permits(), 0);
456
457 bucket.reward_success();
459 bucket.convert_fractional_tokens();
460 assert_eq!(bucket.semaphore.available_permits(), 0);
461
462 bucket.reward_success();
464 bucket.convert_fractional_tokens();
465 assert_eq!(bucket.semaphore.available_permits(), 1);
466 }
467
468 #[test]
469 fn test_fractional_tokens_respect_max_capacity() {
470 let bucket = TokenBucket::builder()
471 .capacity(10)
472 .success_reward(2.0)
473 .build();
474
475 for _ in 0..20 {
476 bucket.reward_success();
477 }
478
479 assert!(bucket.semaphore.available_permits() == 10);
480 }
481
482 #[test]
483 fn test_convert_fractional_tokens() {
484 let test_cases = [
486 (0.7, 0, 0.7),
487 (1.0, 1, 0.0),
488 (2.3, 2, 0.3),
489 (5.8, 5, 0.8),
490 (10.0, 10, 0.0),
491 (f32::NAN, 0, 0.0),
493 (f32::INFINITY, 0, 0.0),
494 ];
495
496 for (input, expected_permits, expected_remaining) in test_cases {
497 let bucket = TokenBucket::builder().capacity(10).build();
498 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
499 let initial = bucket.semaphore.available_permits();
500
501 bucket.fractional_tokens.store(input);
502 bucket.convert_fractional_tokens();
503
504 assert_eq!(
505 bucket.semaphore.available_permits() - initial,
506 expected_permits
507 );
508 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
509 }
510 }
511
512 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
513 #[test]
514 fn test_builder_with_custom_values() {
515 let bucket = TokenBucket::builder()
516 .capacity(100)
517 .retry_cost(10)
518 .timeout_retry_cost(20)
519 .success_reward(0.5)
520 .refill_rate(2.5)
521 .build();
522
523 assert_eq!(bucket.max_permits, 100);
524 assert_eq!(bucket.retry_cost, 10);
525 assert_eq!(bucket.timeout_retry_cost, 20);
526 assert_eq!(bucket.success_reward, 0.5);
527 assert_eq!(bucket.refill_rate, 2.5);
528 }
529
530 #[test]
531 fn test_builder_refill_rate_validation() {
532 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
534 assert_eq!(bucket.refill_rate, 0.0);
535
536 let bucket = TokenBucket::builder().refill_rate(1.5).build();
538 assert_eq!(bucket.refill_rate, 1.5);
539
540 let bucket = TokenBucket::builder().refill_rate(0.0).build();
542 assert_eq!(bucket.refill_rate, 0.0);
543 }
544
545 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
546 #[test]
547 fn test_builder_custom_time_source() {
548 use aws_smithy_async::test_util::ManualTimeSource;
549 use std::time::UNIX_EPOCH;
550
551 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
553 let bucket = TokenBucket::builder()
554 .capacity(100)
555 .refill_rate(1.0)
556 .build();
557
558 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
560 assert_eq!(bucket.available_permits(), 0);
561
562 manual_time.advance(Duration::from_secs(5));
564
565 bucket.refill_tokens_based_on_time(&manual_time);
566 bucket.convert_fractional_tokens();
567
568 assert_eq!(bucket.available_permits(), 5);
570 }
571
572 #[test]
573 fn test_atomicf32_f32_to_bits_conversion_correctness() {
574 let test_values = vec![
576 0.0,
577 -0.0,
578 1.0,
579 -1.0,
580 f32::INFINITY,
581 f32::NEG_INFINITY,
582 f32::NAN,
583 f32::MIN,
584 f32::MAX,
585 f32::MIN_POSITIVE,
586 f32::EPSILON,
587 std::f32::consts::PI,
588 std::f32::consts::E,
589 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
594
595 for &expected in &test_values {
596 let atomic = AtomicF32::new(expected);
597 let actual = atomic.load();
598
599 if expected.is_nan() {
601 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
602 assert_eq!(expected.to_bits(), actual.to_bits());
604 } else {
605 assert_eq!(expected.to_bits(), actual.to_bits());
606 }
607 }
608 }
609
610 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
611 #[test]
612 fn test_atomicf32_store_load_preserves_exact_bits() {
613 let atomic = AtomicF32::new(0.0);
614
615 let critical_bit_patterns = vec![
618 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
628
629 for &expected_bits in &critical_bit_patterns {
630 let expected_f32 = f32::from_bits(expected_bits);
631 atomic.store(expected_f32);
632 let loaded_f32 = atomic.load();
633 let actual_bits = loaded_f32.to_bits();
634
635 assert_eq!(expected_bits, actual_bits);
636 }
637 }
638
639 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
640 #[test]
641 fn test_atomicf32_concurrent_store_load_safety() {
642 use std::sync::Arc;
643 use std::thread;
644
645 let atomic = Arc::new(AtomicF32::new(0.0));
646 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
647 let mut handles = Vec::new();
648
649 for &value in &test_values {
651 let atomic_clone = Arc::clone(&atomic);
652 let handle = thread::spawn(move || {
653 for _ in 0..1000 {
654 atomic_clone.store(value);
655 }
656 });
657 handles.push(handle);
658 }
659
660 let atomic_reader = Arc::clone(&atomic);
662 let reader_handle = thread::spawn(move || {
663 let mut readings = Vec::new();
664 for _ in 0..5000 {
665 let value = atomic_reader.load();
666 readings.push(value);
667 }
668 readings
669 });
670
671 for handle in handles {
673 handle.join().expect("Writer thread panicked");
674 }
675
676 let readings = reader_handle.join().expect("Reader thread panicked");
677
678 for &reading in &readings {
681 assert!(test_values.contains(&reading) || reading == 0.0);
682
683 assert!(
686 reading.is_finite() || reading == 0.0,
687 "Corrupted reading detected"
688 );
689 }
690 }
691
692 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
693 #[test]
694 fn test_atomicf32_stress_concurrent_access() {
695 use std::sync::{Arc, Barrier};
696 use std::thread;
697
698 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
699 let atomic = Arc::new(AtomicF32::new(0.0));
700 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
702
703 for i in 0..10 {
705 let atomic_clone = Arc::clone(&atomic);
706 let barrier_clone = Arc::clone(&barrier);
707 let handle = thread::spawn(move || {
708 barrier_clone.wait(); for _ in 0..10000 {
712 let value = i as f32;
713 atomic_clone.store(value);
714 let loaded = atomic_clone.load();
715 assert!(loaded >= 0.0 && loaded <= 9.0);
717 assert!(
718 expected_values.contains(&loaded),
719 "Got unexpected value: {}, expected one of {:?}",
720 loaded,
721 expected_values
722 );
723 }
724 });
725 handles.push(handle);
726 }
727
728 for handle in handles {
729 handle.join().unwrap();
730 }
731 }
732
733 #[test]
734 fn test_atomicf32_integration_with_token_bucket_usage() {
735 let atomic = AtomicF32::new(0.0);
736 let success_reward = 0.3;
737 let iterations = 5;
738
739 for _ in 1..=iterations {
741 let current = atomic.load();
742 atomic.store(current + success_reward);
743 }
744
745 let accumulated = atomic.load();
746 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
750 atomic.store(accumulated - full_tokens);
751 let remaining = atomic.load();
752
753 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
756 assert_eq!(remaining, expected_total - expected_total.floor());
757 }
758
759 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
760 #[test]
761 fn test_atomicf32_clone_creates_independent_copy() {
762 let original = AtomicF32::new(123.456);
763 let cloned = original.clone();
764
765 assert_eq!(original.load(), cloned.load());
767
768 original.store(999.0);
770 assert_eq!(
771 cloned.load(),
772 123.456,
773 "Clone should be unaffected by original changes"
774 );
775 assert_eq!(original.load(), 999.0, "Original should have new value");
776 }
777
778 #[test]
779 fn test_combined_time_and_success_rewards() {
780 use aws_smithy_async::test_util::ManualTimeSource;
781 use std::time::UNIX_EPOCH;
782
783 let time_source = ManualTimeSource::new(UNIX_EPOCH);
784 let current_time_secs = UNIX_EPOCH
785 .duration_since(SystemTime::UNIX_EPOCH)
786 .unwrap()
787 .as_secs() as u32;
788
789 let bucket = TokenBucket {
790 refill_rate: 1.0,
791 success_reward: 0.5,
792 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
793 semaphore: Arc::new(Semaphore::new(0)),
794 max_permits: 100,
795 ..Default::default()
796 };
797
798 bucket.reward_success();
800 bucket.reward_success();
801
802 time_source.advance(Duration::from_secs(2));
804
805 bucket.refill_tokens_based_on_time(&time_source);
808 bucket.convert_fractional_tokens();
809
810 assert_eq!(bucket.available_permits(), 3);
811 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
812 }
813
814 #[test]
815 fn test_refill_rates() {
816 use aws_smithy_async::test_util::ManualTimeSource;
817 use std::time::UNIX_EPOCH;
818 let test_cases = [
820 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
827
828 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
829 let time_source = ManualTimeSource::new(UNIX_EPOCH);
830 let current_time_secs = UNIX_EPOCH
831 .duration_since(SystemTime::UNIX_EPOCH)
832 .unwrap()
833 .as_secs() as u32;
834
835 let bucket = TokenBucket {
836 refill_rate,
837 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
838 semaphore: Arc::new(Semaphore::new(0)),
839 max_permits: 100,
840 ..Default::default()
841 };
842
843 time_source.advance(Duration::from_secs(elapsed_secs));
845
846 bucket.refill_tokens_based_on_time(&time_source);
847 bucket.convert_fractional_tokens();
848
849 assert_eq!(
850 bucket.available_permits(),
851 expected_permits,
852 "Rate {}: After {}s expected {} permits",
853 refill_rate,
854 elapsed_secs,
855 expected_permits
856 );
857 assert!(
858 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
859 "Rate {}: After {}s expected {} fractional, got {}",
860 refill_rate,
861 elapsed_secs,
862 expected_fractional,
863 bucket.fractional_tokens.load()
864 );
865 }
866 }
867
868 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
869 #[test]
870 fn test_rewards_capped_at_max_capacity() {
871 use aws_smithy_async::test_util::ManualTimeSource;
872 use std::time::UNIX_EPOCH;
873
874 let time_source = ManualTimeSource::new(UNIX_EPOCH);
875 let current_time_secs = UNIX_EPOCH
876 .duration_since(SystemTime::UNIX_EPOCH)
877 .unwrap()
878 .as_secs() as u32;
879
880 let bucket = TokenBucket {
881 refill_rate: 50.0,
882 success_reward: 2.0,
883 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
884 semaphore: Arc::new(Semaphore::new(5)),
885 max_permits: 10,
886 ..Default::default()
887 };
888
889 for _ in 0..50 {
891 bucket.reward_success();
892 }
893
894 assert_eq!(bucket.fractional_tokens.load(), 10.0);
896
897 time_source.advance(Duration::from_secs(100));
899
900 bucket.refill_tokens_based_on_time(&time_source);
903
904 assert_eq!(
906 bucket.fractional_tokens.load(),
907 10.0,
908 "Fractional tokens should be capped at max_permits"
909 );
910 bucket.convert_fractional_tokens();
912 assert_eq!(bucket.available_permits(), 10);
913 }
914
915 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
916 #[test]
917 fn test_concurrent_time_based_refill_no_over_generation() {
918 use aws_smithy_async::test_util::ManualTimeSource;
919 use std::sync::{Arc, Barrier};
920 use std::thread;
921 use std::time::UNIX_EPOCH;
922
923 let time_source = ManualTimeSource::new(UNIX_EPOCH);
924 let current_time_secs = UNIX_EPOCH
925 .duration_since(SystemTime::UNIX_EPOCH)
926 .unwrap()
927 .as_secs() as u32;
928
929 let bucket = Arc::new(TokenBucket {
931 refill_rate: 1.0,
932 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
933 semaphore: Arc::new(Semaphore::new(0)),
934 max_permits: 100,
935 ..Default::default()
936 });
937
938 time_source.advance(Duration::from_secs(10));
940 let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
941
942 let barrier = Arc::new(Barrier::new(100));
944 let mut handles = Vec::new();
945
946 for _ in 0..100 {
947 let bucket_clone1 = Arc::clone(&bucket);
948 let barrier_clone1 = Arc::clone(&barrier);
949 let time_source_clone1 = shared_time_source.clone();
950 let bucket_clone2 = Arc::clone(&bucket);
951 let barrier_clone2 = Arc::clone(&barrier);
952 let time_source_clone2 = shared_time_source.clone();
953
954 let handle1 = thread::spawn(move || {
955 barrier_clone1.wait();
957
958 bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
960 });
961
962 let handle2 = thread::spawn(move || {
963 barrier_clone2.wait();
965
966 bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
968 });
969 handles.push(handle1);
970 handles.push(handle2);
971 }
972
973 for handle in handles {
975 handle.join().unwrap();
976 }
977
978 bucket.convert_fractional_tokens();
980
981 assert_eq!(
984 bucket.available_permits(),
985 10,
986 "Only one thread should have added tokens, not all 100"
987 );
988
989 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
991 }
992
993 #[test]
995 fn test_is_full_accounts_for_fractional_tokens() {
996 let bucket = TokenBucket::builder()
997 .capacity(2)
998 .retry_cost(1)
999 .success_reward(0.9)
1000 .build();
1001
1002 assert!(bucket.is_full());
1003
1004 let _p1 = bucket
1005 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1006 .unwrap();
1007 let _p2 = bucket
1008 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1009 .unwrap();
1010
1011 assert!(bucket.is_empty());
1012
1013 bucket.reward_success();
1016 bucket.reward_success();
1017 bucket.reward_success();
1018
1019 assert!(bucket.is_full());
1022 assert!(!bucket.is_empty());
1023 }
1024
1025 #[test]
1026 fn test_is_empty_accounts_for_fractional_tokens() {
1027 let bucket = TokenBucket::builder()
1028 .capacity(10)
1029 .retry_cost(10)
1030 .success_reward(0.5)
1031 .build();
1032
1033 let _p = bucket
1034 .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1035 .unwrap();
1036 assert_eq!(bucket.semaphore.available_permits(), 0);
1037
1038 bucket.reward_success();
1040 assert!(bucket.is_empty());
1041
1042 bucket.reward_success();
1044 assert!(!bucket.is_empty());
1045 }
1046}