aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29/// Token bucket used for standard and adaptive retry.
30#[derive(Clone, Debug)]
31pub struct TokenBucket {
32    semaphore: Arc<Semaphore>,
33    max_permits: usize,
34    timeout_retry_cost: u32,
35    retry_cost: u32,
36    success_reward: f32,
37    fractional_tokens: Arc<AtomicF32>,
38    refill_rate: f32,
39    time_source: SharedTimeSource,
40    creation_time: SystemTime,
41    last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47    storage: AtomicU32,
48}
49impl AtomicF32 {
50    fn new(value: f32) -> Self {
51        let as_u32 = value.to_bits();
52        Self {
53            storage: AtomicU32::new(as_u32),
54        }
55    }
56    fn store(&self, value: f32) {
57        let as_u32 = value.to_bits();
58        self.storage.store(as_u32, Ordering::Relaxed)
59    }
60    fn load(&self) -> f32 {
61        let as_u32 = self.storage.load(Ordering::Relaxed);
62        f32::from_bits(as_u32)
63    }
64}
65
66impl fmt::Debug for AtomicF32 {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        // Use debug_struct, debug_tuple, or write! for formatting
69        f.debug_struct("AtomicF32")
70            .field("value", &self.load())
71            .finish()
72    }
73}
74
75impl Clone for AtomicF32 {
76    fn clone(&self) -> Self {
77        // Manually clone each field
78        AtomicF32 {
79            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80        }
81    }
82}
83
84impl Storable for TokenBucket {
85    type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89    fn default() -> Self {
90        let time_source = SharedTimeSource::default();
91        Self {
92            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93            max_permits: DEFAULT_CAPACITY,
94            timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95            retry_cost: DEFAULT_RETRY_COST,
96            success_reward: DEFAULT_SUCCESS_REWARD,
97            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98            refill_rate: 0.0,
99            time_source: time_source.clone(),
100            creation_time: time_source.now(),
101            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102        }
103    }
104}
105
106impl TokenBucket {
107    /// Creates a new `TokenBucket` with the given initial quota.
108    pub fn new(initial_quota: usize) -> Self {
109        let time_source = SharedTimeSource::default();
110        Self {
111            semaphore: Arc::new(Semaphore::new(initial_quota)),
112            max_permits: initial_quota,
113            time_source: time_source.clone(),
114            creation_time: time_source.now(),
115            ..Default::default()
116        }
117    }
118
119    /// A token bucket with unlimited capacity that allows retries at no cost.
120    pub fn unlimited() -> Self {
121        let time_source = SharedTimeSource::default();
122        Self {
123            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124            max_permits: MAXIMUM_CAPACITY,
125            timeout_retry_cost: 0,
126            retry_cost: 0,
127            success_reward: 0.0,
128            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129            refill_rate: 0.0,
130            time_source: time_source.clone(),
131            creation_time: time_source.now(),
132            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133        }
134    }
135
136    /// Creates a builder for constructing a `TokenBucket`.
137    pub fn builder() -> TokenBucketBuilder {
138        TokenBucketBuilder::default()
139    }
140
141    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142        // Add time-based tokens to fractional accumulator
143        self.refill_tokens_based_on_time();
144        // Convert accumulated fractional tokens to whole tokens
145        self.convert_fractional_tokens();
146
147        let retry_cost = if err == &ErrorKind::TransientError {
148            self.timeout_retry_cost
149        } else {
150            self.retry_cost
151        };
152
153        self.semaphore
154            .clone()
155            .try_acquire_many_owned(retry_cost)
156            .ok()
157    }
158
159    pub(crate) fn success_reward(&self) -> f32 {
160        self.success_reward
161    }
162
163    pub(crate) fn regenerate_a_token(&self) {
164        self.add_permits(PERMIT_REGENERATION_AMOUNT);
165    }
166
167    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
168    /// Stores the remaining fractional amount back.
169    /// This is shared by both time-based refill and success rewards.
170    #[inline]
171    fn convert_fractional_tokens(&self) {
172        let mut calc_fractional_tokens = self.fractional_tokens.load();
173        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
174        if !calc_fractional_tokens.is_finite() {
175            tracing::error!(
176                "Fractional tokens corrupted to: {}, resetting to 0.0",
177                calc_fractional_tokens
178            );
179            self.fractional_tokens.store(0.0);
180            return;
181        }
182
183        let full_tokens_accumulated = calc_fractional_tokens.floor();
184        if full_tokens_accumulated >= 1.0 {
185            self.add_permits(full_tokens_accumulated as usize);
186            calc_fractional_tokens -= full_tokens_accumulated;
187        }
188        // Always store the updated fractional tokens back, even if no conversion happened
189        self.fractional_tokens.store(calc_fractional_tokens);
190    }
191
192    /// Refills tokens based on elapsed time since last refill.
193    /// This method implements lazy evaluation - tokens are only calculated when accessed.
194    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
195    #[inline]
196    fn refill_tokens_based_on_time(&self) {
197        if self.refill_rate > 0.0 {
198            let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200            // Get current time from TimeSource and calculate current age
201            let current_time = self.time_source.now();
202            let current_age_secs = current_time
203                .duration_since(self.creation_time)
204                .unwrap_or(Duration::ZERO)
205                .as_secs() as u32;
206
207            // Early exit if no time elapsed - most threads take this path
208            if current_age_secs == last_refill_secs {
209                return;
210            }
211            // Try to atomically claim this time window with a single CAS
212            // If we lose, another thread is handling the refill, so we can exit
213            if self
214                .last_refill_age_secs
215                .compare_exchange(
216                    last_refill_secs,
217                    current_age_secs,
218                    Ordering::Relaxed,
219                    Ordering::Relaxed,
220                )
221                .is_err()
222            {
223                // Another thread claimed this time window, we're done
224                return;
225            }
226
227            // We won the CAS - we're responsible for adding tokens for this time window
228            let current_fractional = self.fractional_tokens.load();
229            let max_fractional = self.max_permits as f32;
230
231            // Skip token addition if already at cap
232            if current_fractional >= max_fractional {
233                return;
234            }
235
236            let elapsed_secs = current_age_secs - last_refill_secs;
237            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
240            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241            self.fractional_tokens.store(new_fractional);
242        }
243    }
244
245    #[inline]
246    pub(crate) fn reward_success(&self) {
247        if self.success_reward > 0.0 {
248            let current = self.fractional_tokens.load();
249            let max_fractional = self.max_permits as f32;
250            // Early exit if already at cap - no point calculating
251            if current >= max_fractional {
252                return;
253            }
254            // Cap fractional tokens at max_permits to prevent unbounded growth
255            let new_fractional = (current + self.success_reward).min(max_fractional);
256            self.fractional_tokens.store(new_fractional);
257        }
258    }
259
260    pub(crate) fn add_permits(&self, amount: usize) {
261        let available = self.semaphore.available_permits();
262        if available >= self.max_permits {
263            return;
264        }
265        self.semaphore
266            .add_permits(amount.min(self.max_permits - available));
267    }
268
269    /// Returns true if the token bucket is full, false otherwise
270    pub fn is_full(&self) -> bool {
271        self.semaphore.available_permits() >= self.max_permits
272    }
273
274    /// Returns true if the token bucket is empty, false otherwise
275    pub fn is_empty(&self) -> bool {
276        self.semaphore.available_permits() == 0
277    }
278
279    #[allow(dead_code)] // only used in tests
280    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281    pub(crate) fn available_permits(&self) -> usize {
282        self.semaphore.available_permits()
283    }
284
285    // Allows us to create a default client but still update the time_source
286    pub(crate) fn update_time_source(&mut self, new_time_source: SharedTimeSource) {
287        self.time_source = new_time_source;
288    }
289
290    #[allow(dead_code)]
291    #[doc(hidden)]
292    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
293    /// This method should only be used for internal testing
294    pub fn time_source(&self) -> &SharedTimeSource {
295        &self.time_source
296    }
297}
298
299/// Builder for constructing a `TokenBucket`.
300#[derive(Clone, Debug, Default)]
301pub struct TokenBucketBuilder {
302    capacity: Option<usize>,
303    retry_cost: Option<u32>,
304    timeout_retry_cost: Option<u32>,
305    success_reward: Option<f32>,
306    refill_rate: Option<f32>,
307    time_source: Option<SharedTimeSource>,
308}
309
310impl TokenBucketBuilder {
311    /// Creates a new `TokenBucketBuilder` with default values.
312    pub fn new() -> Self {
313        Self::default()
314    }
315
316    /// Sets the maximum bucket capacity for the builder.
317    pub fn capacity(mut self, mut capacity: usize) -> Self {
318        if capacity > MAXIMUM_CAPACITY {
319            capacity = MAXIMUM_CAPACITY;
320        }
321        self.capacity = Some(capacity);
322        self
323    }
324
325    /// Sets the specified retry cost for the builder.
326    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
327        self.retry_cost = Some(retry_cost);
328        self
329    }
330
331    /// Sets the specified timeout retry cost for the builder.
332    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
333        self.timeout_retry_cost = Some(timeout_retry_cost);
334        self
335    }
336
337    /// Sets the reward for any successful request for the builder.
338    pub fn success_reward(mut self, reward: f32) -> Self {
339        self.success_reward = Some(reward);
340        self
341    }
342
343    /// Sets the refill rate (tokens per second) for time-based token regeneration.
344    ///
345    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
346    /// Non-finite values (NaN, infinity) are treated as 0.0.
347    pub fn refill_rate(mut self, rate: f32) -> Self {
348        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
349        self.refill_rate = Some(validated_rate);
350        self
351    }
352
353    /// Sets the time source for the token bucket.
354    ///
355    /// If not set, defaults to `SystemTimeSource`.
356    pub fn time_source(
357        mut self,
358        time_source: impl aws_smithy_async::time::TimeSource + 'static,
359    ) -> Self {
360        self.time_source = Some(SharedTimeSource::new(time_source));
361        self
362    }
363
364    /// Builds a `TokenBucket`.
365    pub fn build(self) -> TokenBucket {
366        let time_source = self.time_source.unwrap_or_default();
367        TokenBucket {
368            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
369            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
370            retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
371            timeout_retry_cost: self
372                .timeout_retry_cost
373                .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
374            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
375            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
376            refill_rate: self.refill_rate.unwrap_or(0.0),
377            time_source: time_source.clone(),
378            creation_time: time_source.now(),
379            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use aws_smithy_async::time::TimeSource;
388
389    #[test]
390    fn test_unlimited_token_bucket() {
391        let bucket = TokenBucket::unlimited();
392
393        // Should always acquire permits regardless of error type
394        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
395        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
396
397        // Should have maximum capacity
398        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
399
400        // Should have zero retry costs
401        assert_eq!(bucket.retry_cost, 0);
402        assert_eq!(bucket.timeout_retry_cost, 0);
403
404        // The loop count is arbitrary; should obtain permits without limit
405        let mut permits = Vec::new();
406        for _ in 0..100 {
407            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
408            assert!(permit.is_some());
409            permits.push(permit);
410            // Available permits should stay constant
411            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
412        }
413    }
414
415    #[test]
416    fn test_bounded_permits_exhaustion() {
417        let bucket = TokenBucket::new(10);
418        let mut permits = Vec::new();
419
420        for _ in 0..100 {
421            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
422            if let Some(p) = permit {
423                permits.push(p);
424            } else {
425                break;
426            }
427        }
428
429        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
430
431        // Verify next acquisition fails
432        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
433    }
434
435    #[test]
436    fn test_fractional_tokens_accumulate_and_convert() {
437        let bucket = TokenBucket::builder()
438            .capacity(10)
439            .success_reward(0.4)
440            .build();
441
442        // acquire 10 tokens to bring capacity below max so we can test accumulation
443        let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
444        assert_eq!(bucket.semaphore.available_permits(), 0);
445
446        // First success: 0.4 fractional tokens
447        bucket.reward_success();
448        bucket.convert_fractional_tokens();
449        assert_eq!(bucket.semaphore.available_permits(), 0);
450
451        // Second success: 0.8 fractional tokens
452        bucket.reward_success();
453        bucket.convert_fractional_tokens();
454        assert_eq!(bucket.semaphore.available_permits(), 0);
455
456        // Third success: 1.2 fractional tokens -> 1 full token added
457        bucket.reward_success();
458        bucket.convert_fractional_tokens();
459        assert_eq!(bucket.semaphore.available_permits(), 1);
460    }
461
462    #[test]
463    fn test_fractional_tokens_respect_max_capacity() {
464        let bucket = TokenBucket::builder()
465            .capacity(10)
466            .success_reward(2.0)
467            .build();
468
469        for _ in 0..20 {
470            bucket.reward_success();
471        }
472
473        assert!(bucket.semaphore.available_permits() == 10);
474    }
475
476    #[test]
477    fn test_convert_fractional_tokens() {
478        // (input, expected_permits_added, expected_remaining)
479        let test_cases = [
480            (0.7, 0, 0.7),
481            (1.0, 1, 0.0),
482            (2.3, 2, 0.3),
483            (5.8, 5, 0.8),
484            (10.0, 10, 0.0),
485            // verify that if fractional permits are corrupted, we reset to 0 gracefully
486            (f32::NAN, 0, 0.0),
487            (f32::INFINITY, 0, 0.0),
488        ];
489
490        for (input, expected_permits, expected_remaining) in test_cases {
491            let bucket = TokenBucket::builder().capacity(10).build();
492            let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
493            let initial = bucket.semaphore.available_permits();
494
495            bucket.fractional_tokens.store(input);
496            bucket.convert_fractional_tokens();
497
498            assert_eq!(
499                bucket.semaphore.available_permits() - initial,
500                expected_permits
501            );
502            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
503        }
504    }
505
506    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
507    #[test]
508    fn test_builder_with_custom_values() {
509        let bucket = TokenBucket::builder()
510            .capacity(100)
511            .retry_cost(10)
512            .timeout_retry_cost(20)
513            .success_reward(0.5)
514            .refill_rate(2.5)
515            .build();
516
517        assert_eq!(bucket.max_permits, 100);
518        assert_eq!(bucket.retry_cost, 10);
519        assert_eq!(bucket.timeout_retry_cost, 20);
520        assert_eq!(bucket.success_reward, 0.5);
521        assert_eq!(bucket.refill_rate, 2.5);
522    }
523
524    #[test]
525    fn test_builder_refill_rate_validation() {
526        // Test negative values are clamped to 0.0
527        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
528        assert_eq!(bucket.refill_rate, 0.0);
529
530        // Test valid positive value
531        let bucket = TokenBucket::builder().refill_rate(1.5).build();
532        assert_eq!(bucket.refill_rate, 1.5);
533
534        // Test zero is valid
535        let bucket = TokenBucket::builder().refill_rate(0.0).build();
536        assert_eq!(bucket.refill_rate, 0.0);
537    }
538
539    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
540    #[test]
541    fn test_builder_custom_time_source() {
542        use aws_smithy_async::test_util::ManualTimeSource;
543        use std::time::UNIX_EPOCH;
544
545        // Test that TokenBucket uses provided TimeSource when specified via builder
546        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
547        let bucket = TokenBucket::builder()
548            .capacity(100)
549            .refill_rate(1.0)
550            .time_source(manual_time.clone())
551            .build();
552
553        // Verify the bucket uses the manual time source
554        assert_eq!(bucket.creation_time, UNIX_EPOCH);
555
556        // Consume all tokens to test refill from empty state
557        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
558        assert_eq!(bucket.available_permits(), 0);
559
560        // Advance time and verify tokens are added based on manual time
561        manual_time.advance(Duration::from_secs(5));
562
563        bucket.refill_tokens_based_on_time();
564        bucket.convert_fractional_tokens();
565
566        // Should have 5 tokens (5 seconds * 1 token/sec)
567        assert_eq!(bucket.available_permits(), 5);
568    }
569
570    #[test]
571    fn test_atomicf32_f32_to_bits_conversion_correctness() {
572        // This is the core functionality
573        let test_values = vec![
574            0.0,
575            -0.0,
576            1.0,
577            -1.0,
578            f32::INFINITY,
579            f32::NEG_INFINITY,
580            f32::NAN,
581            f32::MIN,
582            f32::MAX,
583            f32::MIN_POSITIVE,
584            f32::EPSILON,
585            std::f32::consts::PI,
586            std::f32::consts::E,
587            // Test values that could expose bit manipulation bugs
588            1.23456789e-38, // Very small normal number
589            1.23456789e38,  // Very large number (within f32 range)
590            1.1754944e-38,  // Near MIN_POSITIVE for f32
591        ];
592
593        for &expected in &test_values {
594            let atomic = AtomicF32::new(expected);
595            let actual = atomic.load();
596
597            // For NaN, we can't use == but must check bit patterns
598            if expected.is_nan() {
599                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
600                // Different NaN bit patterns should be preserved exactly
601                assert_eq!(expected.to_bits(), actual.to_bits());
602            } else {
603                assert_eq!(expected.to_bits(), actual.to_bits());
604            }
605        }
606    }
607
608    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
609    #[test]
610    fn test_atomicf32_store_load_preserves_exact_bits() {
611        let atomic = AtomicF32::new(0.0);
612
613        // Test that store/load cycle preserves EXACT bit patterns
614        // This would catch bugs in the to_bits/from_bits conversion
615        let critical_bit_patterns = vec![
616            0x00000000u32, // +0.0
617            0x80000000u32, // -0.0
618            0x7F800000u32, // +infinity
619            0xFF800000u32, // -infinity
620            0x7FC00000u32, // Quiet NaN
621            0x7FA00000u32, // Signaling NaN
622            0x00000001u32, // Smallest positive subnormal
623            0x007FFFFFu32, // Largest subnormal
624            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
625        ];
626
627        for &expected_bits in &critical_bit_patterns {
628            let expected_f32 = f32::from_bits(expected_bits);
629            atomic.store(expected_f32);
630            let loaded_f32 = atomic.load();
631            let actual_bits = loaded_f32.to_bits();
632
633            assert_eq!(expected_bits, actual_bits);
634        }
635    }
636
637    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
638    #[test]
639    fn test_atomicf32_concurrent_store_load_safety() {
640        use std::sync::Arc;
641        use std::thread;
642
643        let atomic = Arc::new(AtomicF32::new(0.0));
644        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
645        let mut handles = Vec::new();
646
647        // Start multiple threads that continuously write different values
648        for &value in &test_values {
649            let atomic_clone = Arc::clone(&atomic);
650            let handle = thread::spawn(move || {
651                for _ in 0..1000 {
652                    atomic_clone.store(value);
653                }
654            });
655            handles.push(handle);
656        }
657
658        // Start a reader thread that continuously reads
659        let atomic_reader = Arc::clone(&atomic);
660        let reader_handle = thread::spawn(move || {
661            let mut readings = Vec::new();
662            for _ in 0..5000 {
663                let value = atomic_reader.load();
664                readings.push(value);
665            }
666            readings
667        });
668
669        // Wait for all writers to complete
670        for handle in handles {
671            handle.join().expect("Writer thread panicked");
672        }
673
674        let readings = reader_handle.join().expect("Reader thread panicked");
675
676        // Verify that all read values are valid (one of the written values)
677        // This tests that there's no data corruption from concurrent access
678        for &reading in &readings {
679            assert!(test_values.contains(&reading) || reading == 0.0);
680
681            // More importantly, verify the reading is a valid f32
682            // (not corrupted bits that happen to parse as valid)
683            assert!(
684                reading.is_finite() || reading == 0.0,
685                "Corrupted reading detected"
686            );
687        }
688    }
689
690    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
691    #[test]
692    fn test_atomicf32_stress_concurrent_access() {
693        use std::sync::{Arc, Barrier};
694        use std::thread;
695
696        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
697        let atomic = Arc::new(AtomicF32::new(0.0));
698        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
699        let mut handles = Vec::new();
700
701        // Launch threads that all start simultaneously
702        for i in 0..10 {
703            let atomic_clone = Arc::clone(&atomic);
704            let barrier_clone = Arc::clone(&barrier);
705            let handle = thread::spawn(move || {
706                barrier_clone.wait(); // All threads start at same time
707
708                // Tight loop increases chance of race conditions
709                for _ in 0..10000 {
710                    let value = i as f32;
711                    atomic_clone.store(value);
712                    let loaded = atomic_clone.load();
713                    // Verify no corruption occurred
714                    assert!(loaded >= 0.0 && loaded <= 9.0);
715                    assert!(
716                        expected_values.contains(&loaded),
717                        "Got unexpected value: {}, expected one of {:?}",
718                        loaded,
719                        expected_values
720                    );
721                }
722            });
723            handles.push(handle);
724        }
725
726        for handle in handles {
727            handle.join().unwrap();
728        }
729    }
730
731    #[test]
732    fn test_atomicf32_integration_with_token_bucket_usage() {
733        let atomic = AtomicF32::new(0.0);
734        let success_reward = 0.3;
735        let iterations = 5;
736
737        // Accumulate fractional tokens
738        for _ in 1..=iterations {
739            let current = atomic.load();
740            atomic.store(current + success_reward);
741        }
742
743        let accumulated = atomic.load();
744        let expected_total = iterations as f32 * success_reward; // 1.5
745
746        // Test the floor() operation pattern
747        let full_tokens = accumulated.floor();
748        atomic.store(accumulated - full_tokens);
749        let remaining = atomic.load();
750
751        // These assertions should be general:
752        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
753        assert!(remaining >= 0.0 && remaining < 1.0);
754        assert_eq!(remaining, expected_total - expected_total.floor());
755    }
756
757    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
758    #[test]
759    fn test_atomicf32_clone_creates_independent_copy() {
760        let original = AtomicF32::new(123.456);
761        let cloned = original.clone();
762
763        // Verify they start with the same value
764        assert_eq!(original.load(), cloned.load());
765
766        // Verify they're independent - modifying one doesn't affect the other
767        original.store(999.0);
768        assert_eq!(
769            cloned.load(),
770            123.456,
771            "Clone should be unaffected by original changes"
772        );
773        assert_eq!(original.load(), 999.0, "Original should have new value");
774    }
775
776    #[test]
777    fn test_combined_time_and_success_rewards() {
778        use aws_smithy_async::test_util::ManualTimeSource;
779        use std::time::UNIX_EPOCH;
780
781        let time_source = ManualTimeSource::new(UNIX_EPOCH);
782        let bucket = TokenBucket {
783            refill_rate: 1.0,
784            success_reward: 0.5,
785            time_source: time_source.clone().into(),
786            creation_time: time_source.now(),
787            semaphore: Arc::new(Semaphore::new(0)),
788            max_permits: 100,
789            ..Default::default()
790        };
791
792        // Add success rewards: 2 * 0.5 = 1.0 token
793        bucket.reward_success();
794        bucket.reward_success();
795
796        // Advance time by 2 seconds
797        time_source.advance(Duration::from_secs(2));
798
799        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
800        // Total: 1.0 + 2.0 = 3.0 tokens
801        bucket.refill_tokens_based_on_time();
802        bucket.convert_fractional_tokens();
803
804        assert_eq!(bucket.available_permits(), 3);
805        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
806    }
807
808    #[test]
809    fn test_refill_rates() {
810        use aws_smithy_async::test_util::ManualTimeSource;
811        use std::time::UNIX_EPOCH;
812        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
813        let test_cases = [
814            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
815            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
816            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
817            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
818            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
819            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
820        ];
821
822        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
823            let time_source = ManualTimeSource::new(UNIX_EPOCH);
824            let bucket = TokenBucket {
825                refill_rate,
826                time_source: time_source.clone().into(),
827                creation_time: time_source.now(),
828                semaphore: Arc::new(Semaphore::new(0)),
829                max_permits: 100,
830                ..Default::default()
831            };
832
833            // Advance time by the specified duration
834            time_source.advance(Duration::from_secs(elapsed_secs));
835
836            bucket.refill_tokens_based_on_time();
837            bucket.convert_fractional_tokens();
838
839            assert_eq!(
840                bucket.available_permits(),
841                expected_permits,
842                "Rate {}: After {}s expected {} permits",
843                refill_rate,
844                elapsed_secs,
845                expected_permits
846            );
847            assert!(
848                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
849                "Rate {}: After {}s expected {} fractional, got {}",
850                refill_rate,
851                elapsed_secs,
852                expected_fractional,
853                bucket.fractional_tokens.load()
854            );
855        }
856    }
857
858    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
859    #[test]
860    fn test_rewards_capped_at_max_capacity() {
861        use aws_smithy_async::test_util::ManualTimeSource;
862        use std::time::UNIX_EPOCH;
863
864        let time_source = ManualTimeSource::new(UNIX_EPOCH);
865        let bucket = TokenBucket {
866            refill_rate: 50.0,
867            success_reward: 2.0,
868            time_source: time_source.clone().into(),
869            creation_time: time_source.now(),
870            semaphore: Arc::new(Semaphore::new(5)),
871            max_permits: 10,
872            ..Default::default()
873        };
874
875        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
876        for _ in 0..50 {
877            bucket.reward_success();
878        }
879
880        // Fractional tokens capped at 10 from success rewards
881        assert_eq!(bucket.fractional_tokens.load(), 10.0);
882
883        // Advance time by 100 seconds
884        time_source.advance(Duration::from_secs(100));
885
886        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
887        // But fractional is already at 10, so it stays at 10
888        bucket.refill_tokens_based_on_time();
889
890        // Fractional tokens should be capped at max_permits (10)
891        assert_eq!(
892            bucket.fractional_tokens.load(),
893            10.0,
894            "Fractional tokens should be capped at max_permits"
895        );
896        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
897        bucket.convert_fractional_tokens();
898        assert_eq!(bucket.available_permits(), 10);
899    }
900
901    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
902    #[test]
903    fn test_concurrent_time_based_refill_no_over_generation() {
904        use aws_smithy_async::test_util::ManualTimeSource;
905        use std::sync::{Arc, Barrier};
906        use std::thread;
907        use std::time::UNIX_EPOCH;
908
909        let time_source = ManualTimeSource::new(UNIX_EPOCH);
910        // Create bucket with 1 token/sec refill
911        let bucket = Arc::new(TokenBucket {
912            refill_rate: 1.0,
913            time_source: time_source.clone().into(),
914            creation_time: time_source.now(),
915            semaphore: Arc::new(Semaphore::new(0)),
916            max_permits: 100,
917            ..Default::default()
918        });
919
920        // Advance time by 10 seconds
921        time_source.advance(Duration::from_secs(10));
922
923        // Launch 100 threads that all try to refill simultaneously
924        let barrier = Arc::new(Barrier::new(100));
925        let mut handles = Vec::new();
926
927        for _ in 0..100 {
928            let bucket_clone1 = Arc::clone(&bucket);
929            let barrier_clone1 = Arc::clone(&barrier);
930            let bucket_clone2 = Arc::clone(&bucket);
931            let barrier_clone2 = Arc::clone(&barrier);
932
933            let handle1 = thread::spawn(move || {
934                // Wait for all threads to be ready
935                barrier_clone1.wait();
936
937                // All threads call refill at the same time
938                bucket_clone1.refill_tokens_based_on_time();
939            });
940
941            let handle2 = thread::spawn(move || {
942                // Wait for all threads to be ready
943                barrier_clone2.wait();
944
945                // All threads call refill at the same time
946                bucket_clone2.refill_tokens_based_on_time();
947            });
948            handles.push(handle1);
949            handles.push(handle2);
950        }
951
952        // Wait for all threads to complete
953        for handle in handles {
954            handle.join().unwrap();
955        }
956
957        // Convert fractional tokens to whole tokens
958        bucket.convert_fractional_tokens();
959
960        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
961        // Not 1000 tokens (100 threads * 10 tokens each)
962        assert_eq!(
963            bucket.available_permits(),
964            10,
965            "Only one thread should have added tokens, not all 100"
966        );
967
968        // Fractional should be 0 after conversion
969        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
970    }
971}