aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::SharedTimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29/// Token bucket used for standard and adaptive retry.
30#[derive(Clone, Debug)]
31pub struct TokenBucket {
32    semaphore: Arc<Semaphore>,
33    max_permits: usize,
34    timeout_retry_cost: u32,
35    retry_cost: u32,
36    success_reward: f32,
37    fractional_tokens: Arc<AtomicF32>,
38    refill_rate: f32,
39    time_source: SharedTimeSource,
40    creation_time: SystemTime,
41    last_refill_age_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47    storage: AtomicU32,
48}
49impl AtomicF32 {
50    fn new(value: f32) -> Self {
51        let as_u32 = value.to_bits();
52        Self {
53            storage: AtomicU32::new(as_u32),
54        }
55    }
56    fn store(&self, value: f32) {
57        let as_u32 = value.to_bits();
58        self.storage.store(as_u32, Ordering::Relaxed)
59    }
60    fn load(&self) -> f32 {
61        let as_u32 = self.storage.load(Ordering::Relaxed);
62        f32::from_bits(as_u32)
63    }
64}
65
66impl fmt::Debug for AtomicF32 {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        // Use debug_struct, debug_tuple, or write! for formatting
69        f.debug_struct("AtomicF32")
70            .field("value", &self.load())
71            .finish()
72    }
73}
74
75impl Clone for AtomicF32 {
76    fn clone(&self) -> Self {
77        // Manually clone each field
78        AtomicF32 {
79            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80        }
81    }
82}
83
84impl Storable for TokenBucket {
85    type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89    fn default() -> Self {
90        let time_source = SharedTimeSource::default();
91        Self {
92            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
93            max_permits: DEFAULT_CAPACITY,
94            timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
95            retry_cost: DEFAULT_RETRY_COST,
96            success_reward: DEFAULT_SUCCESS_REWARD,
97            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
98            refill_rate: 0.0,
99            time_source: time_source.clone(),
100            creation_time: time_source.now(),
101            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
102        }
103    }
104}
105
106impl TokenBucket {
107    /// Creates a new `TokenBucket` with the given initial quota.
108    pub fn new(initial_quota: usize) -> Self {
109        let time_source = SharedTimeSource::default();
110        Self {
111            semaphore: Arc::new(Semaphore::new(initial_quota)),
112            max_permits: initial_quota,
113            time_source: time_source.clone(),
114            creation_time: time_source.now(),
115            ..Default::default()
116        }
117    }
118
119    /// A token bucket with unlimited capacity that allows retries at no cost.
120    pub fn unlimited() -> Self {
121        let time_source = SharedTimeSource::default();
122        Self {
123            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
124            max_permits: MAXIMUM_CAPACITY,
125            timeout_retry_cost: 0,
126            retry_cost: 0,
127            success_reward: 0.0,
128            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
129            refill_rate: 0.0,
130            time_source: time_source.clone(),
131            creation_time: time_source.now(),
132            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
133        }
134    }
135
136    /// Creates a builder for constructing a `TokenBucket`.
137    pub fn builder() -> TokenBucketBuilder {
138        TokenBucketBuilder::default()
139    }
140
141    pub(crate) fn acquire(&self, err: &ErrorKind) -> Option<OwnedSemaphorePermit> {
142        // Add time-based tokens to fractional accumulator
143        self.refill_tokens_based_on_time();
144        // Convert accumulated fractional tokens to whole tokens
145        self.convert_fractional_tokens();
146
147        let retry_cost = if err == &ErrorKind::TransientError {
148            self.timeout_retry_cost
149        } else {
150            self.retry_cost
151        };
152
153        self.semaphore
154            .clone()
155            .try_acquire_many_owned(retry_cost)
156            .ok()
157    }
158
159    pub(crate) fn success_reward(&self) -> f32 {
160        self.success_reward
161    }
162
163    pub(crate) fn regenerate_a_token(&self) {
164        self.add_permits(PERMIT_REGENERATION_AMOUNT);
165    }
166
167    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
168    /// Stores the remaining fractional amount back.
169    /// This is shared by both time-based refill and success rewards.
170    #[inline]
171    fn convert_fractional_tokens(&self) {
172        let mut calc_fractional_tokens = self.fractional_tokens.load();
173        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
174        if !calc_fractional_tokens.is_finite() {
175            tracing::error!(
176                "Fractional tokens corrupted to: {}, resetting to 0.0",
177                calc_fractional_tokens
178            );
179            self.fractional_tokens.store(0.0);
180            return;
181        }
182
183        let full_tokens_accumulated = calc_fractional_tokens.floor();
184        if full_tokens_accumulated >= 1.0 {
185            self.add_permits(full_tokens_accumulated as usize);
186            calc_fractional_tokens -= full_tokens_accumulated;
187        }
188        // Always store the updated fractional tokens back, even if no conversion happened
189        self.fractional_tokens.store(calc_fractional_tokens);
190    }
191
192    /// Refills tokens based on elapsed time since last refill.
193    /// This method implements lazy evaluation - tokens are only calculated when accessed.
194    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
195    #[inline]
196    fn refill_tokens_based_on_time(&self) {
197        if self.refill_rate > 0.0 {
198            let last_refill_secs = self.last_refill_age_secs.load(Ordering::Relaxed);
199
200            // Get current time from TimeSource and calculate current age
201            let current_time = self.time_source.now();
202            let current_age_secs = current_time
203                .duration_since(self.creation_time)
204                .unwrap_or(Duration::ZERO)
205                .as_secs() as u32;
206
207            // Early exit if no time elapsed - most threads take this path
208            if current_age_secs == last_refill_secs {
209                return;
210            }
211            // Try to atomically claim this time window with a single CAS
212            // If we lose, another thread is handling the refill, so we can exit
213            if self
214                .last_refill_age_secs
215                .compare_exchange(
216                    last_refill_secs,
217                    current_age_secs,
218                    Ordering::Relaxed,
219                    Ordering::Relaxed,
220                )
221                .is_err()
222            {
223                // Another thread claimed this time window, we're done
224                return;
225            }
226
227            // We won the CAS - we're responsible for adding tokens for this time window
228            let current_fractional = self.fractional_tokens.load();
229            let max_fractional = self.max_permits as f32;
230
231            // Skip token addition if already at cap
232            if current_fractional >= max_fractional {
233                return;
234            }
235
236            let elapsed_secs = current_age_secs - last_refill_secs;
237            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
238
239            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
240            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
241            self.fractional_tokens.store(new_fractional);
242        }
243    }
244
245    #[inline]
246    pub(crate) fn reward_success(&self) {
247        if self.success_reward > 0.0 {
248            let current = self.fractional_tokens.load();
249            let max_fractional = self.max_permits as f32;
250            // Early exit if already at cap - no point calculating
251            if current >= max_fractional {
252                return;
253            }
254            // Cap fractional tokens at max_permits to prevent unbounded growth
255            let new_fractional = (current + self.success_reward).min(max_fractional);
256            self.fractional_tokens.store(new_fractional);
257        }
258    }
259
260    pub(crate) fn add_permits(&self, amount: usize) {
261        let available = self.semaphore.available_permits();
262        if available >= self.max_permits {
263            return;
264        }
265        self.semaphore
266            .add_permits(amount.min(self.max_permits - available));
267    }
268
269    /// Returns true if the token bucket is full, false otherwise
270    pub fn is_full(&self) -> bool {
271        self.semaphore.available_permits() >= self.max_permits
272    }
273
274    /// Returns true if the token bucket is empty, false otherwise
275    pub fn is_empty(&self) -> bool {
276        self.semaphore.available_permits() == 0
277    }
278
279    #[allow(dead_code)] // only used in tests
280    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
281    pub(crate) fn available_permits(&self) -> usize {
282        self.semaphore.available_permits()
283    }
284
285    // Allows us to create a default client but still update the time_source
286    pub(crate) fn update_time_source(&mut self, new_time_source: SharedTimeSource) {
287        self.time_source = new_time_source;
288    }
289}
290
291/// Builder for constructing a `TokenBucket`.
292#[derive(Clone, Debug, Default)]
293pub struct TokenBucketBuilder {
294    capacity: Option<usize>,
295    retry_cost: Option<u32>,
296    timeout_retry_cost: Option<u32>,
297    success_reward: Option<f32>,
298    refill_rate: Option<f32>,
299    time_source: Option<SharedTimeSource>,
300}
301
302impl TokenBucketBuilder {
303    /// Creates a new `TokenBucketBuilder` with default values.
304    pub fn new() -> Self {
305        Self::default()
306    }
307
308    /// Sets the maximum bucket capacity for the builder.
309    pub fn capacity(mut self, mut capacity: usize) -> Self {
310        if capacity > MAXIMUM_CAPACITY {
311            capacity = MAXIMUM_CAPACITY;
312        }
313        self.capacity = Some(capacity);
314        self
315    }
316
317    /// Sets the specified retry cost for the builder.
318    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
319        self.retry_cost = Some(retry_cost);
320        self
321    }
322
323    /// Sets the specified timeout retry cost for the builder.
324    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
325        self.timeout_retry_cost = Some(timeout_retry_cost);
326        self
327    }
328
329    /// Sets the reward for any successful request for the builder.
330    pub fn success_reward(mut self, reward: f32) -> Self {
331        self.success_reward = Some(reward);
332        self
333    }
334
335    /// Sets the refill rate (tokens per second) for time-based token regeneration.
336    ///
337    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
338    /// Non-finite values (NaN, infinity) are treated as 0.0.
339    pub fn refill_rate(mut self, rate: f32) -> Self {
340        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
341        self.refill_rate = Some(validated_rate);
342        self
343    }
344
345    /// Sets the time source for the token bucket.
346    ///
347    /// If not set, defaults to `SystemTimeSource`.
348    pub fn time_source(
349        mut self,
350        time_source: impl aws_smithy_async::time::TimeSource + 'static,
351    ) -> Self {
352        self.time_source = Some(SharedTimeSource::new(time_source));
353        self
354    }
355
356    /// Builds a `TokenBucket`.
357    pub fn build(self) -> TokenBucket {
358        let time_source = self.time_source.unwrap_or_default();
359        TokenBucket {
360            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
361            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
362            retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
363            timeout_retry_cost: self
364                .timeout_retry_cost
365                .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
366            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
367            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
368            refill_rate: self.refill_rate.unwrap_or(0.0),
369            time_source: time_source.clone(),
370            creation_time: time_source.now(),
371            last_refill_age_secs: Arc::new(AtomicU32::new(0)),
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use aws_smithy_async::time::TimeSource;
380
381    #[test]
382    fn test_unlimited_token_bucket() {
383        let bucket = TokenBucket::unlimited();
384
385        // Should always acquire permits regardless of error type
386        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_some());
387        assert!(bucket.acquire(&ErrorKind::TransientError).is_some());
388
389        // Should have maximum capacity
390        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
391
392        // Should have zero retry costs
393        assert_eq!(bucket.retry_cost, 0);
394        assert_eq!(bucket.timeout_retry_cost, 0);
395
396        // The loop count is arbitrary; should obtain permits without limit
397        let mut permits = Vec::new();
398        for _ in 0..100 {
399            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
400            assert!(permit.is_some());
401            permits.push(permit);
402            // Available permits should stay constant
403            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
404        }
405    }
406
407    #[test]
408    fn test_bounded_permits_exhaustion() {
409        let bucket = TokenBucket::new(10);
410        let mut permits = Vec::new();
411
412        for _ in 0..100 {
413            let permit = bucket.acquire(&ErrorKind::ThrottlingError);
414            if let Some(p) = permit {
415                permits.push(p);
416            } else {
417                break;
418            }
419        }
420
421        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
422
423        // Verify next acquisition fails
424        assert!(bucket.acquire(&ErrorKind::ThrottlingError).is_none());
425    }
426
427    #[test]
428    fn test_fractional_tokens_accumulate_and_convert() {
429        let bucket = TokenBucket::builder()
430            .capacity(10)
431            .success_reward(0.4)
432            .build();
433
434        // acquire 10 tokens to bring capacity below max so we can test accumulation
435        let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
436        assert_eq!(bucket.semaphore.available_permits(), 0);
437
438        // First success: 0.4 fractional tokens
439        bucket.reward_success();
440        bucket.convert_fractional_tokens();
441        assert_eq!(bucket.semaphore.available_permits(), 0);
442
443        // Second success: 0.8 fractional tokens
444        bucket.reward_success();
445        bucket.convert_fractional_tokens();
446        assert_eq!(bucket.semaphore.available_permits(), 0);
447
448        // Third success: 1.2 fractional tokens -> 1 full token added
449        bucket.reward_success();
450        bucket.convert_fractional_tokens();
451        assert_eq!(bucket.semaphore.available_permits(), 1);
452    }
453
454    #[test]
455    fn test_fractional_tokens_respect_max_capacity() {
456        let bucket = TokenBucket::builder()
457            .capacity(10)
458            .success_reward(2.0)
459            .build();
460
461        for _ in 0..20 {
462            bucket.reward_success();
463        }
464
465        assert!(bucket.semaphore.available_permits() == 10);
466    }
467
468    #[test]
469    fn test_convert_fractional_tokens() {
470        // (input, expected_permits_added, expected_remaining)
471        let test_cases = [
472            (0.7, 0, 0.7),
473            (1.0, 1, 0.0),
474            (2.3, 2, 0.3),
475            (5.8, 5, 0.8),
476            (10.0, 10, 0.0),
477            // verify that if fractional permits are corrupted, we reset to 0 gracefully
478            (f32::NAN, 0, 0.0),
479            (f32::INFINITY, 0, 0.0),
480        ];
481
482        for (input, expected_permits, expected_remaining) in test_cases {
483            let bucket = TokenBucket::builder().capacity(10).build();
484            let _hold_permit = bucket.acquire(&ErrorKind::TransientError);
485            let initial = bucket.semaphore.available_permits();
486
487            bucket.fractional_tokens.store(input);
488            bucket.convert_fractional_tokens();
489
490            assert_eq!(
491                bucket.semaphore.available_permits() - initial,
492                expected_permits
493            );
494            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
495        }
496    }
497
498    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
499    #[test]
500    fn test_builder_with_custom_values() {
501        let bucket = TokenBucket::builder()
502            .capacity(100)
503            .retry_cost(10)
504            .timeout_retry_cost(20)
505            .success_reward(0.5)
506            .refill_rate(2.5)
507            .build();
508
509        assert_eq!(bucket.max_permits, 100);
510        assert_eq!(bucket.retry_cost, 10);
511        assert_eq!(bucket.timeout_retry_cost, 20);
512        assert_eq!(bucket.success_reward, 0.5);
513        assert_eq!(bucket.refill_rate, 2.5);
514    }
515
516    #[test]
517    fn test_builder_refill_rate_validation() {
518        // Test negative values are clamped to 0.0
519        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
520        assert_eq!(bucket.refill_rate, 0.0);
521
522        // Test valid positive value
523        let bucket = TokenBucket::builder().refill_rate(1.5).build();
524        assert_eq!(bucket.refill_rate, 1.5);
525
526        // Test zero is valid
527        let bucket = TokenBucket::builder().refill_rate(0.0).build();
528        assert_eq!(bucket.refill_rate, 0.0);
529    }
530
531    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
532    #[test]
533    fn test_builder_custom_time_source() {
534        use aws_smithy_async::test_util::ManualTimeSource;
535        use std::time::UNIX_EPOCH;
536
537        // Test that TokenBucket uses provided TimeSource when specified via builder
538        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
539        let bucket = TokenBucket::builder()
540            .capacity(100)
541            .refill_rate(1.0)
542            .time_source(manual_time.clone())
543            .build();
544
545        // Verify the bucket uses the manual time source
546        assert_eq!(bucket.creation_time, UNIX_EPOCH);
547
548        // Consume all tokens to test refill from empty state
549        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
550        assert_eq!(bucket.available_permits(), 0);
551
552        // Advance time and verify tokens are added based on manual time
553        manual_time.advance(Duration::from_secs(5));
554
555        bucket.refill_tokens_based_on_time();
556        bucket.convert_fractional_tokens();
557
558        // Should have 5 tokens (5 seconds * 1 token/sec)
559        assert_eq!(bucket.available_permits(), 5);
560    }
561
562    #[test]
563    fn test_atomicf32_f32_to_bits_conversion_correctness() {
564        // This is the core functionality
565        let test_values = vec![
566            0.0,
567            -0.0,
568            1.0,
569            -1.0,
570            f32::INFINITY,
571            f32::NEG_INFINITY,
572            f32::NAN,
573            f32::MIN,
574            f32::MAX,
575            f32::MIN_POSITIVE,
576            f32::EPSILON,
577            std::f32::consts::PI,
578            std::f32::consts::E,
579            // Test values that could expose bit manipulation bugs
580            1.23456789e-38, // Very small normal number
581            1.23456789e38,  // Very large number (within f32 range)
582            1.1754944e-38,  // Near MIN_POSITIVE for f32
583        ];
584
585        for &expected in &test_values {
586            let atomic = AtomicF32::new(expected);
587            let actual = atomic.load();
588
589            // For NaN, we can't use == but must check bit patterns
590            if expected.is_nan() {
591                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
592                // Different NaN bit patterns should be preserved exactly
593                assert_eq!(expected.to_bits(), actual.to_bits());
594            } else {
595                assert_eq!(expected.to_bits(), actual.to_bits());
596            }
597        }
598    }
599
600    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
601    #[test]
602    fn test_atomicf32_store_load_preserves_exact_bits() {
603        let atomic = AtomicF32::new(0.0);
604
605        // Test that store/load cycle preserves EXACT bit patterns
606        // This would catch bugs in the to_bits/from_bits conversion
607        let critical_bit_patterns = vec![
608            0x00000000u32, // +0.0
609            0x80000000u32, // -0.0
610            0x7F800000u32, // +infinity
611            0xFF800000u32, // -infinity
612            0x7FC00000u32, // Quiet NaN
613            0x7FA00000u32, // Signaling NaN
614            0x00000001u32, // Smallest positive subnormal
615            0x007FFFFFu32, // Largest subnormal
616            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
617        ];
618
619        for &expected_bits in &critical_bit_patterns {
620            let expected_f32 = f32::from_bits(expected_bits);
621            atomic.store(expected_f32);
622            let loaded_f32 = atomic.load();
623            let actual_bits = loaded_f32.to_bits();
624
625            assert_eq!(expected_bits, actual_bits);
626        }
627    }
628
629    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
630    #[test]
631    fn test_atomicf32_concurrent_store_load_safety() {
632        use std::sync::Arc;
633        use std::thread;
634
635        let atomic = Arc::new(AtomicF32::new(0.0));
636        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
637        let mut handles = Vec::new();
638
639        // Start multiple threads that continuously write different values
640        for &value in &test_values {
641            let atomic_clone = Arc::clone(&atomic);
642            let handle = thread::spawn(move || {
643                for _ in 0..1000 {
644                    atomic_clone.store(value);
645                }
646            });
647            handles.push(handle);
648        }
649
650        // Start a reader thread that continuously reads
651        let atomic_reader = Arc::clone(&atomic);
652        let reader_handle = thread::spawn(move || {
653            let mut readings = Vec::new();
654            for _ in 0..5000 {
655                let value = atomic_reader.load();
656                readings.push(value);
657            }
658            readings
659        });
660
661        // Wait for all writers to complete
662        for handle in handles {
663            handle.join().expect("Writer thread panicked");
664        }
665
666        let readings = reader_handle.join().expect("Reader thread panicked");
667
668        // Verify that all read values are valid (one of the written values)
669        // This tests that there's no data corruption from concurrent access
670        for &reading in &readings {
671            assert!(test_values.contains(&reading) || reading == 0.0);
672
673            // More importantly, verify the reading is a valid f32
674            // (not corrupted bits that happen to parse as valid)
675            assert!(
676                reading.is_finite() || reading == 0.0,
677                "Corrupted reading detected"
678            );
679        }
680    }
681
682    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
683    #[test]
684    fn test_atomicf32_stress_concurrent_access() {
685        use std::sync::{Arc, Barrier};
686        use std::thread;
687
688        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
689        let atomic = Arc::new(AtomicF32::new(0.0));
690        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
691        let mut handles = Vec::new();
692
693        // Launch threads that all start simultaneously
694        for i in 0..10 {
695            let atomic_clone = Arc::clone(&atomic);
696            let barrier_clone = Arc::clone(&barrier);
697            let handle = thread::spawn(move || {
698                barrier_clone.wait(); // All threads start at same time
699
700                // Tight loop increases chance of race conditions
701                for _ in 0..10000 {
702                    let value = i as f32;
703                    atomic_clone.store(value);
704                    let loaded = atomic_clone.load();
705                    // Verify no corruption occurred
706                    assert!(loaded >= 0.0 && loaded <= 9.0);
707                    assert!(
708                        expected_values.contains(&loaded),
709                        "Got unexpected value: {}, expected one of {:?}",
710                        loaded,
711                        expected_values
712                    );
713                }
714            });
715            handles.push(handle);
716        }
717
718        for handle in handles {
719            handle.join().unwrap();
720        }
721    }
722
723    #[test]
724    fn test_atomicf32_integration_with_token_bucket_usage() {
725        let atomic = AtomicF32::new(0.0);
726        let success_reward = 0.3;
727        let iterations = 5;
728
729        // Accumulate fractional tokens
730        for _ in 1..=iterations {
731            let current = atomic.load();
732            atomic.store(current + success_reward);
733        }
734
735        let accumulated = atomic.load();
736        let expected_total = iterations as f32 * success_reward; // 1.5
737
738        // Test the floor() operation pattern
739        let full_tokens = accumulated.floor();
740        atomic.store(accumulated - full_tokens);
741        let remaining = atomic.load();
742
743        // These assertions should be general:
744        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
745        assert!(remaining >= 0.0 && remaining < 1.0);
746        assert_eq!(remaining, expected_total - expected_total.floor());
747    }
748
749    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
750    #[test]
751    fn test_atomicf32_clone_creates_independent_copy() {
752        let original = AtomicF32::new(123.456);
753        let cloned = original.clone();
754
755        // Verify they start with the same value
756        assert_eq!(original.load(), cloned.load());
757
758        // Verify they're independent - modifying one doesn't affect the other
759        original.store(999.0);
760        assert_eq!(
761            cloned.load(),
762            123.456,
763            "Clone should be unaffected by original changes"
764        );
765        assert_eq!(original.load(), 999.0, "Original should have new value");
766    }
767
768    #[test]
769    fn test_combined_time_and_success_rewards() {
770        use aws_smithy_async::test_util::ManualTimeSource;
771        use std::time::UNIX_EPOCH;
772
773        let time_source = ManualTimeSource::new(UNIX_EPOCH);
774        let bucket = TokenBucket {
775            refill_rate: 1.0,
776            success_reward: 0.5,
777            time_source: time_source.clone().into(),
778            creation_time: time_source.now(),
779            semaphore: Arc::new(Semaphore::new(0)),
780            max_permits: 100,
781            ..Default::default()
782        };
783
784        // Add success rewards: 2 * 0.5 = 1.0 token
785        bucket.reward_success();
786        bucket.reward_success();
787
788        // Advance time by 2 seconds
789        time_source.advance(Duration::from_secs(2));
790
791        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
792        // Total: 1.0 + 2.0 = 3.0 tokens
793        bucket.refill_tokens_based_on_time();
794        bucket.convert_fractional_tokens();
795
796        assert_eq!(bucket.available_permits(), 3);
797        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
798    }
799
800    #[test]
801    fn test_refill_rates() {
802        use aws_smithy_async::test_util::ManualTimeSource;
803        use std::time::UNIX_EPOCH;
804        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
805        let test_cases = [
806            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
807            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
808            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
809            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
810            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
811            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
812        ];
813
814        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
815            let time_source = ManualTimeSource::new(UNIX_EPOCH);
816            let bucket = TokenBucket {
817                refill_rate,
818                time_source: time_source.clone().into(),
819                creation_time: time_source.now(),
820                semaphore: Arc::new(Semaphore::new(0)),
821                max_permits: 100,
822                ..Default::default()
823            };
824
825            // Advance time by the specified duration
826            time_source.advance(Duration::from_secs(elapsed_secs));
827
828            bucket.refill_tokens_based_on_time();
829            bucket.convert_fractional_tokens();
830
831            assert_eq!(
832                bucket.available_permits(),
833                expected_permits,
834                "Rate {}: After {}s expected {} permits",
835                refill_rate,
836                elapsed_secs,
837                expected_permits
838            );
839            assert!(
840                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
841                "Rate {}: After {}s expected {} fractional, got {}",
842                refill_rate,
843                elapsed_secs,
844                expected_fractional,
845                bucket.fractional_tokens.load()
846            );
847        }
848    }
849
850    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
851    #[test]
852    fn test_rewards_capped_at_max_capacity() {
853        use aws_smithy_async::test_util::ManualTimeSource;
854        use std::time::UNIX_EPOCH;
855
856        let time_source = ManualTimeSource::new(UNIX_EPOCH);
857        let bucket = TokenBucket {
858            refill_rate: 50.0,
859            success_reward: 2.0,
860            time_source: time_source.clone().into(),
861            creation_time: time_source.now(),
862            semaphore: Arc::new(Semaphore::new(5)),
863            max_permits: 10,
864            ..Default::default()
865        };
866
867        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
868        for _ in 0..50 {
869            bucket.reward_success();
870        }
871
872        // Fractional tokens capped at 10 from success rewards
873        assert_eq!(bucket.fractional_tokens.load(), 10.0);
874
875        // Advance time by 100 seconds
876        time_source.advance(Duration::from_secs(100));
877
878        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
879        // But fractional is already at 10, so it stays at 10
880        bucket.refill_tokens_based_on_time();
881
882        // Fractional tokens should be capped at max_permits (10)
883        assert_eq!(
884            bucket.fractional_tokens.load(),
885            10.0,
886            "Fractional tokens should be capped at max_permits"
887        );
888        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
889        bucket.convert_fractional_tokens();
890        assert_eq!(bucket.available_permits(), 10);
891    }
892
893    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
894    #[test]
895    fn test_concurrent_time_based_refill_no_over_generation() {
896        use aws_smithy_async::test_util::ManualTimeSource;
897        use std::sync::{Arc, Barrier};
898        use std::thread;
899        use std::time::UNIX_EPOCH;
900
901        let time_source = ManualTimeSource::new(UNIX_EPOCH);
902        // Create bucket with 1 token/sec refill
903        let bucket = Arc::new(TokenBucket {
904            refill_rate: 1.0,
905            time_source: time_source.clone().into(),
906            creation_time: time_source.now(),
907            semaphore: Arc::new(Semaphore::new(0)),
908            max_permits: 100,
909            ..Default::default()
910        });
911
912        // Advance time by 10 seconds
913        time_source.advance(Duration::from_secs(10));
914
915        // Launch 100 threads that all try to refill simultaneously
916        let barrier = Arc::new(Barrier::new(100));
917        let mut handles = Vec::new();
918
919        for _ in 0..100 {
920            let bucket_clone1 = Arc::clone(&bucket);
921            let barrier_clone1 = Arc::clone(&barrier);
922            let bucket_clone2 = Arc::clone(&bucket);
923            let barrier_clone2 = Arc::clone(&barrier);
924
925            let handle1 = thread::spawn(move || {
926                // Wait for all threads to be ready
927                barrier_clone1.wait();
928
929                // All threads call refill at the same time
930                bucket_clone1.refill_tokens_based_on_time();
931            });
932
933            let handle2 = thread::spawn(move || {
934                // Wait for all threads to be ready
935                barrier_clone2.wait();
936
937                // All threads call refill at the same time
938                bucket_clone2.refill_tokens_based_on_time();
939            });
940            handles.push(handle1);
941            handles.push(handle2);
942        }
943
944        // Wait for all threads to complete
945        for handle in handles {
946            handle.join().unwrap();
947        }
948
949        // Convert fractional tokens to whole tokens
950        bucket.convert_fractional_tokens();
951
952        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
953        // Not 1000 tokens (100 threads * 10 tokens each)
954        assert_eq!(
955            bucket.available_permits(),
956            10,
957            "Only one thread should have added tokens, not all 100"
958        );
959
960        // Fractional should be 0 after conversion
961        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
962    }
963}