1use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29#[derive(Clone, Debug)]
31pub struct TokenBucket {
32 semaphore: Arc<Semaphore>,
33 max_permits: usize,
34 timeout_retry_cost: u32,
35 retry_cost: u32,
36 success_reward: f32,
37 fractional_tokens: Arc<AtomicF32>,
38 refill_rate: f32,
39 last_refill_time_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47 storage: AtomicU32,
48}
49impl AtomicF32 {
50 fn new(value: f32) -> Self {
51 let as_u32 = value.to_bits();
52 Self {
53 storage: AtomicU32::new(as_u32),
54 }
55 }
56 fn store(&self, value: f32) {
57 let as_u32 = value.to_bits();
58 self.storage.store(as_u32, Ordering::Relaxed)
59 }
60 fn load(&self) -> f32 {
61 let as_u32 = self.storage.load(Ordering::Relaxed);
62 f32::from_bits(as_u32)
63 }
64}
65
66impl fmt::Debug for AtomicF32 {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AtomicF32")
70 .field("value", &self.load())
71 .finish()
72 }
73}
74
75impl Clone for AtomicF32 {
76 fn clone(&self) -> Self {
77 AtomicF32 {
79 storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80 }
81 }
82}
83
84impl Storable for TokenBucket {
85 type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89 fn default() -> Self {
90 Self {
91 semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
92 max_permits: DEFAULT_CAPACITY,
93 timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
94 retry_cost: DEFAULT_RETRY_COST,
95 success_reward: DEFAULT_SUCCESS_REWARD,
96 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
97 refill_rate: 0.0,
98 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
99 }
100 }
101}
102
103impl TokenBucket {
104 pub fn new(initial_quota: usize) -> Self {
106 Self {
107 semaphore: Arc::new(Semaphore::new(initial_quota)),
108 max_permits: initial_quota,
109 ..Default::default()
110 }
111 }
112
113 pub fn unlimited() -> Self {
115 Self {
116 semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
117 max_permits: MAXIMUM_CAPACITY,
118 timeout_retry_cost: 0,
119 retry_cost: 0,
120 success_reward: 0.0,
121 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
122 refill_rate: 0.0,
123 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
124 }
125 }
126
127 pub fn builder() -> TokenBucketBuilder {
129 TokenBucketBuilder::default()
130 }
131
132 pub(crate) fn acquire(
133 &self,
134 err: &ErrorKind,
135 time_source: &impl TimeSource,
136 ) -> Option<OwnedSemaphorePermit> {
137 self.refill_tokens_based_on_time(time_source);
139 self.convert_fractional_tokens();
141
142 let retry_cost = if err == &ErrorKind::TransientError {
143 self.timeout_retry_cost
144 } else {
145 self.retry_cost
146 };
147
148 self.semaphore
149 .clone()
150 .try_acquire_many_owned(retry_cost)
151 .ok()
152 }
153
154 pub(crate) fn success_reward(&self) -> f32 {
155 self.success_reward
156 }
157
158 pub(crate) fn regenerate_a_token(&self) {
159 self.add_permits(PERMIT_REGENERATION_AMOUNT);
160 }
161
162 #[inline]
166 fn convert_fractional_tokens(&self) {
167 let mut calc_fractional_tokens = self.fractional_tokens.load();
168 if !calc_fractional_tokens.is_finite() {
170 tracing::error!(
171 "Fractional tokens corrupted to: {}, resetting to 0.0",
172 calc_fractional_tokens
173 );
174 self.fractional_tokens.store(0.0);
175 return;
176 }
177
178 let full_tokens_accumulated = calc_fractional_tokens.floor();
179 if full_tokens_accumulated >= 1.0 {
180 self.add_permits(full_tokens_accumulated as usize);
181 calc_fractional_tokens -= full_tokens_accumulated;
182 }
183 self.fractional_tokens.store(calc_fractional_tokens);
185 }
186
187 #[inline]
191 fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
192 if self.refill_rate > 0.0 {
193 let current_time_secs = time_source
195 .now()
196 .duration_since(SystemTime::UNIX_EPOCH)
197 .unwrap_or(Duration::ZERO)
198 .as_secs() as u32;
199
200 let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
201
202 if current_time_secs == last_refill_secs {
204 return;
205 }
206
207 if self
210 .last_refill_time_secs
211 .compare_exchange(
212 last_refill_secs,
213 current_time_secs,
214 Ordering::Relaxed,
215 Ordering::Relaxed,
216 )
217 .is_err()
218 {
219 return;
221 }
222
223 let current_fractional = self.fractional_tokens.load();
225 let max_fractional = self.max_permits as f32;
226
227 if current_fractional >= max_fractional {
229 return;
230 }
231
232 let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
233 let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
234
235 let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
237 self.fractional_tokens.store(new_fractional);
238 }
239 }
240
241 #[inline]
242 pub(crate) fn reward_success(&self) {
243 if self.success_reward > 0.0 {
244 let current = self.fractional_tokens.load();
245 let max_fractional = self.max_permits as f32;
246 if current >= max_fractional {
248 return;
249 }
250 let new_fractional = (current + self.success_reward).min(max_fractional);
252 self.fractional_tokens.store(new_fractional);
253 }
254 }
255
256 pub(crate) fn add_permits(&self, amount: usize) {
257 let available = self.semaphore.available_permits();
258 if available >= self.max_permits {
259 return;
260 }
261 self.semaphore
262 .add_permits(amount.min(self.max_permits - available));
263 }
264
265 pub fn is_full(&self) -> bool {
267 self.semaphore.available_permits() >= self.max_permits
268 }
269
270 pub fn is_empty(&self) -> bool {
272 self.semaphore.available_permits() == 0
273 }
274
275 #[allow(dead_code)] #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
277 pub(crate) fn available_permits(&self) -> usize {
278 self.semaphore.available_permits()
279 }
280
281 #[allow(dead_code)]
283 #[doc(hidden)]
284 #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
285 pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
286 self.last_refill_time_secs.clone()
287 }
288}
289
290#[derive(Clone, Debug, Default)]
292pub struct TokenBucketBuilder {
293 capacity: Option<usize>,
294 retry_cost: Option<u32>,
295 timeout_retry_cost: Option<u32>,
296 success_reward: Option<f32>,
297 refill_rate: Option<f32>,
298}
299
300impl TokenBucketBuilder {
301 pub fn new() -> Self {
303 Self::default()
304 }
305
306 pub fn capacity(mut self, mut capacity: usize) -> Self {
308 if capacity > MAXIMUM_CAPACITY {
309 capacity = MAXIMUM_CAPACITY;
310 }
311 self.capacity = Some(capacity);
312 self
313 }
314
315 pub fn retry_cost(mut self, retry_cost: u32) -> Self {
317 self.retry_cost = Some(retry_cost);
318 self
319 }
320
321 pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
323 self.timeout_retry_cost = Some(timeout_retry_cost);
324 self
325 }
326
327 pub fn success_reward(mut self, reward: f32) -> Self {
329 self.success_reward = Some(reward);
330 self
331 }
332
333 pub fn refill_rate(mut self, rate: f32) -> Self {
338 let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
339 self.refill_rate = Some(validated_rate);
340 self
341 }
342
343 pub fn build(self) -> TokenBucket {
345 TokenBucket {
346 semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
347 max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
348 retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
349 timeout_retry_cost: self
350 .timeout_retry_cost
351 .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
352 success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
353 fractional_tokens: Arc::new(AtomicF32::new(0.0)),
354 refill_rate: self.refill_rate.unwrap_or(0.0),
355 last_refill_time_secs: Arc::new(AtomicU32::new(0)),
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362
363 use super::*;
364 use aws_smithy_async::test_util::ManualTimeSource;
365 use std::{sync::LazyLock, time::UNIX_EPOCH};
366
367 static TIME_SOURCE: LazyLock<ManualTimeSource> =
368 LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
369
370 #[test]
371 fn test_unlimited_token_bucket() {
372 let bucket = TokenBucket::unlimited();
373
374 assert!(bucket
376 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
377 .is_some());
378 assert!(bucket
379 .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
380 .is_some());
381
382 assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
384
385 assert_eq!(bucket.retry_cost, 0);
387 assert_eq!(bucket.timeout_retry_cost, 0);
388
389 let mut permits = Vec::new();
391 for _ in 0..100 {
392 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
393 assert!(permit.is_some());
394 permits.push(permit);
395 assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
397 }
398 }
399
400 #[test]
401 fn test_bounded_permits_exhaustion() {
402 let bucket = TokenBucket::new(10);
403 let mut permits = Vec::new();
404
405 for _ in 0..100 {
406 let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
407 if let Some(p) = permit {
408 permits.push(p);
409 } else {
410 break;
411 }
412 }
413
414 assert_eq!(permits.len(), 2); assert!(bucket
418 .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
419 .is_none());
420 }
421
422 #[test]
423 fn test_fractional_tokens_accumulate_and_convert() {
424 let bucket = TokenBucket::builder()
425 .capacity(10)
426 .success_reward(0.4)
427 .build();
428
429 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
431 assert_eq!(bucket.semaphore.available_permits(), 0);
432
433 bucket.reward_success();
435 bucket.convert_fractional_tokens();
436 assert_eq!(bucket.semaphore.available_permits(), 0);
437
438 bucket.reward_success();
440 bucket.convert_fractional_tokens();
441 assert_eq!(bucket.semaphore.available_permits(), 0);
442
443 bucket.reward_success();
445 bucket.convert_fractional_tokens();
446 assert_eq!(bucket.semaphore.available_permits(), 1);
447 }
448
449 #[test]
450 fn test_fractional_tokens_respect_max_capacity() {
451 let bucket = TokenBucket::builder()
452 .capacity(10)
453 .success_reward(2.0)
454 .build();
455
456 for _ in 0..20 {
457 bucket.reward_success();
458 }
459
460 assert!(bucket.semaphore.available_permits() == 10);
461 }
462
463 #[test]
464 fn test_convert_fractional_tokens() {
465 let test_cases = [
467 (0.7, 0, 0.7),
468 (1.0, 1, 0.0),
469 (2.3, 2, 0.3),
470 (5.8, 5, 0.8),
471 (10.0, 10, 0.0),
472 (f32::NAN, 0, 0.0),
474 (f32::INFINITY, 0, 0.0),
475 ];
476
477 for (input, expected_permits, expected_remaining) in test_cases {
478 let bucket = TokenBucket::builder().capacity(10).build();
479 let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
480 let initial = bucket.semaphore.available_permits();
481
482 bucket.fractional_tokens.store(input);
483 bucket.convert_fractional_tokens();
484
485 assert_eq!(
486 bucket.semaphore.available_permits() - initial,
487 expected_permits
488 );
489 assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
490 }
491 }
492
493 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
494 #[test]
495 fn test_builder_with_custom_values() {
496 let bucket = TokenBucket::builder()
497 .capacity(100)
498 .retry_cost(10)
499 .timeout_retry_cost(20)
500 .success_reward(0.5)
501 .refill_rate(2.5)
502 .build();
503
504 assert_eq!(bucket.max_permits, 100);
505 assert_eq!(bucket.retry_cost, 10);
506 assert_eq!(bucket.timeout_retry_cost, 20);
507 assert_eq!(bucket.success_reward, 0.5);
508 assert_eq!(bucket.refill_rate, 2.5);
509 }
510
511 #[test]
512 fn test_builder_refill_rate_validation() {
513 let bucket = TokenBucket::builder().refill_rate(-5.0).build();
515 assert_eq!(bucket.refill_rate, 0.0);
516
517 let bucket = TokenBucket::builder().refill_rate(1.5).build();
519 assert_eq!(bucket.refill_rate, 1.5);
520
521 let bucket = TokenBucket::builder().refill_rate(0.0).build();
523 assert_eq!(bucket.refill_rate, 0.0);
524 }
525
526 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
527 #[test]
528 fn test_builder_custom_time_source() {
529 use aws_smithy_async::test_util::ManualTimeSource;
530 use std::time::UNIX_EPOCH;
531
532 let manual_time = ManualTimeSource::new(UNIX_EPOCH);
534 let bucket = TokenBucket::builder()
535 .capacity(100)
536 .refill_rate(1.0)
537 .build();
538
539 let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
541 assert_eq!(bucket.available_permits(), 0);
542
543 manual_time.advance(Duration::from_secs(5));
545
546 bucket.refill_tokens_based_on_time(&manual_time);
547 bucket.convert_fractional_tokens();
548
549 assert_eq!(bucket.available_permits(), 5);
551 }
552
553 #[test]
554 fn test_atomicf32_f32_to_bits_conversion_correctness() {
555 let test_values = vec![
557 0.0,
558 -0.0,
559 1.0,
560 -1.0,
561 f32::INFINITY,
562 f32::NEG_INFINITY,
563 f32::NAN,
564 f32::MIN,
565 f32::MAX,
566 f32::MIN_POSITIVE,
567 f32::EPSILON,
568 std::f32::consts::PI,
569 std::f32::consts::E,
570 1.23456789e-38, 1.23456789e38, 1.1754944e-38, ];
575
576 for &expected in &test_values {
577 let atomic = AtomicF32::new(expected);
578 let actual = atomic.load();
579
580 if expected.is_nan() {
582 assert!(actual.is_nan(), "Expected NaN, got {}", actual);
583 assert_eq!(expected.to_bits(), actual.to_bits());
585 } else {
586 assert_eq!(expected.to_bits(), actual.to_bits());
587 }
588 }
589 }
590
591 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
592 #[test]
593 fn test_atomicf32_store_load_preserves_exact_bits() {
594 let atomic = AtomicF32::new(0.0);
595
596 let critical_bit_patterns = vec![
599 0x00000000u32, 0x80000000u32, 0x7F800000u32, 0xFF800000u32, 0x7FC00000u32, 0x7FA00000u32, 0x00000001u32, 0x007FFFFFu32, 0x00800000u32, ];
609
610 for &expected_bits in &critical_bit_patterns {
611 let expected_f32 = f32::from_bits(expected_bits);
612 atomic.store(expected_f32);
613 let loaded_f32 = atomic.load();
614 let actual_bits = loaded_f32.to_bits();
615
616 assert_eq!(expected_bits, actual_bits);
617 }
618 }
619
620 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
621 #[test]
622 fn test_atomicf32_concurrent_store_load_safety() {
623 use std::sync::Arc;
624 use std::thread;
625
626 let atomic = Arc::new(AtomicF32::new(0.0));
627 let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
628 let mut handles = Vec::new();
629
630 for &value in &test_values {
632 let atomic_clone = Arc::clone(&atomic);
633 let handle = thread::spawn(move || {
634 for _ in 0..1000 {
635 atomic_clone.store(value);
636 }
637 });
638 handles.push(handle);
639 }
640
641 let atomic_reader = Arc::clone(&atomic);
643 let reader_handle = thread::spawn(move || {
644 let mut readings = Vec::new();
645 for _ in 0..5000 {
646 let value = atomic_reader.load();
647 readings.push(value);
648 }
649 readings
650 });
651
652 for handle in handles {
654 handle.join().expect("Writer thread panicked");
655 }
656
657 let readings = reader_handle.join().expect("Reader thread panicked");
658
659 for &reading in &readings {
662 assert!(test_values.contains(&reading) || reading == 0.0);
663
664 assert!(
667 reading.is_finite() || reading == 0.0,
668 "Corrupted reading detected"
669 );
670 }
671 }
672
673 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
674 #[test]
675 fn test_atomicf32_stress_concurrent_access() {
676 use std::sync::{Arc, Barrier};
677 use std::thread;
678
679 let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
680 let atomic = Arc::new(AtomicF32::new(0.0));
681 let barrier = Arc::new(Barrier::new(10)); let mut handles = Vec::new();
683
684 for i in 0..10 {
686 let atomic_clone = Arc::clone(&atomic);
687 let barrier_clone = Arc::clone(&barrier);
688 let handle = thread::spawn(move || {
689 barrier_clone.wait(); for _ in 0..10000 {
693 let value = i as f32;
694 atomic_clone.store(value);
695 let loaded = atomic_clone.load();
696 assert!(loaded >= 0.0 && loaded <= 9.0);
698 assert!(
699 expected_values.contains(&loaded),
700 "Got unexpected value: {}, expected one of {:?}",
701 loaded,
702 expected_values
703 );
704 }
705 });
706 handles.push(handle);
707 }
708
709 for handle in handles {
710 handle.join().unwrap();
711 }
712 }
713
714 #[test]
715 fn test_atomicf32_integration_with_token_bucket_usage() {
716 let atomic = AtomicF32::new(0.0);
717 let success_reward = 0.3;
718 let iterations = 5;
719
720 for _ in 1..=iterations {
722 let current = atomic.load();
723 atomic.store(current + success_reward);
724 }
725
726 let accumulated = atomic.load();
727 let expected_total = iterations as f32 * success_reward; let full_tokens = accumulated.floor();
731 atomic.store(accumulated - full_tokens);
732 let remaining = atomic.load();
733
734 assert_eq!(full_tokens, expected_total.floor()); assert!(remaining >= 0.0 && remaining < 1.0);
737 assert_eq!(remaining, expected_total - expected_total.floor());
738 }
739
740 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
741 #[test]
742 fn test_atomicf32_clone_creates_independent_copy() {
743 let original = AtomicF32::new(123.456);
744 let cloned = original.clone();
745
746 assert_eq!(original.load(), cloned.load());
748
749 original.store(999.0);
751 assert_eq!(
752 cloned.load(),
753 123.456,
754 "Clone should be unaffected by original changes"
755 );
756 assert_eq!(original.load(), 999.0, "Original should have new value");
757 }
758
759 #[test]
760 fn test_combined_time_and_success_rewards() {
761 use aws_smithy_async::test_util::ManualTimeSource;
762 use std::time::UNIX_EPOCH;
763
764 let time_source = ManualTimeSource::new(UNIX_EPOCH);
765 let current_time_secs = UNIX_EPOCH
766 .duration_since(SystemTime::UNIX_EPOCH)
767 .unwrap()
768 .as_secs() as u32;
769
770 let bucket = TokenBucket {
771 refill_rate: 1.0,
772 success_reward: 0.5,
773 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
774 semaphore: Arc::new(Semaphore::new(0)),
775 max_permits: 100,
776 ..Default::default()
777 };
778
779 bucket.reward_success();
781 bucket.reward_success();
782
783 time_source.advance(Duration::from_secs(2));
785
786 bucket.refill_tokens_based_on_time(&time_source);
789 bucket.convert_fractional_tokens();
790
791 assert_eq!(bucket.available_permits(), 3);
792 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
793 }
794
795 #[test]
796 fn test_refill_rates() {
797 use aws_smithy_async::test_util::ManualTimeSource;
798 use std::time::UNIX_EPOCH;
799 let test_cases = [
801 (10.0, 2, 20, 0.0), (0.001, 1100, 1, 0.1), (0.0001, 11000, 1, 0.1), (0.001, 1200, 1, 0.2), (0.0001, 10000, 1, 0.0), (0.001, 500, 0, 0.5), ];
808
809 for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
810 let time_source = ManualTimeSource::new(UNIX_EPOCH);
811 let current_time_secs = UNIX_EPOCH
812 .duration_since(SystemTime::UNIX_EPOCH)
813 .unwrap()
814 .as_secs() as u32;
815
816 let bucket = TokenBucket {
817 refill_rate,
818 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
819 semaphore: Arc::new(Semaphore::new(0)),
820 max_permits: 100,
821 ..Default::default()
822 };
823
824 time_source.advance(Duration::from_secs(elapsed_secs));
826
827 bucket.refill_tokens_based_on_time(&time_source);
828 bucket.convert_fractional_tokens();
829
830 assert_eq!(
831 bucket.available_permits(),
832 expected_permits,
833 "Rate {}: After {}s expected {} permits",
834 refill_rate,
835 elapsed_secs,
836 expected_permits
837 );
838 assert!(
839 (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
840 "Rate {}: After {}s expected {} fractional, got {}",
841 refill_rate,
842 elapsed_secs,
843 expected_fractional,
844 bucket.fractional_tokens.load()
845 );
846 }
847 }
848
849 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
850 #[test]
851 fn test_rewards_capped_at_max_capacity() {
852 use aws_smithy_async::test_util::ManualTimeSource;
853 use std::time::UNIX_EPOCH;
854
855 let time_source = ManualTimeSource::new(UNIX_EPOCH);
856 let current_time_secs = UNIX_EPOCH
857 .duration_since(SystemTime::UNIX_EPOCH)
858 .unwrap()
859 .as_secs() as u32;
860
861 let bucket = TokenBucket {
862 refill_rate: 50.0,
863 success_reward: 2.0,
864 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
865 semaphore: Arc::new(Semaphore::new(5)),
866 max_permits: 10,
867 ..Default::default()
868 };
869
870 for _ in 0..50 {
872 bucket.reward_success();
873 }
874
875 assert_eq!(bucket.fractional_tokens.load(), 10.0);
877
878 time_source.advance(Duration::from_secs(100));
880
881 bucket.refill_tokens_based_on_time(&time_source);
884
885 assert_eq!(
887 bucket.fractional_tokens.load(),
888 10.0,
889 "Fractional tokens should be capped at max_permits"
890 );
891 bucket.convert_fractional_tokens();
893 assert_eq!(bucket.available_permits(), 10);
894 }
895
896 #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
897 #[test]
898 fn test_concurrent_time_based_refill_no_over_generation() {
899 use aws_smithy_async::test_util::ManualTimeSource;
900 use std::sync::{Arc, Barrier};
901 use std::thread;
902 use std::time::UNIX_EPOCH;
903
904 let time_source = ManualTimeSource::new(UNIX_EPOCH);
905 let current_time_secs = UNIX_EPOCH
906 .duration_since(SystemTime::UNIX_EPOCH)
907 .unwrap()
908 .as_secs() as u32;
909
910 let bucket = Arc::new(TokenBucket {
912 refill_rate: 1.0,
913 last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
914 semaphore: Arc::new(Semaphore::new(0)),
915 max_permits: 100,
916 ..Default::default()
917 });
918
919 time_source.advance(Duration::from_secs(10));
921 let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
922
923 let barrier = Arc::new(Barrier::new(100));
925 let mut handles = Vec::new();
926
927 for _ in 0..100 {
928 let bucket_clone1 = Arc::clone(&bucket);
929 let barrier_clone1 = Arc::clone(&barrier);
930 let time_source_clone1 = shared_time_source.clone();
931 let bucket_clone2 = Arc::clone(&bucket);
932 let barrier_clone2 = Arc::clone(&barrier);
933 let time_source_clone2 = shared_time_source.clone();
934
935 let handle1 = thread::spawn(move || {
936 barrier_clone1.wait();
938
939 bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
941 });
942
943 let handle2 = thread::spawn(move || {
944 barrier_clone2.wait();
946
947 bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
949 });
950 handles.push(handle1);
951 handles.push(handle2);
952 }
953
954 for handle in handles {
956 handle.join().unwrap();
957 }
958
959 bucket.convert_fractional_tokens();
961
962 assert_eq!(
965 bucket.available_permits(),
966 10,
967 "Only one thread should have added tokens, not all 100"
968 );
969
970 assert!(bucket.fractional_tokens.load().abs() < 0.0001);
972 }
973}