aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29/// Token bucket used for standard and adaptive retry.
30#[derive(Clone, Debug)]
31pub struct TokenBucket {
32    semaphore: Arc<Semaphore>,
33    max_permits: usize,
34    timeout_retry_cost: u32,
35    retry_cost: u32,
36    success_reward: f32,
37    fractional_tokens: Arc<AtomicF32>,
38    refill_rate: f32,
39    time_source: SharedTimeSource,
40    creation_time: SystemTime,
41    last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47    storage: AtomicU32,
48}
49impl AtomicF32 {
50    fn new(value: f32) -> Self {
51        let as_u32 = value.to_bits();
52        Self {
53            storage: AtomicU32::new(as_u32),
54        }
55    }
56    fn store(&self, value: f32) {
57        let as_u32 = value.to_bits();
58        self.storage.store(as_u32, Ordering::Relaxed)
59    }
60    fn load(&self) -> f32 {
61        let as_u32 = self.storage.load(Ordering::Relaxed);
62        f32::from_bits(as_u32)
63    }
64}
65
66impl fmt::Debug for AtomicF32 {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        // Use debug_struct, debug_tuple, or write! for formatting
69        f.debug_struct("AtomicF32")
70            .field("value", &self.load())
71            .finish()
72    }
73}
74
75impl Clone for AtomicF32 {
76    fn clone(&self) -> Self {
77        // Manually clone each field
78        AtomicF32 {
79            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80        }
81    }
82}
83
84impl Storable for TokenBucket {
85    type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89    fn default() -> Self {
90        let time_source = SharedTimeSource::default();
91        Self {
92            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93            max_permits: DEFAULT_CAPACITY,
94            timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95            retry_cost: DEFAULT_RETRY_COST,
96            success_reward: DEFAULT_SUCCESS_REWARD,
97            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98            refill_rate: 0.0,
99            time_source: time_source.clone(),
100            creation_time: time_source.now(),
101            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102        }
103    }
104}
105
106impl TokenBucket {
107    /// Creates a new `TokenBucket` with the given initial quota.
108    pub fn new(initial_quota: usize) -> Self {
109        let time_source = SharedTimeSource::default();
110        Self {
111            semaphore: Arc::new(Semaphore::new(initial_quota)),
112            max_permits: initial_quota,
113            time_source: time_source.clone(),
114            creation_time: time_source.now(),
115            ..Default::default()
116        }
117    }
118
119    /// A token bucket with unlimited capacity that allows retries at no cost.
120    pub fn unlimited() -> Self {
121        let time_source = SharedTimeSource::default();
122        Self {
123            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124            max_permits: MAXIMUM_CAPACITY,
125            timeout_retry_cost: 0,
126            retry_cost: 0,
127            success_reward: 0.0,
128            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129            refill_rate: 0.0,
130            time_source: time_source.clone(),
131            creation_time: time_source.now(),
132            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133        }
134    }
135
136    /// Creates a builder for constructing a `TokenBucket`.
137    pub fn builder() -> TokenBucketBuilder {
138        TokenBucketBuilder::default()
139    }
140
141    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142        // Add time-based tokens to fractional accumulator
143        self.refill_tokens_based_on_time();
144        // Convert accumulated fractional tokens to whole tokens
145        self.convert_fractional_tokens();
146
147        let retry_cost = if err == &ErrorKind::TransientError {
148            self.timeout_retry_cost
149        } else {
150            self.retry_cost
151        };
152
153        self.semaphore
154            .clone()
155            .try_acquire_many_owned(retry_cost)
156            .ok()
157    }
158
159    pub(crate) fn success_reward(&self) -> f32 {
160        self.success_reward
161    }
162
163    pub(crate) fn regenerate_a_token(&self) {
164        self.add_permits(PERMIT_REGENERATION_AMOUNT);
165    }
166
167    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
168    /// Stores the remaining fractional amount back.
169    /// This is shared by both time-based refill and success rewards.
170    #[inline]
171    fn convert_fractional_tokens(&self) {
172        let mut calc_fractional_tokens = self.fractional_tokens.load();
173        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
174        if !calc_fractional_tokens.is_finite() {
175            tracing::error!(
176                "Fractional tokens corrupted to: {}, resetting to 0.0",
177                calc_fractional_tokens
178            );
179            self.fractional_tokens.store(0.0);
180            return;
181        }
182
183        let full_tokens_accumulated = calc_fractional_tokens.floor();
184        if full_tokens_accumulated >= 1.0 {
185            self.add_permits(full_tokens_accumulated as usize);
186            calc_fractional_tokens -= full_tokens_accumulated;
187        }
188        // Always store the updated fractional tokens back, even if no conversion happened
189        self.fractional_tokens.store(calc_fractional_tokens);
190    }
191
192    /// Refills tokens based on elapsed time since last refill.
193    /// This method implements lazy evaluation - tokens are only calculated when accessed.
194    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
195    #[inline]
196    fn refill_tokens_based_on_time(&self) {
197        if self.refill_rate > 0.0 {
198            let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200            // Get current time from TimeSource and calculate current age
201            let current_time = self.time_source.now();
202            let current_age_secs = current_time
203                .duration_since(self.creation_time)
204                .unwrap_or(Duration::ZERO)
205                .as_secs() as u32;
206
207            // Early exit if no time elapsed - most threads take this path
208            if current_age_secs == last_refill_secs {
209                return;
210            }
211            // Try to atomically claim this time window with a single CAS
212            // If we lose, another thread is handling the refill, so we can exit
213            if self
214                .last_refill_age_secs
215                .compare_exchange(
216                    last_refill_secs,
217                    current_age_secs,
218                    Ordering::Relaxed,
219                    Ordering::Relaxed,
220                )
221                .is_err()
222            {
223                // Another thread claimed this time window, we're done
224                return;
225            }
226
227            // We won the CAS - we're responsible for adding tokens for this time window
228            let current_fractional = self.fractional_tokens.load();
229            let max_fractional = self.max_permits as f32;
230
231            // Skip token addition if already at cap
232            if current_fractional >= max_fractional {
233                return;
234            }
235
236            let elapsed_secs = current_age_secs - last_refill_secs;
237            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
240            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241            self.fractional_tokens.store(new_fractional);
242        }
243    }
244
245    #[inline]
246    pub(crate) fn reward_success(&self) {
247        if self.success_reward > 0.0 {
248            let current = self.fractional_tokens.load();
249            let max_fractional = self.max_permits as f32;
250            // Early exit if already at cap - no point calculating
251            if current >= max_fractional {
252                return;
253            }
254            // Cap fractional tokens at max_permits to prevent unbounded growth
255            let new_fractional = (current + self.success_reward).min(max_fractional);
256            self.fractional_tokens.store(new_fractional);
257        }
258    }
259
260    pub(crate) fn add_permits(&self, amount: usize) {
261        let available = self.semaphore.available_permits();
262        if available >= self.max_permits {
263            return;
264        }
265        self.semaphore
266            .add_permits(amount.min(self.max_permits - available));
267    }
268
269    /// Returns true if the token bucket is full, false otherwise
270    pub fn is_full(&self) -> bool {
271        self.semaphore.available_permits() >= self.max_permits
272    }
273
274    /// Returns true if the token bucket is empty, false otherwise
275    pub fn is_empty(&self) -> bool {
276        self.semaphore.available_permits() == 0
277    }
278
279    #[allow(dead_code)] // only used in tests
280    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281    pub(crate) fn available_permits(&self) -> usize {
282        self.semaphore.available_permits()
283    }
284}
285
286/// Builder for constructing a `TokenBucket`.
287#[derive(Clone, Debug, Default)]
288pub struct TokenBucketBuilder {
289    capacity: Option<usize>,
290    retry_cost: Option<u32>,
291    timeout_retry_cost: Option<u32>,
292    success_reward: Option<f32>,
293    refill_rate: Option<f32>,
294    time_source: Option<SharedTimeSource>,
295}
296
297impl TokenBucketBuilder {
298    /// Creates a new `TokenBucketBuilder` with default values.
299    pub fn new() -> Self {
300        Self::default()
301    }
302
303    /// Sets the maximum bucket capacity for the builder.
304    pub fn capacity(mut self, mut capacity: usize) -> Self {
305        if capacity > MAXIMUM_CAPACITY {
306            capacity = MAXIMUM_CAPACITY;
307        }
308        self.capacity = Some(capacity);
309        self
310    }
311
312    /// Sets the specified retry cost for the builder.
313    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
314        self.retry_cost = Some(retry_cost);
315        self
316    }
317
318    /// Sets the specified timeout retry cost for the builder.
319    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
320        self.timeout_retry_cost = Some(timeout_retry_cost);
321        self
322    }
323
324    /// Sets the reward for any successful request for the builder.
325    pub fn success_reward(mut self, reward: f32) -> Self {
326        self.success_reward = Some(reward);
327        self
328    }
329
330    /// Sets the refill rate (tokens per second) for time-based token regeneration.
331    ///
332    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
333    /// Non-finite values (NaN, infinity) are treated as 0.0.
334    pub fn refill_rate(mut self, rate: f32) -> Self {
335        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
336        self.refill_rate = Some(validated_rate);
337        self
338    }
339
340    /// Sets the time source for the token bucket.
341    ///
342    /// If not set, defaults to `SystemTimeSource`.
343    pub fn time_source(
344        mut self,
345        time_source: impl aws_smithy_async::time::TimeSource + 'static,
346    ) -> Self {
347        self.time_source = Some(SharedTimeSource::new(time_source));
348        self
349    }
350
351    /// Builds a `TokenBucket`.
352    pub fn build(self) -> TokenBucket {
353        let time_source = self.time_source.unwrap_or_default();
354        TokenBucket {
355            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
356            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
357            retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
358            timeout_retry_cost: self
359                .timeout_retry_cost
360                .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
361            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
362            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
363            refill_rate: self.refill_rate.unwrap_or(0.0),
364            time_source: time_source.clone(),
365            creation_time: time_source.now(),
366            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use aws_smithy_async::time::TimeSource;
375
376    #[test]
377    fn test_unlimited_token_bucket() {
378        let bucket = TokenBucket::unlimited();
379
380        // Should always acquire permits regardless of error type
381        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
382        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
383
384        // Should have maximum capacity
385        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
386
387        // Should have zero retry costs
388        assert_eq!(bucket.retry_cost, 0);
389        assert_eq!(bucket.timeout_retry_cost, 0);
390
391        // The loop count is arbitrary; should obtain permits without limit
392        let mut permits = Vec::new();
393        for _ in 0..100 {
394            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
395            assert!(permit.is_some());
396            permits.push(permit);
397            // Available permits should stay constant
398            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
399        }
400    }
401
402    #[test]
403    fn test_bounded_permits_exhaustion() {
404        let bucket = TokenBucket::new(10);
405        let mut permits = Vec::new();
406
407        for _ in 0..100 {
408            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
409            if let Some(p) = permit {
410                permits.push(p);
411            } else {
412                break;
413            }
414        }
415
416        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
417
418        // Verify next acquisition fails
419        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
420    }
421
422    #[test]
423    fn test_fractional_tokens_accumulate_and_convert() {
424        let bucket = TokenBucket::builder()
425            .capacity(10)
426            .success_reward(0.4)
427            .build();
428
429        // acquire 10 tokens to bring capacity below max so we can test accumulation
430        let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
431        assert_eq!(bucket.semaphore.available_permits(), 0);
432
433        // First success: 0.4 fractional tokens
434        bucket.reward_success();
435        bucket.convert_fractional_tokens();
436        assert_eq!(bucket.semaphore.available_permits(), 0);
437
438        // Second success: 0.8 fractional tokens
439        bucket.reward_success();
440        bucket.convert_fractional_tokens();
441        assert_eq!(bucket.semaphore.available_permits(), 0);
442
443        // Third success: 1.2 fractional tokens -> 1 full token added
444        bucket.reward_success();
445        bucket.convert_fractional_tokens();
446        assert_eq!(bucket.semaphore.available_permits(), 1);
447    }
448
449    #[test]
450    fn test_fractional_tokens_respect_max_capacity() {
451        let bucket = TokenBucket::builder()
452            .capacity(10)
453            .success_reward(2.0)
454            .build();
455
456        for _ in 0..20 {
457            bucket.reward_success();
458        }
459
460        assert!(bucket.semaphore.available_permits() == 10);
461    }
462
463    #[test]
464    fn test_convert_fractional_tokens() {
465        // (input, expected_permits_added, expected_remaining)
466        let test_cases = [
467            (0.7, 0, 0.7),
468            (1.0, 1, 0.0),
469            (2.3, 2, 0.3),
470            (5.8, 5, 0.8),
471            (10.0, 10, 0.0),
472            // verify that if fractional permits are corrupted, we reset to 0 gracefully
473            (f32::NAN, 0, 0.0),
474            (f32::INFINITY, 0, 0.0),
475        ];
476
477        for (input, expected_permits, expected_remaining) in test_cases {
478            let bucket = TokenBucket::builder().capacity(10).build();
479            let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
480            let initial = bucket.semaphore.available_permits();
481
482            bucket.fractional_tokens.store(input);
483            bucket.convert_fractional_tokens();
484
485            assert_eq!(
486                bucket.semaphore.available_permits() - initial,
487                expected_permits
488            );
489            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
490        }
491    }
492
493    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
494    #[test]
495    fn test_builder_with_custom_values() {
496        let bucket = TokenBucket::builder()
497            .capacity(100)
498            .retry_cost(10)
499            .timeout_retry_cost(20)
500            .success_reward(0.5)
501            .refill_rate(2.5)
502            .build();
503
504        assert_eq!(bucket.max_permits, 100);
505        assert_eq!(bucket.retry_cost, 10);
506        assert_eq!(bucket.timeout_retry_cost, 20);
507        assert_eq!(bucket.success_reward, 0.5);
508        assert_eq!(bucket.refill_rate, 2.5);
509    }
510
511    #[test]
512    fn test_builder_refill_rate_validation() {
513        // Test negative values are clamped to 0.0
514        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
515        assert_eq!(bucket.refill_rate, 0.0);
516
517        // Test valid positive value
518        let bucket = TokenBucket::builder().refill_rate(1.5).build();
519        assert_eq!(bucket.refill_rate, 1.5);
520
521        // Test zero is valid
522        let bucket = TokenBucket::builder().refill_rate(0.0).build();
523        assert_eq!(bucket.refill_rate, 0.0);
524    }
525
526    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
527    #[test]
528    fn test_builder_custom_time_source() {
529        use aws_smithy_async::test_util::ManualTimeSource;
530        use std::time::UNIX_EPOCH;
531
532        // Test that TokenBucket uses provided TimeSource when specified via builder
533        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
534        let bucket = TokenBucket::builder()
535            .capacity(100)
536            .refill_rate(1.0)
537            .time_source(manual_time.clone())
538            .build();
539
540        // Verify the bucket uses the manual time source
541        assert_eq!(bucket.creation_time, UNIX_EPOCH);
542
543        // Consume all tokens to test refill from empty state
544        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
545        assert_eq!(bucket.available_permits(), 0);
546
547        // Advance time and verify tokens are added based on manual time
548        manual_time.advance(Duration::from_secs(5));
549
550        bucket.refill_tokens_based_on_time();
551        bucket.convert_fractional_tokens();
552
553        // Should have 5 tokens (5 seconds * 1 token/sec)
554        assert_eq!(bucket.available_permits(), 5);
555    }
556
557    #[test]
558    fn test_atomicf32_f32_to_bits_conversion_correctness() {
559        // This is the core functionality
560        let test_values = vec![
561            0.0,
562            -0.0,
563            1.0,
564            -1.0,
565            f32::INFINITY,
566            f32::NEG_INFINITY,
567            f32::NAN,
568            f32::MIN,
569            f32::MAX,
570            f32::MIN_POSITIVE,
571            f32::EPSILON,
572            std::f32::consts::PI,
573            std::f32::consts::E,
574            // Test values that could expose bit manipulation bugs
575            1.23456789e-38, // Very small normal number
576            1.23456789e38,  // Very large number (within f32 range)
577            1.1754944e-38,  // Near MIN_POSITIVE for f32
578        ];
579
580        for &expected in &test_values {
581            let atomic = AtomicF32::new(expected);
582            let actual = atomic.load();
583
584            // For NaN, we can't use == but must check bit patterns
585            if expected.is_nan() {
586                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
587                // Different NaN bit patterns should be preserved exactly
588                assert_eq!(expected.to_bits(), actual.to_bits());
589            } else {
590                assert_eq!(expected.to_bits(), actual.to_bits());
591            }
592        }
593    }
594
595    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
596    #[test]
597    fn test_atomicf32_store_load_preserves_exact_bits() {
598        let atomic = AtomicF32::new(0.0);
599
600        // Test that store/load cycle preserves EXACT bit patterns
601        // This would catch bugs in the to_bits/from_bits conversion
602        let critical_bit_patterns = vec![
603            0x00000000u32, // +0.0
604            0x80000000u32, // -0.0
605            0x7F800000u32, // +infinity
606            0xFF800000u32, // -infinity
607            0x7FC00000u32, // Quiet NaN
608            0x7FA00000u32, // Signaling NaN
609            0x00000001u32, // Smallest positive subnormal
610            0x007FFFFFu32, // Largest subnormal
611            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
612        ];
613
614        for &expected_bits in &critical_bit_patterns {
615            let expected_f32 = f32::from_bits(expected_bits);
616            atomic.store(expected_f32);
617            let loaded_f32 = atomic.load();
618            let actual_bits = loaded_f32.to_bits();
619
620            assert_eq!(expected_bits, actual_bits);
621        }
622    }
623
624    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
625    #[test]
626    fn test_atomicf32_concurrent_store_load_safety() {
627        use std::sync::Arc;
628        use std::thread;
629
630        let atomic = Arc::new(AtomicF32::new(0.0));
631        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
632        let mut handles = Vec::new();
633
634        // Start multiple threads that continuously write different values
635        for &value in &test_values {
636            let atomic_clone = Arc::clone(&atomic);
637            let handle = thread::spawn(move || {
638                for _ in 0..1000 {
639                    atomic_clone.store(value);
640                }
641            });
642            handles.push(handle);
643        }
644
645        // Start a reader thread that continuously reads
646        let atomic_reader = Arc::clone(&atomic);
647        let reader_handle = thread::spawn(move || {
648            let mut readings = Vec::new();
649            for _ in 0..5000 {
650                let value = atomic_reader.load();
651                readings.push(value);
652            }
653            readings
654        });
655
656        // Wait for all writers to complete
657        for handle in handles {
658            handle.join().expect("Writer thread panicked");
659        }
660
661        let readings = reader_handle.join().expect("Reader thread panicked");
662
663        // Verify that all read values are valid (one of the written values)
664        // This tests that there's no data corruption from concurrent access
665        for &reading in &readings {
666            assert!(test_values.contains(&reading) || reading == 0.0);
667
668            // More importantly, verify the reading is a valid f32
669            // (not corrupted bits that happen to parse as valid)
670            assert!(
671                reading.is_finite() || reading == 0.0,
672                "Corrupted reading detected"
673            );
674        }
675    }
676
677    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
678    #[test]
679    fn test_atomicf32_stress_concurrent_access() {
680        use std::sync::{Arc, Barrier};
681        use std::thread;
682
683        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
684        let atomic = Arc::new(AtomicF32::new(0.0));
685        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
686        let mut handles = Vec::new();
687
688        // Launch threads that all start simultaneously
689        for i in 0..10 {
690            let atomic_clone = Arc::clone(&atomic);
691            let barrier_clone = Arc::clone(&barrier);
692            let handle = thread::spawn(move || {
693                barrier_clone.wait(); // All threads start at same time
694
695                // Tight loop increases chance of race conditions
696                for _ in 0..10000 {
697                    let value = i as f32;
698                    atomic_clone.store(value);
699                    let loaded = atomic_clone.load();
700                    // Verify no corruption occurred
701                    assert!(loaded >= 0.0 && loaded <= 9.0);
702                    assert!(
703                        expected_values.contains(&loaded),
704                        "Got unexpected value: {}, expected one of {:?}",
705                        loaded,
706                        expected_values
707                    );
708                }
709            });
710            handles.push(handle);
711        }
712
713        for handle in handles {
714            handle.join().unwrap();
715        }
716    }
717
718    #[test]
719    fn test_atomicf32_integration_with_token_bucket_usage() {
720        let atomic = AtomicF32::new(0.0);
721        let success_reward = 0.3;
722        let iterations = 5;
723
724        // Accumulate fractional tokens
725        for _ in 1..=iterations {
726            let current = atomic.load();
727            atomic.store(current + success_reward);
728        }
729
730        let accumulated = atomic.load();
731        let expected_total = iterations as f32 * success_reward; // 1.5
732
733        // Test the floor() operation pattern
734        let full_tokens = accumulated.floor();
735        atomic.store(accumulated - full_tokens);
736        let remaining = atomic.load();
737
738        // These assertions should be general:
739        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
740        assert!(remaining >= 0.0 && remaining < 1.0);
741        assert_eq!(remaining, expected_total - expected_total.floor());
742    }
743
744    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
745    #[test]
746    fn test_atomicf32_clone_creates_independent_copy() {
747        let original = AtomicF32::new(123.456);
748        let cloned = original.clone();
749
750        // Verify they start with the same value
751        assert_eq!(original.load(), cloned.load());
752
753        // Verify they're independent - modifying one doesn't affect the other
754        original.store(999.0);
755        assert_eq!(
756            cloned.load(),
757            123.456,
758            "Clone should be unaffected by original changes"
759        );
760        assert_eq!(original.load(), 999.0, "Original should have new value");
761    }
762
763    #[test]
764    fn test_combined_time_and_success_rewards() {
765        use aws_smithy_async::test_util::ManualTimeSource;
766        use std::time::UNIX_EPOCH;
767
768        let time_source = ManualTimeSource::new(UNIX_EPOCH);
769        let bucket = TokenBucket {
770            refill_rate: 1.0,
771            success_reward: 0.5,
772            time_source: time_source.clone().into(),
773            creation_time: time_source.now(),
774            semaphore: Arc::new(Semaphore::new(0)),
775            max_permits: 100,
776            ..Default::default()
777        };
778
779        // Add success rewards: 2 * 0.5 = 1.0 token
780        bucket.reward_success();
781        bucket.reward_success();
782
783        // Advance time by 2 seconds
784        time_source.advance(Duration::from_secs(2));
785
786        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
787        // Total: 1.0 + 2.0 = 3.0 tokens
788        bucket.refill_tokens_based_on_time();
789        bucket.convert_fractional_tokens();
790
791        assert_eq!(bucket.available_permits(), 3);
792        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
793    }
794
795    #[test]
796    fn test_refill_rates() {
797        use aws_smithy_async::test_util::ManualTimeSource;
798        use std::time::UNIX_EPOCH;
799        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
800        let test_cases = [
801            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
802            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
803            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
804            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
805            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
806            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
807        ];
808
809        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
810            let time_source = ManualTimeSource::new(UNIX_EPOCH);
811            let bucket = TokenBucket {
812                refill_rate,
813                time_source: time_source.clone().into(),
814                creation_time: time_source.now(),
815                semaphore: Arc::new(Semaphore::new(0)),
816                max_permits: 100,
817                ..Default::default()
818            };
819
820            // Advance time by the specified duration
821            time_source.advance(Duration::from_secs(elapsed_secs));
822
823            bucket.refill_tokens_based_on_time();
824            bucket.convert_fractional_tokens();
825
826            assert_eq!(
827                bucket.available_permits(),
828                expected_permits,
829                "Rate {}: After {}s expected {} permits",
830                refill_rate,
831                elapsed_secs,
832                expected_permits
833            );
834            assert!(
835                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
836                "Rate {}: After {}s expected {} fractional, got {}",
837                refill_rate,
838                elapsed_secs,
839                expected_fractional,
840                bucket.fractional_tokens.load()
841            );
842        }
843    }
844
845    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
846    #[test]
847    fn test_rewards_capped_at_max_capacity() {
848        use aws_smithy_async::test_util::ManualTimeSource;
849        use std::time::UNIX_EPOCH;
850
851        let time_source = ManualTimeSource::new(UNIX_EPOCH);
852        let bucket = TokenBucket {
853            refill_rate: 50.0,
854            success_reward: 2.0,
855            time_source: time_source.clone().into(),
856            creation_time: time_source.now(),
857            semaphore: Arc::new(Semaphore::new(5)),
858            max_permits: 10,
859            ..Default::default()
860        };
861
862        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
863        for _ in 0..50 {
864            bucket.reward_success();
865        }
866
867        // Fractional tokens capped at 10 from success rewards
868        assert_eq!(bucket.fractional_tokens.load(), 10.0);
869
870        // Advance time by 100 seconds
871        time_source.advance(Duration::from_secs(100));
872
873        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
874        // But fractional is already at 10, so it stays at 10
875        bucket.refill_tokens_based_on_time();
876
877        // Fractional tokens should be capped at max_permits (10)
878        assert_eq!(
879            bucket.fractional_tokens.load(),
880            10.0,
881            "Fractional tokens should be capped at max_permits"
882        );
883        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
884        bucket.convert_fractional_tokens();
885        assert_eq!(bucket.available_permits(), 10);
886    }
887
888    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
889    #[test]
890    fn test_concurrent_time_based_refill_no_over_generation() {
891        use aws_smithy_async::test_util::ManualTimeSource;
892        use std::sync::{Arc, Barrier};
893        use std::thread;
894        use std::time::UNIX_EPOCH;
895
896        let time_source = ManualTimeSource::new(UNIX_EPOCH);
897        // Create bucket with 1 token/sec refill
898        let bucket = Arc::new(TokenBucket {
899            refill_rate: 1.0,
900            time_source: time_source.clone().into(),
901            creation_time: time_source.now(),
902            semaphore: Arc::new(Semaphore::new(0)),
903            max_permits: 100,
904            ..Default::default()
905        });
906
907        // Advance time by 10 seconds
908        time_source.advance(Duration::from_secs(10));
909
910        // Launch 100 threads that all try to refill simultaneously
911        let barrier = Arc::new(Barrier::new(100));
912        let mut handles = Vec::new();
913
914        for _ in 0..100 {
915            let bucket_clone1 = Arc::clone(&bucket);
916            let barrier_clone1 = Arc::clone(&barrier);
917            let bucket_clone2 = Arc::clone(&bucket);
918            let barrier_clone2 = Arc::clone(&barrier);
919
920            let handle1 = thread::spawn(move || {
921                // Wait for all threads to be ready
922                barrier_clone1.wait();
923
924                // All threads call refill at the same time
925                bucket_clone1.refill_tokens_based_on_time();
926            });
927
928            let handle2 = thread::spawn(move || {
929                // Wait for all threads to be ready
930                barrier_clone2.wait();
931
932                // All threads call refill at the same time
933                bucket_clone2.refill_tokens_based_on_time();
934            });
935            handles.push(handle1);
936            handles.push(handle2);
937        }
938
939        // Wait for all threads to complete
940        for handle in handles {
941            handle.join().unwrap();
942        }
943
944        // Convert fractional tokens to whole tokens
945        bucket.convert_fractional_tokens();
946
947        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
948        // Not 1000 tokens (100 threads * 10 tokens each)
949        assert_eq!(
950            bucket.available_permits(),
951            10,
952            "Only one thread should have added tokens, not all 100"
953        );
954
955        // Fractional should be 0 after conversion
956        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
957    }
958}