1use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29#[derive(Clone, Debug)]
31pub struct TokenBucket {
32 semaphore: Arc<Semaphore>,
33 max_permits: usize,
34 timeout_retry_cost: u32,
35 retry_cost: u32,
36 success_reward: f32,
37 fractional_tokens: Arc<AtomicF32>,
38 refill_rate: f32,
39 time_source: SharedTimeSource,
40 creation_time: SystemTime,
41 last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47 storage: AtomicU32,
48}
49impl AtomicF32 {
50 fn new(value: f32) -> Self {
51 let as_u32 = value.to_bits();
52 Self {
53 storage: AtomicU32::new(as_u32),
54 }
55 }
56 fn store(&self, value: f32) {
57 let as_u32 = value.to_bits();
58 self.storage.store(as_u32, Ordering::Relaxed)
59 }
60 fn load(&self) -> f32 {
61 let as_u32 = self.storage.load(Ordering::Relaxed);
62 f32::from_bits(as_u32)
63 }
64}
65
66impl fmt::Debug for AtomicF32 {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AtomicF32")
70 .field("value", &self.load())
71 .finish()
72 }
73}
74
75impl Clone for AtomicF32 {
76 fn clone(&self) -> Self {
77 AtomicF32 {
79 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80 }
81 }
82}
83
84impl Storable for TokenBucket {
85 type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89 fn default() -> Self {
90 let time_source = SharedTimeSource::default();
91 Self {
92 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93 max_permits: DEFAULT_CAPACITY,
94 timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95 retry_cost: DEFAULT_RETRY_COST,
96 success_reward: DEFAULT_SUCCESS_REWARD,
97 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98 refill_rate: 0.0,
99 time_source: time_source.clone(),
100 creation_time: time_source.now(),
101 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102 }
103 }
104}
105
106impl TokenBucket {
107 pub fn new(initial_quota: usize) -> Self {
109 let time_source = SharedTimeSource::default();
110 Self {
111 semaphore: Arc::new(Semaphore::new(initial_quota)),
112 max_permits: initial_quota,
113 time_source: time_source.clone(),
114 creation_time: time_source.now(),
115 ..Default::default()
116 }
117 }
118
119 pub fn unlimited() -> Self {
121 let time_source = SharedTimeSource::default();
122 Self {
123 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124 max_permits: MAXIMUM_CAPACITY,
125 timeout_retry_cost: 0,
126 retry_cost: 0,
127 success_reward: 0.0,
128 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129 refill_rate: 0.0,
130 time_source: time_source.clone(),
131 creation_time: time_source.now(),
132 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133 }
134 }
135
136 pub fn builder() -> TokenBucketBuilder {
138 TokenBucketBuilder::default()
139 }
140
141 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142 self.refill_tokens_based_on_time();
144 self.convert_fractional_tokens();
146
147 let retry_cost = if err == &ErrorKind::TransientError {
148 self.timeout_retry_cost
149 } else {
150 self.retry_cost
151 };
152
153 self.semaphore
154 .clone()
155 .try_acquire_many_owned(retry_cost)
156 .ok()
157 }
158
159 pub(crate) fn success_reward(&self) -> f32 {
160 self.success_reward
161 }
162
163 pub(crate) fn regenerate_a_token(&self) {
164 self.add_permits(PERMIT_REGENERATION_AMOUNT);
165 }
166
167 #[inline]
171 fn convert_fractional_tokens(&self) {
172 let mut calc_fractional_tokens = self.fractional_tokens.load();
173 if !calc_fractional_tokens.is_finite() {
175 tracing::error!(
176 "Fractional tokens corrupted to: {}, resetting to 0.0",
177 calc_fractional_tokens
178 );
179 self.fractional_tokens.store(0.0);
180 return;
181 }
182
183 let full_tokens_accumulated = calc_fractional_tokens.floor();
184 if full_tokens_accumulated >= 1.0 {
185 self.add_permits(full_tokens_accumulated as usize);
186 calc_fractional_tokens -= full_tokens_accumulated;
187 }
188 self.fractional_tokens.store(calc_fractional_tokens);
190 }
191
192 #[inline]
196 fn refill_tokens_based_on_time(&self) {
197 if self.refill_rate > 0.0 {
198 let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200 let current_time = self.time_source.now();
202 let current_age_secs = current_time
203 .duration_since(self.creation_time)
204 .unwrap_or(Duration::ZERO)
205 .as_secs() as u32;
206
207 if current_age_secs == last_refill_secs {
209 return;
210 }
211 if self
214 .last_refill_age_secs
215 .compare_exchange(
216 last_refill_secs,
217 current_age_secs,
218 Ordering::Relaxed,
219 Ordering::Relaxed,
220 )
221 .is_err()
222 {
223 return;
225 }
226
227 let current_fractional = self.fractional_tokens.load();
229 let max_fractional = self.max_permits as f32;
230
231 if current_fractional >= max_fractional {
233 return;
234 }
235
236 let elapsed_secs = current_age_secs - last_refill_secs;
237 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241 self.fractional_tokens.store(new_fractional);
242 }
243 }
244
245 #[inline]
246 pub(crate) fn reward_success(&self) {
247 if self.success_reward > 0.0 {
248 let current = self.fractional_tokens.load();
249 let max_fractional = self.max_permits as f32;
250 if current >= max_fractional {
252 return;
253 }
254 let new_fractional = (current + self.success_reward).min(max_fractional);
256 self.fractional_tokens.store(new_fractional);
257 }
258 }
259
260 pub(crate) fn add_permits(&self, amount: usize) {
261 let available = self.semaphore.available_permits();
262 if available >= self.max_permits {
263 return;
264 }
265 self.semaphore
266 .add_permits(amount.min(self.max_permits - available));
267 }
268
269 pub fn is_full(&self) -> bool {
271 self.semaphore.available_permits() >= self.max_permits
272 }
273
274 pub fn is_empty(&self) -> bool {
276 self.semaphore.available_permits() == 0
277 }
278
279 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281 pub(crate) fn available_permits(&self) -> usize {
282 self.semaphore.available_permits()
283 }
284
285 pub(crate) fn update_time_source(&mut self, new_time_source: SharedTimeSource) {
287 self.time_source = new_time_source;
288 }
289}
290
291#[derive(Clone, Debug, Default)]
293pub struct TokenBucketBuilder {
294 capacity: Option<usize>,
295 retry_cost: Option<u32>,
296 timeout_retry_cost: Option<u32>,
297 success_reward: Option<f32>,
298 refill_rate: Option<f32>,
299 time_source: Option<SharedTimeSource>,
300}
301
302impl TokenBucketBuilder {
303 pub fn new() -> Self {
305 Self::default()
306 }
307
308 pub fn capacity(mut self, mut capacity: usize) -> Self {
310 if capacity > MAXIMUM_CAPACITY {
311 capacity = MAXIMUM_CAPACITY;
312 }
313 self.capacity = Some(capacity);
314 self
315 }
316
317 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
319 self.retry_cost = Some(retry_cost);
320 self
321 }
322
323 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
325 self.timeout_retry_cost = Some(timeout_retry_cost);
326 self
327 }
328
329 pub fn success_reward(mut self, reward: f32) -> Self {
331 self.success_reward = Some(reward);
332 self
333 }
334
335 pub fn refill_rate(mut self, rate: f32) -> Self {
340 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
341 self.refill_rate = Some(validated_rate);
342 self
343 }
344
345 pub fn time_source(
349 mut self,
350 time_source: impl aws_smithy_async::time::TimeSource + 'static,
351 ) -> Self {
352 self.time_source = Some(SharedTimeSource::new(time_source));
353 self
354 }
355
356 pub fn build(self) -> TokenBucket {
358 let time_source = self.time_source.unwrap_or_default();
359 TokenBucket {
360 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
361 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
362 retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
363 timeout_retry_cost: self
364 .timeout_retry_cost
365 .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
366 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
367 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
368 refill_rate: self.refill_rate.unwrap_or(0.0),
369 time_source: time_source.clone(),
370 creation_time: time_source.now(),
371 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use aws_smithy_async::time::TimeSource;
380
381 #[test]
382 fn test_unlimited_token_bucket() {
383 let bucket = TokenBucket::unlimited();
384
385 assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
387 assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
388
389 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
391
392 assert_eq!(bucket.retry_cost, 0);
394 assert_eq!(bucket.timeout_retry_cost, 0);
395
396 let mut permits = Vec::new();
398 for _ in 0..100 {
399 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
400 assert!(permit.is_some());
401 permits.push(permit);
402 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
404 }
405 }
406
407 #[test]
408 fn test_bounded_permits_exhaustion() {
409 let bucket = TokenBucket::new(10);
410 let mut permits = Vec::new();
411
412 for _ in 0..100 {
413 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
414 if let Some(p) = permit {
415 permits.push(p);
416 } else {
417 break;
418 }
419 }
420
421 assert_eq!(permits.len(), 2); assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
425 }
426
427 #[test]
428 fn test_fractional_tokens_accumulate_and_convert() {
429 let bucket = TokenBucket::builder()
430 .capacity(10)
431 .success_reward(0.4)
432 .build();
433
434 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
436 assert_eq!(bucket.semaphore.available_permits(), 0);
437
438 bucket.reward_success();
440 bucket.convert_fractional_tokens();
441 assert_eq!(bucket.semaphore.available_permits(), 0);
442
443 bucket.reward_success();
445 bucket.convert_fractional_tokens();
446 assert_eq!(bucket.semaphore.available_permits(), 0);
447
448 bucket.reward_success();
450 bucket.convert_fractional_tokens();
451 assert_eq!(bucket.semaphore.available_permits(), 1);
452 }
453
454 #[test]
455 fn test_fractional_tokens_respect_max_capacity() {
456 let bucket = TokenBucket::builder()
457 .capacity(10)
458 .success_reward(2.0)
459 .build();
460
461 for _ in 0..20 {
462 bucket.reward_success();
463 }
464
465 assert!(bucket.semaphore.available_permits() == 10);
466 }
467
468 #[test]
469 fn test_convert_fractional_tokens() {
470 let test_cases = [
472 (0.7, 0, 0.7),
473 (1.0, 1, 0.0),
474 (2.3, 2, 0.3),
475 (5.8, 5, 0.8),
476 (10.0, 10, 0.0),
477 (f32::NAN, 0, 0.0),
479 (f32::INFINITY, 0, 0.0),
480 ];
481
482 for (input, expected_permits, expected_remaining) in test_cases {
483 let bucket = TokenBucket::builder().capacity(10).build();
484 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
485 let initial = bucket.semaphore.available_permits();
486
487 bucket.fractional_tokens.store(input);
488 bucket.convert_fractional_tokens();
489
490 assert_eq!(
491 bucket.semaphore.available_permits() - initial,
492 expected_permits
493 );
494 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
495 }
496 }
497
498 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
499 #[test]
500 fn test_builder_with_custom_values() {
501 let bucket = TokenBucket::builder()
502 .capacity(100)
503 .retry_cost(10)
504 .timeout_retry_cost(20)
505 .success_reward(0.5)
506 .refill_rate(2.5)
507 .build();
508
509 assert_eq!(bucket.max_permits, 100);
510 assert_eq!(bucket.retry_cost, 10);
511 assert_eq!(bucket.timeout_retry_cost, 20);
512 assert_eq!(bucket.success_reward, 0.5);
513 assert_eq!(bucket.refill_rate, 2.5);
514 }
515
516 #[test]
517 fn test_builder_refill_rate_validation() {
518 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
520 assert_eq!(bucket.refill_rate, 0.0);
521
522 let bucket = TokenBucket::builder().refill_rate(1.5).build();
524 assert_eq!(bucket.refill_rate, 1.5);
525
526 let bucket = TokenBucket::builder().refill_rate(0.0).build();
528 assert_eq!(bucket.refill_rate, 0.0);
529 }
530
531 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
532 #[test]
533 fn test_builder_custom_time_source() {
534 use aws_smithy_async::test_util::ManualTimeSource;
535 use std::time::UNIX_EPOCH;
536
537 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
539 let bucket = TokenBucket::builder()
540 .capacity(100)
541 .refill_rate(1.0)
542 .time_source(manual_time.clone())
543 .build();
544
545 assert_eq!(bucket.creation_time, UNIX_EPOCH);
547
548 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
550 assert_eq!(bucket.available_permits(), 0);
551
552 manual_time.advance(Duration::from_secs(5));
554
555 bucket.refill_tokens_based_on_time();
556 bucket.convert_fractional_tokens();
557
558 assert_eq!(bucket.available_permits(), 5);
560 }
561
562 #[test]
563 fn test_atomicf32_f32_to_bits_conversion_correctness() {
564 let test_values = vec![
566 0.0,
567 -0.0,
568 1.0,
569 -1.0,
570 f32::INFINITY,
571 f32::NEG_INFINITY,
572 f32::NAN,
573 f32::MIN,
574 f32::MAX,
575 f32::MIN_POSITIVE,
576 f32::EPSILON,
577 std::f32::consts::PI,
578 std::f32::consts::E,
579 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
584
585 for &expected in &test_values {
586 let atomic = AtomicF32::new(expected);
587 let actual = atomic.load();
588
589 if expected.is_nan() {
591 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
592 assert_eq!(expected.to_bits(), actual.to_bits());
594 } else {
595 assert_eq!(expected.to_bits(), actual.to_bits());
596 }
597 }
598 }
599
600 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
601 #[test]
602 fn test_atomicf32_store_load_preserves_exact_bits() {
603 let atomic = AtomicF32::new(0.0);
604
605 let critical_bit_patterns = vec![
608 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
618
619 for &expected_bits in &critical_bit_patterns {
620 let expected_f32 = f32::from_bits(expected_bits);
621 atomic.store(expected_f32);
622 let loaded_f32 = atomic.load();
623 let actual_bits = loaded_f32.to_bits();
624
625 assert_eq!(expected_bits, actual_bits);
626 }
627 }
628
629 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
630 #[test]
631 fn test_atomicf32_concurrent_store_load_safety() {
632 use std::sync::Arc;
633 use std::thread;
634
635 let atomic = Arc::new(AtomicF32::new(0.0));
636 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
637 let mut handles = Vec::new();
638
639 for &value in &test_values {
641 let atomic_clone = Arc::clone(&atomic);
642 let handle = thread::spawn(move || {
643 for _ in 0..1000 {
644 atomic_clone.store(value);
645 }
646 });
647 handles.push(handle);
648 }
649
650 let atomic_reader = Arc::clone(&atomic);
652 let reader_handle = thread::spawn(move || {
653 let mut readings = Vec::new();
654 for _ in 0..5000 {
655 let value = atomic_reader.load();
656 readings.push(value);
657 }
658 readings
659 });
660
661 for handle in handles {
663 handle.join().expect("Writer thread panicked");
664 }
665
666 let readings = reader_handle.join().expect("Reader thread panicked");
667
668 for &reading in &readings {
671 assert!(test_values.contains(&reading) || reading == 0.0);
672
673 assert!(
676 reading.is_finite() || reading == 0.0,
677 "Corrupted reading detected"
678 );
679 }
680 }
681
682 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
683 #[test]
684 fn test_atomicf32_stress_concurrent_access() {
685 use std::sync::{Arc, Barrier};
686 use std::thread;
687
688 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
689 let atomic = Arc::new(AtomicF32::new(0.0));
690 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
692
693 for i in 0..10 {
695 let atomic_clone = Arc::clone(&atomic);
696 let barrier_clone = Arc::clone(&barrier);
697 let handle = thread::spawn(move || {
698 barrier_clone.wait(); for _ in 0..10000 {
702 let value = i as f32;
703 atomic_clone.store(value);
704 let loaded = atomic_clone.load();
705 assert!(loaded >= 0.0 && loaded <= 9.0);
707 assert!(
708 expected_values.contains(&loaded),
709 "Got unexpected value: {}, expected one of {:?}",
710 loaded,
711 expected_values
712 );
713 }
714 });
715 handles.push(handle);
716 }
717
718 for handle in handles {
719 handle.join().unwrap();
720 }
721 }
722
723 #[test]
724 fn test_atomicf32_integration_with_token_bucket_usage() {
725 let atomic = AtomicF32::new(0.0);
726 let success_reward = 0.3;
727 let iterations = 5;
728
729 for _ in 1..=iterations {
731 let current = atomic.load();
732 atomic.store(current + success_reward);
733 }
734
735 let accumulated = atomic.load();
736 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
740 atomic.store(accumulated - full_tokens);
741 let remaining = atomic.load();
742
743 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
746 assert_eq!(remaining, expected_total - expected_total.floor());
747 }
748
749 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
750 #[test]
751 fn test_atomicf32_clone_creates_independent_copy() {
752 let original = AtomicF32::new(123.456);
753 let cloned = original.clone();
754
755 assert_eq!(original.load(), cloned.load());
757
758 original.store(999.0);
760 assert_eq!(
761 cloned.load(),
762 123.456,
763 "Clone should be unaffected by original changes"
764 );
765 assert_eq!(original.load(), 999.0, "Original should have new value");
766 }
767
768 #[test]
769 fn test_combined_time_and_success_rewards() {
770 use aws_smithy_async::test_util::ManualTimeSource;
771 use std::time::UNIX_EPOCH;
772
773 let time_source = ManualTimeSource::new(UNIX_EPOCH);
774 let bucket = TokenBucket {
775 refill_rate: 1.0,
776 success_reward: 0.5,
777 time_source: time_source.clone().into(),
778 creation_time: time_source.now(),
779 semaphore: Arc::new(Semaphore::new(0)),
780 max_permits: 100,
781 ..Default::default()
782 };
783
784 bucket.reward_success();
786 bucket.reward_success();
787
788 time_source.advance(Duration::from_secs(2));
790
791 bucket.refill_tokens_based_on_time();
794 bucket.convert_fractional_tokens();
795
796 assert_eq!(bucket.available_permits(), 3);
797 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
798 }
799
800 #[test]
801 fn test_refill_rates() {
802 use aws_smithy_async::test_util::ManualTimeSource;
803 use std::time::UNIX_EPOCH;
804 let test_cases = [
806 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
813
814 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
815 let time_source = ManualTimeSource::new(UNIX_EPOCH);
816 let bucket = TokenBucket {
817 refill_rate,
818 time_source: time_source.clone().into(),
819 creation_time: time_source.now(),
820 semaphore: Arc::new(Semaphore::new(0)),
821 max_permits: 100,
822 ..Default::default()
823 };
824
825 time_source.advance(Duration::from_secs(elapsed_secs));
827
828 bucket.refill_tokens_based_on_time();
829 bucket.convert_fractional_tokens();
830
831 assert_eq!(
832 bucket.available_permits(),
833 expected_permits,
834 "Rate {}: After {}s expected {} permits",
835 refill_rate,
836 elapsed_secs,
837 expected_permits
838 );
839 assert!(
840 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
841 "Rate {}: After {}s expected {} fractional, got {}",
842 refill_rate,
843 elapsed_secs,
844 expected_fractional,
845 bucket.fractional_tokens.load()
846 );
847 }
848 }
849
850 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
851 #[test]
852 fn test_rewards_capped_at_max_capacity() {
853 use aws_smithy_async::test_util::ManualTimeSource;
854 use std::time::UNIX_EPOCH;
855
856 let time_source = ManualTimeSource::new(UNIX_EPOCH);
857 let bucket = TokenBucket {
858 refill_rate: 50.0,
859 success_reward: 2.0,
860 time_source: time_source.clone().into(),
861 creation_time: time_source.now(),
862 semaphore: Arc::new(Semaphore::new(5)),
863 max_permits: 10,
864 ..Default::default()
865 };
866
867 for _ in 0..50 {
869 bucket.reward_success();
870 }
871
872 assert_eq!(bucket.fractional_tokens.load(), 10.0);
874
875 time_source.advance(Duration::from_secs(100));
877
878 bucket.refill_tokens_based_on_time();
881
882 assert_eq!(
884 bucket.fractional_tokens.load(),
885 10.0,
886 "Fractional tokens should be capped at max_permits"
887 );
888 bucket.convert_fractional_tokens();
890 assert_eq!(bucket.available_permits(), 10);
891 }
892
893 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
894 #[test]
895 fn test_concurrent_time_based_refill_no_over_generation() {
896 use aws_smithy_async::test_util::ManualTimeSource;
897 use std::sync::{Arc, Barrier};
898 use std::thread;
899 use std::time::UNIX_EPOCH;
900
901 let time_source = ManualTimeSource::new(UNIX_EPOCH);
902 let bucket = Arc::new(TokenBucket {
904 refill_rate: 1.0,
905 time_source: time_source.clone().into(),
906 creation_time: time_source.now(),
907 semaphore: Arc::new(Semaphore::new(0)),
908 max_permits: 100,
909 ..Default::default()
910 });
911
912 time_source.advance(Duration::from_secs(10));
914
915 let barrier = Arc::new(Barrier::new(100));
917 let mut handles = Vec::new();
918
919 for _ in 0..100 {
920 let bucket_clone1 = Arc::clone(&bucket);
921 let barrier_clone1 = Arc::clone(&barrier);
922 let bucket_clone2 = Arc::clone(&bucket);
923 let barrier_clone2 = Arc::clone(&barrier);
924
925 let handle1 = thread::spawn(move || {
926 barrier_clone1.wait();
928
929 bucket_clone1.refill_tokens_based_on_time();
931 });
932
933 let handle2 = thread::spawn(move || {
934 barrier_clone2.wait();
936
937 bucket_clone2.refill_tokens_based_on_time();
939 });
940 handles.push(handle1);
941 handles.push(handle2);
942 }
943
944 for handle in handles {
946 handle.join().unwrap();
947 }
948
949 bucket.convert_fractional_tokens();
951
952 assert_eq!(
955 bucket.available_permits(),
956 10,
957 "Only one thread should have added tokens, not all 100"
958 );
959
960 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
962 }
963}