1use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29#[derive(Clone, Debug)]
31pub struct TokenBucket {
32 semaphore: Arc<Semaphore>,
33 max_permits: usize,
34 timeout_retry_cost: u32,
35 retry_cost: u32,
36 success_reward: f32,
37 fractional_tokens: Arc<AtomicF32>,
38 refill_rate: f32,
39 time_source: SharedTimeSource,
40 creation_time: SystemTime,
41 last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47 storage: AtomicU32,
48}
49impl AtomicF32 {
50 fn new(value: f32) -> Self {
51 let as_u32 = value.to_bits();
52 Self {
53 storage: AtomicU32::new(as_u32),
54 }
55 }
56 fn store(&self, value: f32) {
57 let as_u32 = value.to_bits();
58 self.storage.store(as_u32, Ordering::Relaxed)
59 }
60 fn load(&self) -> f32 {
61 let as_u32 = self.storage.load(Ordering::Relaxed);
62 f32::from_bits(as_u32)
63 }
64}
65
66impl fmt::Debug for AtomicF32 {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AtomicF32")
70 .field("value", &self.load())
71 .finish()
72 }
73}
74
75impl Clone for AtomicF32 {
76 fn clone(&self) -> Self {
77 AtomicF32 {
79 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80 }
81 }
82}
83
84impl Storable for TokenBucket {
85 type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89 fn default() -> Self {
90 let time_source = SharedTimeSource::default();
91 Self {
92 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93 max_permits: DEFAULT_CAPACITY,
94 timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95 retry_cost: DEFAULT_RETRY_COST,
96 success_reward: DEFAULT_SUCCESS_REWARD,
97 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98 refill_rate: 0.0,
99 time_source: time_source.clone(),
100 creation_time: time_source.now(),
101 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102 }
103 }
104}
105
106impl TokenBucket {
107 pub fn new(initial_quota: usize) -> Self {
109 let time_source = SharedTimeSource::default();
110 Self {
111 semaphore: Arc::new(Semaphore::new(initial_quota)),
112 max_permits: initial_quota,
113 time_source: time_source.clone(),
114 creation_time: time_source.now(),
115 ..Default::default()
116 }
117 }
118
119 pub fn unlimited() -> Self {
121 let time_source = SharedTimeSource::default();
122 Self {
123 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124 max_permits: MAXIMUM_CAPACITY,
125 timeout_retry_cost: 0,
126 retry_cost: 0,
127 success_reward: 0.0,
128 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129 refill_rate: 0.0,
130 time_source: time_source.clone(),
131 creation_time: time_source.now(),
132 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133 }
134 }
135
136 pub fn builder() -> TokenBucketBuilder {
138 TokenBucketBuilder::default()
139 }
140
141 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142 self.refill_tokens_based_on_time();
144 self.convert_fractional_tokens();
146
147 let retry_cost = if err == &ErrorKind::TransientError {
148 self.timeout_retry_cost
149 } else {
150 self.retry_cost
151 };
152
153 self.semaphore
154 .clone()
155 .try_acquire_many_owned(retry_cost)
156 .ok()
157 }
158
159 pub(crate) fn success_reward(&self) -> f32 {
160 self.success_reward
161 }
162
163 pub(crate) fn regenerate_a_token(&self) {
164 self.add_permits(PERMIT_REGENERATION_AMOUNT);
165 }
166
167 #[inline]
171 fn convert_fractional_tokens(&self) {
172 let mut calc_fractional_tokens = self.fractional_tokens.load();
173 if !calc_fractional_tokens.is_finite() {
175 tracing::error!(
176 "Fractional tokens corrupted to: {}, resetting to 0.0",
177 calc_fractional_tokens
178 );
179 self.fractional_tokens.store(0.0);
180 return;
181 }
182
183 let full_tokens_accumulated = calc_fractional_tokens.floor();
184 if full_tokens_accumulated >= 1.0 {
185 self.add_permits(full_tokens_accumulated as usize);
186 calc_fractional_tokens -= full_tokens_accumulated;
187 }
188 self.fractional_tokens.store(calc_fractional_tokens);
190 }
191
192 #[inline]
196 fn refill_tokens_based_on_time(&self) {
197 if self.refill_rate > 0.0 {
198 let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200 let current_time = self.time_source.now();
202 let current_age_secs = current_time
203 .duration_since(self.creation_time)
204 .unwrap_or(Duration::ZERO)
205 .as_secs() as u32;
206
207 if current_age_secs == last_refill_secs {
209 return;
210 }
211 if self
214 .last_refill_age_secs
215 .compare_exchange(
216 last_refill_secs,
217 current_age_secs,
218 Ordering::Relaxed,
219 Ordering::Relaxed,
220 )
221 .is_err()
222 {
223 return;
225 }
226
227 let current_fractional = self.fractional_tokens.load();
229 let max_fractional = self.max_permits as f32;
230
231 if current_fractional >= max_fractional {
233 return;
234 }
235
236 let elapsed_secs = current_age_secs - last_refill_secs;
237 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241 self.fractional_tokens.store(new_fractional);
242 }
243 }
244
245 #[inline]
246 pub(crate) fn reward_success(&self) {
247 if self.success_reward > 0.0 {
248 let current = self.fractional_tokens.load();
249 let max_fractional = self.max_permits as f32;
250 if current >= max_fractional {
252 return;
253 }
254 let new_fractional = (current + self.success_reward).min(max_fractional);
256 self.fractional_tokens.store(new_fractional);
257 }
258 }
259
260 pub(crate) fn add_permits(&self, amount: usize) {
261 let available = self.semaphore.available_permits();
262 if available >= self.max_permits {
263 return;
264 }
265 self.semaphore
266 .add_permits(amount.min(self.max_permits - available));
267 }
268
269 pub fn is_full(&self) -> bool {
271 self.semaphore.available_permits() >= self.max_permits
272 }
273
274 pub fn is_empty(&self) -> bool {
276 self.semaphore.available_permits() == 0
277 }
278
279 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281 pub(crate) fn available_permits(&self) -> usize {
282 self.semaphore.available_permits()
283 }
284}
285
286#[derive(Clone, Debug, Default)]
288pub struct TokenBucketBuilder {
289 capacity: Option<usize>,
290 retry_cost: Option<u32>,
291 timeout_retry_cost: Option<u32>,
292 success_reward: Option<f32>,
293 refill_rate: Option<f32>,
294 time_source: Option<SharedTimeSource>,
295}
296
297impl TokenBucketBuilder {
298 pub fn new() -> Self {
300 Self::default()
301 }
302
303 pub fn capacity(mut self, mut capacity: usize) -> Self {
305 if capacity > MAXIMUM_CAPACITY {
306 capacity = MAXIMUM_CAPACITY;
307 }
308 self.capacity = Some(capacity);
309 self
310 }
311
312 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
314 self.retry_cost = Some(retry_cost);
315 self
316 }
317
318 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
320 self.timeout_retry_cost = Some(timeout_retry_cost);
321 self
322 }
323
324 pub fn success_reward(mut self, reward: f32) -> Self {
326 self.success_reward = Some(reward);
327 self
328 }
329
330 pub fn refill_rate(mut self, rate: f32) -> Self {
335 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
336 self.refill_rate = Some(validated_rate);
337 self
338 }
339
340 pub fn time_source(
344 mut self,
345 time_source: impl aws_smithy_async::time::TimeSource + 'static,
346 ) -> Self {
347 self.time_source = Some(SharedTimeSource::new(time_source));
348 self
349 }
350
351 pub fn build(self) -> TokenBucket {
353 let time_source = self.time_source.unwrap_or_default();
354 TokenBucket {
355 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
356 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
357 retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
358 timeout_retry_cost: self
359 .timeout_retry_cost
360 .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
361 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
362 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
363 refill_rate: self.refill_rate.unwrap_or(0.0),
364 time_source: time_source.clone(),
365 creation_time: time_source.now(),
366 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use aws_smithy_async::time::TimeSource;
375
376 #[test]
377 fn test_unlimited_token_bucket() {
378 let bucket = TokenBucket::unlimited();
379
380 assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
382 assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
383
384 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
386
387 assert_eq!(bucket.retry_cost, 0);
389 assert_eq!(bucket.timeout_retry_cost, 0);
390
391 let mut permits = Vec::new();
393 for _ in 0..100 {
394 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
395 assert!(permit.is_some());
396 permits.push(permit);
397 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
399 }
400 }
401
402 #[test]
403 fn test_bounded_permits_exhaustion() {
404 let bucket = TokenBucket::new(10);
405 let mut permits = Vec::new();
406
407 for _ in 0..100 {
408 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
409 if let Some(p) = permit {
410 permits.push(p);
411 } else {
412 break;
413 }
414 }
415
416 assert_eq!(permits.len(), 2); assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
420 }
421
422 #[test]
423 fn test_fractional_tokens_accumulate_and_convert() {
424 let bucket = TokenBucket::builder()
425 .capacity(10)
426 .success_reward(0.4)
427 .build();
428
429 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
431 assert_eq!(bucket.semaphore.available_permits(), 0);
432
433 bucket.reward_success();
435 bucket.convert_fractional_tokens();
436 assert_eq!(bucket.semaphore.available_permits(), 0);
437
438 bucket.reward_success();
440 bucket.convert_fractional_tokens();
441 assert_eq!(bucket.semaphore.available_permits(), 0);
442
443 bucket.reward_success();
445 bucket.convert_fractional_tokens();
446 assert_eq!(bucket.semaphore.available_permits(), 1);
447 }
448
449 #[test]
450 fn test_fractional_tokens_respect_max_capacity() {
451 let bucket = TokenBucket::builder()
452 .capacity(10)
453 .success_reward(2.0)
454 .build();
455
456 for _ in 0..20 {
457 bucket.reward_success();
458 }
459
460 assert!(bucket.semaphore.available_permits() == 10);
461 }
462
463 #[test]
464 fn test_convert_fractional_tokens() {
465 let test_cases = [
467 (0.7, 0, 0.7),
468 (1.0, 1, 0.0),
469 (2.3, 2, 0.3),
470 (5.8, 5, 0.8),
471 (10.0, 10, 0.0),
472 (f32::NAN, 0, 0.0),
474 (f32::INFINITY, 0, 0.0),
475 ];
476
477 for (input, expected_permits, expected_remaining) in test_cases {
478 let bucket = TokenBucket::builder().capacity(10).build();
479 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
480 let initial = bucket.semaphore.available_permits();
481
482 bucket.fractional_tokens.store(input);
483 bucket.convert_fractional_tokens();
484
485 assert_eq!(
486 bucket.semaphore.available_permits() - initial,
487 expected_permits
488 );
489 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
490 }
491 }
492
493 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
494 #[test]
495 fn test_builder_with_custom_values() {
496 let bucket = TokenBucket::builder()
497 .capacity(100)
498 .retry_cost(10)
499 .timeout_retry_cost(20)
500 .success_reward(0.5)
501 .refill_rate(2.5)
502 .build();
503
504 assert_eq!(bucket.max_permits, 100);
505 assert_eq!(bucket.retry_cost, 10);
506 assert_eq!(bucket.timeout_retry_cost, 20);
507 assert_eq!(bucket.success_reward, 0.5);
508 assert_eq!(bucket.refill_rate, 2.5);
509 }
510
511 #[test]
512 fn test_builder_refill_rate_validation() {
513 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
515 assert_eq!(bucket.refill_rate, 0.0);
516
517 let bucket = TokenBucket::builder().refill_rate(1.5).build();
519 assert_eq!(bucket.refill_rate, 1.5);
520
521 let bucket = TokenBucket::builder().refill_rate(0.0).build();
523 assert_eq!(bucket.refill_rate, 0.0);
524 }
525
526 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
527 #[test]
528 fn test_builder_custom_time_source() {
529 use aws_smithy_async::test_util::ManualTimeSource;
530 use std::time::UNIX_EPOCH;
531
532 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
534 let bucket = TokenBucket::builder()
535 .capacity(100)
536 .refill_rate(1.0)
537 .time_source(manual_time.clone())
538 .build();
539
540 assert_eq!(bucket.creation_time, UNIX_EPOCH);
542
543 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
545 assert_eq!(bucket.available_permits(), 0);
546
547 manual_time.advance(Duration::from_secs(5));
549
550 bucket.refill_tokens_based_on_time();
551 bucket.convert_fractional_tokens();
552
553 assert_eq!(bucket.available_permits(), 5);
555 }
556
557 #[test]
558 fn test_atomicf32_f32_to_bits_conversion_correctness() {
559 let test_values = vec![
561 0.0,
562 -0.0,
563 1.0,
564 -1.0,
565 f32::INFINITY,
566 f32::NEG_INFINITY,
567 f32::NAN,
568 f32::MIN,
569 f32::MAX,
570 f32::MIN_POSITIVE,
571 f32::EPSILON,
572 std::f32::consts::PI,
573 std::f32::consts::E,
574 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
579
580 for &expected in &test_values {
581 let atomic = AtomicF32::new(expected);
582 let actual = atomic.load();
583
584 if expected.is_nan() {
586 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
587 assert_eq!(expected.to_bits(), actual.to_bits());
589 } else {
590 assert_eq!(expected.to_bits(), actual.to_bits());
591 }
592 }
593 }
594
595 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
596 #[test]
597 fn test_atomicf32_store_load_preserves_exact_bits() {
598 let atomic = AtomicF32::new(0.0);
599
600 let critical_bit_patterns = vec![
603 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
613
614 for &expected_bits in &critical_bit_patterns {
615 let expected_f32 = f32::from_bits(expected_bits);
616 atomic.store(expected_f32);
617 let loaded_f32 = atomic.load();
618 let actual_bits = loaded_f32.to_bits();
619
620 assert_eq!(expected_bits, actual_bits);
621 }
622 }
623
624 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
625 #[test]
626 fn test_atomicf32_concurrent_store_load_safety() {
627 use std::sync::Arc;
628 use std::thread;
629
630 let atomic = Arc::new(AtomicF32::new(0.0));
631 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
632 let mut handles = Vec::new();
633
634 for &value in &test_values {
636 let atomic_clone = Arc::clone(&atomic);
637 let handle = thread::spawn(move || {
638 for _ in 0..1000 {
639 atomic_clone.store(value);
640 }
641 });
642 handles.push(handle);
643 }
644
645 let atomic_reader = Arc::clone(&atomic);
647 let reader_handle = thread::spawn(move || {
648 let mut readings = Vec::new();
649 for _ in 0..5000 {
650 let value = atomic_reader.load();
651 readings.push(value);
652 }
653 readings
654 });
655
656 for handle in handles {
658 handle.join().expect("Writer thread panicked");
659 }
660
661 let readings = reader_handle.join().expect("Reader thread panicked");
662
663 for &reading in &readings {
666 assert!(test_values.contains(&reading) || reading == 0.0);
667
668 assert!(
671 reading.is_finite() || reading == 0.0,
672 "Corrupted reading detected"
673 );
674 }
675 }
676
677 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
678 #[test]
679 fn test_atomicf32_stress_concurrent_access() {
680 use std::sync::{Arc, Barrier};
681 use std::thread;
682
683 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
684 let atomic = Arc::new(AtomicF32::new(0.0));
685 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
687
688 for i in 0..10 {
690 let atomic_clone = Arc::clone(&atomic);
691 let barrier_clone = Arc::clone(&barrier);
692 let handle = thread::spawn(move || {
693 barrier_clone.wait(); for _ in 0..10000 {
697 let value = i as f32;
698 atomic_clone.store(value);
699 let loaded = atomic_clone.load();
700 assert!(loaded >= 0.0 && loaded <= 9.0);
702 assert!(
703 expected_values.contains(&loaded),
704 "Got unexpected value: {}, expected one of {:?}",
705 loaded,
706 expected_values
707 );
708 }
709 });
710 handles.push(handle);
711 }
712
713 for handle in handles {
714 handle.join().unwrap();
715 }
716 }
717
718 #[test]
719 fn test_atomicf32_integration_with_token_bucket_usage() {
720 let atomic = AtomicF32::new(0.0);
721 let success_reward = 0.3;
722 let iterations = 5;
723
724 for _ in 1..=iterations {
726 let current = atomic.load();
727 atomic.store(current + success_reward);
728 }
729
730 let accumulated = atomic.load();
731 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
735 atomic.store(accumulated - full_tokens);
736 let remaining = atomic.load();
737
738 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
741 assert_eq!(remaining, expected_total - expected_total.floor());
742 }
743
744 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
745 #[test]
746 fn test_atomicf32_clone_creates_independent_copy() {
747 let original = AtomicF32::new(123.456);
748 let cloned = original.clone();
749
750 assert_eq!(original.load(), cloned.load());
752
753 original.store(999.0);
755 assert_eq!(
756 cloned.load(),
757 123.456,
758 "Clone should be unaffected by original changes"
759 );
760 assert_eq!(original.load(), 999.0, "Original should have new value");
761 }
762
763 #[test]
764 fn test_combined_time_and_success_rewards() {
765 use aws_smithy_async::test_util::ManualTimeSource;
766 use std::time::UNIX_EPOCH;
767
768 let time_source = ManualTimeSource::new(UNIX_EPOCH);
769 let bucket = TokenBucket {
770 refill_rate: 1.0,
771 success_reward: 0.5,
772 time_source: time_source.clone().into(),
773 creation_time: time_source.now(),
774 semaphore: Arc::new(Semaphore::new(0)),
775 max_permits: 100,
776 ..Default::default()
777 };
778
779 bucket.reward_success();
781 bucket.reward_success();
782
783 time_source.advance(Duration::from_secs(2));
785
786 bucket.refill_tokens_based_on_time();
789 bucket.convert_fractional_tokens();
790
791 assert_eq!(bucket.available_permits(), 3);
792 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
793 }
794
795 #[test]
796 fn test_refill_rates() {
797 use aws_smithy_async::test_util::ManualTimeSource;
798 use std::time::UNIX_EPOCH;
799 let test_cases = [
801 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
808
809 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
810 let time_source = ManualTimeSource::new(UNIX_EPOCH);
811 let bucket = TokenBucket {
812 refill_rate,
813 time_source: time_source.clone().into(),
814 creation_time: time_source.now(),
815 semaphore: Arc::new(Semaphore::new(0)),
816 max_permits: 100,
817 ..Default::default()
818 };
819
820 time_source.advance(Duration::from_secs(elapsed_secs));
822
823 bucket.refill_tokens_based_on_time();
824 bucket.convert_fractional_tokens();
825
826 assert_eq!(
827 bucket.available_permits(),
828 expected_permits,
829 "Rate {}: After {}s expected {} permits",
830 refill_rate,
831 elapsed_secs,
832 expected_permits
833 );
834 assert!(
835 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
836 "Rate {}: After {}s expected {} fractional, got {}",
837 refill_rate,
838 elapsed_secs,
839 expected_fractional,
840 bucket.fractional_tokens.load()
841 );
842 }
843 }
844
845 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
846 #[test]
847 fn test_rewards_capped_at_max_capacity() {
848 use aws_smithy_async::test_util::ManualTimeSource;
849 use std::time::UNIX_EPOCH;
850
851 let time_source = ManualTimeSource::new(UNIX_EPOCH);
852 let bucket = TokenBucket {
853 refill_rate: 50.0,
854 success_reward: 2.0,
855 time_source: time_source.clone().into(),
856 creation_time: time_source.now(),
857 semaphore: Arc::new(Semaphore::new(5)),
858 max_permits: 10,
859 ..Default::default()
860 };
861
862 for _ in 0..50 {
864 bucket.reward_success();
865 }
866
867 assert_eq!(bucket.fractional_tokens.load(), 10.0);
869
870 time_source.advance(Duration::from_secs(100));
872
873 bucket.refill_tokens_based_on_time();
876
877 assert_eq!(
879 bucket.fractional_tokens.load(),
880 10.0,
881 "Fractional tokens should be capped at max_permits"
882 );
883 bucket.convert_fractional_tokens();
885 assert_eq!(bucket.available_permits(), 10);
886 }
887
888 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
889 #[test]
890 fn test_concurrent_time_based_refill_no_over_generation() {
891 use aws_smithy_async::test_util::ManualTimeSource;
892 use std::sync::{Arc, Barrier};
893 use std::thread;
894 use std::time::UNIX_EPOCH;
895
896 let time_source = ManualTimeSource::new(UNIX_EPOCH);
897 let bucket = Arc::new(TokenBucket {
899 refill_rate: 1.0,
900 time_source: time_source.clone().into(),
901 creation_time: time_source.now(),
902 semaphore: Arc::new(Semaphore::new(0)),
903 max_permits: 100,
904 ..Default::default()
905 });
906
907 time_source.advance(Duration::from_secs(10));
909
910 let barrier = Arc::new(Barrier::new(100));
912 let mut handles = Vec::new();
913
914 for _ in 0..100 {
915 let bucket_clone1 = Arc::clone(&bucket);
916 let barrier_clone1 = Arc::clone(&barrier);
917 let bucket_clone2 = Arc::clone(&bucket);
918 let barrier_clone2 = Arc::clone(&barrier);
919
920 let handle1 = thread::spawn(move || {
921 barrier_clone1.wait();
923
924 bucket_clone1.refill_tokens_based_on_time();
926 });
927
928 let handle2 = thread::spawn(move || {
929 barrier_clone2.wait();
931
932 bucket_clone2.refill_tokens_based_on_time();
934 });
935 handles.push(handle1);
936 handles.push(handle2);
937 }
938
939 for handle in handles {
941 handle.join().unwrap();
942 }
943
944 bucket.convert_fractional_tokens();
946
947 assert_eq!(
950 bucket.available_permits(),
951 10,
952 "Only one thread should have added tokens, not all 100"
953 );
954
955 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
957 }
958}