1use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29#[derive(Clone, Debug)]
31pub struct TokenBucket {
32 semaphore: Arc<Semaphore>,
33 max_permits: usize,
34 timeout_retry_cost: u32,
35 retry_cost: u32,
36 success_reward: f32,
37 fractional_tokens: Arc<AtomicF32>,
38 refill_rate: f32,
39 time_source: SharedTimeSource,
40 creation_time: SystemTime,
41 last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47 storage: AtomicU32,
48}
49impl AtomicF32 {
50 fn new(value: f32) -> Self {
51 let as_u32 = value.to_bits();
52 Self {
53 storage: AtomicU32::new(as_u32),
54 }
55 }
56 fn store(&self, value: f32) {
57 let as_u32 = value.to_bits();
58 self.storage.store(as_u32, Ordering::Relaxed)
59 }
60 fn load(&self) -> f32 {
61 let as_u32 = self.storage.load(Ordering::Relaxed);
62 f32::from_bits(as_u32)
63 }
64}
65
66impl fmt::Debug for AtomicF32 {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AtomicF32")
70 .field("value", &self.load())
71 .finish()
72 }
73}
74
75impl Clone for AtomicF32 {
76 fn clone(&self) -> Self {
77 AtomicF32 {
79 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80 }
81 }
82}
83
84impl Storable for TokenBucket {
85 type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89 fn default() -> Self {
90 let time_source = SharedTimeSource::default();
91 Self {
92 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93 max_permits: DEFAULT_CAPACITY,
94 timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95 retry_cost: DEFAULT_RETRY_COST,
96 success_reward: DEFAULT_SUCCESS_REWARD,
97 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98 refill_rate: 0.0,
99 time_source: time_source.clone(),
100 creation_time: time_source.now(),
101 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102 }
103 }
104}
105
106impl TokenBucket {
107 pub fn new(initial_quota: usize) -> Self {
109 let time_source = SharedTimeSource::default();
110 Self {
111 semaphore: Arc::new(Semaphore::new(initial_quota)),
112 max_permits: initial_quota,
113 time_source: time_source.clone(),
114 creation_time: time_source.now(),
115 ..Default::default()
116 }
117 }
118
119 pub fn unlimited() -> Self {
121 let time_source = SharedTimeSource::default();
122 Self {
123 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124 max_permits: MAXIMUM_CAPACITY,
125 timeout_retry_cost: 0,
126 retry_cost: 0,
127 success_reward: 0.0,
128 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129 refill_rate: 0.0,
130 time_source: time_source.clone(),
131 creation_time: time_source.now(),
132 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133 }
134 }
135
136 pub fn builder() -> TokenBucketBuilder {
138 TokenBucketBuilder::default()
139 }
140
141 pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142 self.refill_tokens_based_on_time();
144 self.convert_fractional_tokens();
146
147 let retry_cost = if err == &ErrorKind::TransientError {
148 self.timeout_retry_cost
149 } else {
150 self.retry_cost
151 };
152
153 self.semaphore
154 .clone()
155 .try_acquire_many_owned(retry_cost)
156 .ok()
157 }
158
159 pub(crate) fn success_reward(&self) -> f32 {
160 self.success_reward
161 }
162
163 pub(crate) fn regenerate_a_token(&self) {
164 self.add_permits(PERMIT_REGENERATION_AMOUNT);
165 }
166
167 #[inline]
171 fn convert_fractional_tokens(&self) {
172 let mut calc_fractional_tokens = self.fractional_tokens.load();
173 if !calc_fractional_tokens.is_finite() {
175 tracing::error!(
176 "Fractional tokens corrupted to: {}, resetting to 0.0",
177 calc_fractional_tokens
178 );
179 self.fractional_tokens.store(0.0);
180 return;
181 }
182
183 let full_tokens_accumulated = calc_fractional_tokens.floor();
184 if full_tokens_accumulated >= 1.0 {
185 self.add_permits(full_tokens_accumulated as usize);
186 calc_fractional_tokens -= full_tokens_accumulated;
187 }
188 self.fractional_tokens.store(calc_fractional_tokens);
190 }
191
192 #[inline]
196 fn refill_tokens_based_on_time(&self) {
197 if self.refill_rate > 0.0 {
198 let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200 let current_time = self.time_source.now();
202 let current_age_secs = current_time
203 .duration_since(self.creation_time)
204 .unwrap_or(Duration::ZERO)
205 .as_secs() as u32;
206
207 if current_age_secs == last_refill_secs {
209 return;
210 }
211 if self
214 .last_refill_age_secs
215 .compare_exchange(
216 last_refill_secs,
217 current_age_secs,
218 Ordering::Relaxed,
219 Ordering::Relaxed,
220 )
221 .is_err()
222 {
223 return;
225 }
226
227 let current_fractional = self.fractional_tokens.load();
229 let max_fractional = self.max_permits as f32;
230
231 if current_fractional >= max_fractional {
233 return;
234 }
235
236 let elapsed_secs = current_age_secs - last_refill_secs;
237 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241 self.fractional_tokens.store(new_fractional);
242 }
243 }
244
245 #[inline]
246 pub(crate) fn reward_success(&self) {
247 if self.success_reward > 0.0 {
248 let current = self.fractional_tokens.load();
249 let max_fractional = self.max_permits as f32;
250 if current >= max_fractional {
252 return;
253 }
254 let new_fractional = (current + self.success_reward).min(max_fractional);
256 self.fractional_tokens.store(new_fractional);
257 }
258 }
259
260 pub(crate) fn add_permits(&self, amount: usize) {
261 let available = self.semaphore.available_permits();
262 if available >= self.max_permits {
263 return;
264 }
265 self.semaphore
266 .add_permits(amount.min(self.max_permits - available));
267 }
268
269 pub fn is_full(&self) -> bool {
271 self.semaphore.available_permits() >= self.max_permits
272 }
273
274 pub fn is_empty(&self) -> bool {
276 self.semaphore.available_permits() == 0
277 }
278
279 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281 pub(crate) fn available_permits(&self) -> usize {
282 self.semaphore.available_permits()
283 }
284
285 pub(crate) fn update_time_source(&mut self, new_time_source: SharedTimeSource) {
287 self.time_source = new_time_source;
288 }
289
290 #[allow(dead_code)]
291 #[doc(hidden)]
292 #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
293 pub fn time_source(&self) -> &SharedTimeSource {
295 &self.time_source
296 }
297}
298
299#[derive(Clone, Debug, Default)]
301pub struct TokenBucketBuilder {
302 capacity: Option<usize>,
303 retry_cost: Option<u32>,
304 timeout_retry_cost: Option<u32>,
305 success_reward: Option<f32>,
306 refill_rate: Option<f32>,
307 time_source: Option<SharedTimeSource>,
308}
309
310impl TokenBucketBuilder {
311 pub fn new() -> Self {
313 Self::default()
314 }
315
316 pub fn capacity(mut self, mut capacity: usize) -> Self {
318 if capacity > MAXIMUM_CAPACITY {
319 capacity = MAXIMUM_CAPACITY;
320 }
321 self.capacity = Some(capacity);
322 self
323 }
324
325 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
327 self.retry_cost = Some(retry_cost);
328 self
329 }
330
331 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
333 self.timeout_retry_cost = Some(timeout_retry_cost);
334 self
335 }
336
337 pub fn success_reward(mut self, reward: f32) -> Self {
339 self.success_reward = Some(reward);
340 self
341 }
342
343 pub fn refill_rate(mut self, rate: f32) -> Self {
348 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
349 self.refill_rate = Some(validated_rate);
350 self
351 }
352
353 pub fn time_source(
357 mut self,
358 time_source: impl aws_smithy_async::time::TimeSource + 'static,
359 ) -> Self {
360 self.time_source = Some(SharedTimeSource::new(time_source));
361 self
362 }
363
364 pub fn build(self) -> TokenBucket {
366 let time_source = self.time_source.unwrap_or_default();
367 TokenBucket {
368 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
369 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
370 retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
371 timeout_retry_cost: self
372 .timeout_retry_cost
373 .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
374 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
375 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
376 refill_rate: self.refill_rate.unwrap_or(0.0),
377 time_source: time_source.clone(),
378 creation_time: time_source.now(),
379 last_refill_age_secs: Arc::new(AtomicU32::new(0)),
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use aws_smithy_async::time::TimeSource;
388
389 #[test]
390 fn test_unlimited_token_bucket() {
391 let bucket = TokenBucket::unlimited();
392
393 assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
395 assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
396
397 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
399
400 assert_eq!(bucket.retry_cost, 0);
402 assert_eq!(bucket.timeout_retry_cost, 0);
403
404 let mut permits = Vec::new();
406 for _ in 0..100 {
407 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
408 assert!(permit.is_some());
409 permits.push(permit);
410 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
412 }
413 }
414
415 #[test]
416 fn test_bounded_permits_exhaustion() {
417 let bucket = TokenBucket::new(10);
418 let mut permits = Vec::new();
419
420 for _ in 0..100 {
421 let permit = bucket.acquire(&ErrorKind::ThrottlingError);
422 if let Some(p) = permit {
423 permits.push(p);
424 } else {
425 break;
426 }
427 }
428
429 assert_eq!(permits.len(), 2); assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
433 }
434
435 #[test]
436 fn test_fractional_tokens_accumulate_and_convert() {
437 let bucket = TokenBucket::builder()
438 .capacity(10)
439 .success_reward(0.4)
440 .build();
441
442 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
444 assert_eq!(bucket.semaphore.available_permits(), 0);
445
446 bucket.reward_success();
448 bucket.convert_fractional_tokens();
449 assert_eq!(bucket.semaphore.available_permits(), 0);
450
451 bucket.reward_success();
453 bucket.convert_fractional_tokens();
454 assert_eq!(bucket.semaphore.available_permits(), 0);
455
456 bucket.reward_success();
458 bucket.convert_fractional_tokens();
459 assert_eq!(bucket.semaphore.available_permits(), 1);
460 }
461
462 #[test]
463 fn test_fractional_tokens_respect_max_capacity() {
464 let bucket = TokenBucket::builder()
465 .capacity(10)
466 .success_reward(2.0)
467 .build();
468
469 for _ in 0..20 {
470 bucket.reward_success();
471 }
472
473 assert!(bucket.semaphore.available_permits() == 10);
474 }
475
476 #[test]
477 fn test_convert_fractional_tokens() {
478 let test_cases = [
480 (0.7, 0, 0.7),
481 (1.0, 1, 0.0),
482 (2.3, 2, 0.3),
483 (5.8, 5, 0.8),
484 (10.0, 10, 0.0),
485 (f32::NAN, 0, 0.0),
487 (f32::INFINITY, 0, 0.0),
488 ];
489
490 for (input, expected_permits, expected_remaining) in test_cases {
491 let bucket = TokenBucket::builder().capacity(10).build();
492 let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
493 let initial = bucket.semaphore.available_permits();
494
495 bucket.fractional_tokens.store(input);
496 bucket.convert_fractional_tokens();
497
498 assert_eq!(
499 bucket.semaphore.available_permits() - initial,
500 expected_permits
501 );
502 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
503 }
504 }
505
506 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
507 #[test]
508 fn test_builder_with_custom_values() {
509 let bucket = TokenBucket::builder()
510 .capacity(100)
511 .retry_cost(10)
512 .timeout_retry_cost(20)
513 .success_reward(0.5)
514 .refill_rate(2.5)
515 .build();
516
517 assert_eq!(bucket.max_permits, 100);
518 assert_eq!(bucket.retry_cost, 10);
519 assert_eq!(bucket.timeout_retry_cost, 20);
520 assert_eq!(bucket.success_reward, 0.5);
521 assert_eq!(bucket.refill_rate, 2.5);
522 }
523
524 #[test]
525 fn test_builder_refill_rate_validation() {
526 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
528 assert_eq!(bucket.refill_rate, 0.0);
529
530 let bucket = TokenBucket::builder().refill_rate(1.5).build();
532 assert_eq!(bucket.refill_rate, 1.5);
533
534 let bucket = TokenBucket::builder().refill_rate(0.0).build();
536 assert_eq!(bucket.refill_rate, 0.0);
537 }
538
539 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
540 #[test]
541 fn test_builder_custom_time_source() {
542 use aws_smithy_async::test_util::ManualTimeSource;
543 use std::time::UNIX_EPOCH;
544
545 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
547 let bucket = TokenBucket::builder()
548 .capacity(100)
549 .refill_rate(1.0)
550 .time_source(manual_time.clone())
551 .build();
552
553 assert_eq!(bucket.creation_time, UNIX_EPOCH);
555
556 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
558 assert_eq!(bucket.available_permits(), 0);
559
560 manual_time.advance(Duration::from_secs(5));
562
563 bucket.refill_tokens_based_on_time();
564 bucket.convert_fractional_tokens();
565
566 assert_eq!(bucket.available_permits(), 5);
568 }
569
570 #[test]
571 fn test_atomicf32_f32_to_bits_conversion_correctness() {
572 let test_values = vec![
574 0.0,
575 -0.0,
576 1.0,
577 -1.0,
578 f32::INFINITY,
579 f32::NEG_INFINITY,
580 f32::NAN,
581 f32::MIN,
582 f32::MAX,
583 f32::MIN_POSITIVE,
584 f32::EPSILON,
585 std::f32::consts::PI,
586 std::f32::consts::E,
587 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
592
593 for &expected in &test_values {
594 let atomic = AtomicF32::new(expected);
595 let actual = atomic.load();
596
597 if expected.is_nan() {
599 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
600 assert_eq!(expected.to_bits(), actual.to_bits());
602 } else {
603 assert_eq!(expected.to_bits(), actual.to_bits());
604 }
605 }
606 }
607
608 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
609 #[test]
610 fn test_atomicf32_store_load_preserves_exact_bits() {
611 let atomic = AtomicF32::new(0.0);
612
613 let critical_bit_patterns = vec![
616 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
626
627 for &expected_bits in &critical_bit_patterns {
628 let expected_f32 = f32::from_bits(expected_bits);
629 atomic.store(expected_f32);
630 let loaded_f32 = atomic.load();
631 let actual_bits = loaded_f32.to_bits();
632
633 assert_eq!(expected_bits, actual_bits);
634 }
635 }
636
637 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
638 #[test]
639 fn test_atomicf32_concurrent_store_load_safety() {
640 use std::sync::Arc;
641 use std::thread;
642
643 let atomic = Arc::new(AtomicF32::new(0.0));
644 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
645 let mut handles = Vec::new();
646
647 for &value in &test_values {
649 let atomic_clone = Arc::clone(&atomic);
650 let handle = thread::spawn(move || {
651 for _ in 0..1000 {
652 atomic_clone.store(value);
653 }
654 });
655 handles.push(handle);
656 }
657
658 let atomic_reader = Arc::clone(&atomic);
660 let reader_handle = thread::spawn(move || {
661 let mut readings = Vec::new();
662 for _ in 0..5000 {
663 let value = atomic_reader.load();
664 readings.push(value);
665 }
666 readings
667 });
668
669 for handle in handles {
671 handle.join().expect("Writer thread panicked");
672 }
673
674 let readings = reader_handle.join().expect("Reader thread panicked");
675
676 for &reading in &readings {
679 assert!(test_values.contains(&reading) || reading == 0.0);
680
681 assert!(
684 reading.is_finite() || reading == 0.0,
685 "Corrupted reading detected"
686 );
687 }
688 }
689
690 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
691 #[test]
692 fn test_atomicf32_stress_concurrent_access() {
693 use std::sync::{Arc, Barrier};
694 use std::thread;
695
696 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
697 let atomic = Arc::new(AtomicF32::new(0.0));
698 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
700
701 for i in 0..10 {
703 let atomic_clone = Arc::clone(&atomic);
704 let barrier_clone = Arc::clone(&barrier);
705 let handle = thread::spawn(move || {
706 barrier_clone.wait(); for _ in 0..10000 {
710 let value = i as f32;
711 atomic_clone.store(value);
712 let loaded = atomic_clone.load();
713 assert!(loaded >= 0.0 && loaded <= 9.0);
715 assert!(
716 expected_values.contains(&loaded),
717 "Got unexpected value: {}, expected one of {:?}",
718 loaded,
719 expected_values
720 );
721 }
722 });
723 handles.push(handle);
724 }
725
726 for handle in handles {
727 handle.join().unwrap();
728 }
729 }
730
731 #[test]
732 fn test_atomicf32_integration_with_token_bucket_usage() {
733 let atomic = AtomicF32::new(0.0);
734 let success_reward = 0.3;
735 let iterations = 5;
736
737 for _ in 1..=iterations {
739 let current = atomic.load();
740 atomic.store(current + success_reward);
741 }
742
743 let accumulated = atomic.load();
744 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
748 atomic.store(accumulated - full_tokens);
749 let remaining = atomic.load();
750
751 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
754 assert_eq!(remaining, expected_total - expected_total.floor());
755 }
756
757 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
758 #[test]
759 fn test_atomicf32_clone_creates_independent_copy() {
760 let original = AtomicF32::new(123.456);
761 let cloned = original.clone();
762
763 assert_eq!(original.load(), cloned.load());
765
766 original.store(999.0);
768 assert_eq!(
769 cloned.load(),
770 123.456,
771 "Clone should be unaffected by original changes"
772 );
773 assert_eq!(original.load(), 999.0, "Original should have new value");
774 }
775
776 #[test]
777 fn test_combined_time_and_success_rewards() {
778 use aws_smithy_async::test_util::ManualTimeSource;
779 use std::time::UNIX_EPOCH;
780
781 let time_source = ManualTimeSource::new(UNIX_EPOCH);
782 let bucket = TokenBucket {
783 refill_rate: 1.0,
784 success_reward: 0.5,
785 time_source: time_source.clone().into(),
786 creation_time: time_source.now(),
787 semaphore: Arc::new(Semaphore::new(0)),
788 max_permits: 100,
789 ..Default::default()
790 };
791
792 bucket.reward_success();
794 bucket.reward_success();
795
796 time_source.advance(Duration::from_secs(2));
798
799 bucket.refill_tokens_based_on_time();
802 bucket.convert_fractional_tokens();
803
804 assert_eq!(bucket.available_permits(), 3);
805 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
806 }
807
808 #[test]
809 fn test_refill_rates() {
810 use aws_smithy_async::test_util::ManualTimeSource;
811 use std::time::UNIX_EPOCH;
812 let test_cases = [
814 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
821
822 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
823 let time_source = ManualTimeSource::new(UNIX_EPOCH);
824 let bucket = TokenBucket {
825 refill_rate,
826 time_source: time_source.clone().into(),
827 creation_time: time_source.now(),
828 semaphore: Arc::new(Semaphore::new(0)),
829 max_permits: 100,
830 ..Default::default()
831 };
832
833 time_source.advance(Duration::from_secs(elapsed_secs));
835
836 bucket.refill_tokens_based_on_time();
837 bucket.convert_fractional_tokens();
838
839 assert_eq!(
840 bucket.available_permits(),
841 expected_permits,
842 "Rate {}: After {}s expected {} permits",
843 refill_rate,
844 elapsed_secs,
845 expected_permits
846 );
847 assert!(
848 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
849 "Rate {}: After {}s expected {} fractional, got {}",
850 refill_rate,
851 elapsed_secs,
852 expected_fractional,
853 bucket.fractional_tokens.load()
854 );
855 }
856 }
857
858 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
859 #[test]
860 fn test_rewards_capped_at_max_capacity() {
861 use aws_smithy_async::test_util::ManualTimeSource;
862 use std::time::UNIX_EPOCH;
863
864 let time_source = ManualTimeSource::new(UNIX_EPOCH);
865 let bucket = TokenBucket {
866 refill_rate: 50.0,
867 success_reward: 2.0,
868 time_source: time_source.clone().into(),
869 creation_time: time_source.now(),
870 semaphore: Arc::new(Semaphore::new(5)),
871 max_permits: 10,
872 ..Default::default()
873 };
874
875 for _ in 0..50 {
877 bucket.reward_success();
878 }
879
880 assert_eq!(bucket.fractional_tokens.load(), 10.0);
882
883 time_source.advance(Duration::from_secs(100));
885
886 bucket.refill_tokens_based_on_time();
889
890 assert_eq!(
892 bucket.fractional_tokens.load(),
893 10.0,
894 "Fractional tokens should be capped at max_permits"
895 );
896 bucket.convert_fractional_tokens();
898 assert_eq!(bucket.available_permits(), 10);
899 }
900
901 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
902 #[test]
903 fn test_concurrent_time_based_refill_no_over_generation() {
904 use aws_smithy_async::test_util::ManualTimeSource;
905 use std::sync::{Arc, Barrier};
906 use std::thread;
907 use std::time::UNIX_EPOCH;
908
909 let time_source = ManualTimeSource::new(UNIX_EPOCH);
910 let bucket = Arc::new(TokenBucket {
912 refill_rate: 1.0,
913 time_source: time_source.clone().into(),
914 creation_time: time_source.now(),
915 semaphore: Arc::new(Semaphore::new(0)),
916 max_permits: 100,
917 ..Default::default()
918 });
919
920 time_source.advance(Duration::from_secs(10));
922
923 let barrier = Arc::new(Barrier::new(100));
925 let mut handles = Vec::new();
926
927 for _ in 0..100 {
928 let bucket_clone1 = Arc::clone(&bucket);
929 let barrier_clone1 = Arc::clone(&barrier);
930 let bucket_clone2 = Arc::clone(&bucket);
931 let barrier_clone2 = Arc::clone(&barrier);
932
933 let handle1 = thread::spawn(move || {
934 barrier_clone1.wait();
936
937 bucket_clone1.refill_tokens_based_on_time();
939 });
940
941 let handle2 = thread::spawn(move || {
942 barrier_clone2.wait();
944
945 bucket_clone2.refill_tokens_based_on_time();
947 });
948 handles.push(handle1);
949 handles.push(handle2);
950 }
951
952 for handle in handles {
954 handle.join().unwrap();
955 }
956
957 bucket.convert_fractional_tokens();
959
960 assert_eq!(
963 bucket.available_permits(),
964 10,
965 "Only one thread should have added tokens, not all 100"
966 );
967
968 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
970 }
971}