aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24const DEFAULT_RETRY_COST: u32 = 5;
25const DEFAULT_RETRY_TIMEOUT_COST: u32 = DEFAULT_RETRY_COST * 2;
26const PERMIT_REGENERATION_AMOUNT: usize = 1;
27const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
28
29/// Token bucket used for standard and adaptive retry.
30#[derive(Clone, Debug)]
31pub struct TokenBucket {
32    semaphore: Arc<Semaphore>,
33    max_permits: usize,
34    timeout_retry_cost: u32,
35    retry_cost: u32,
36    success_reward: f32,
37    fractional_tokens: Arc<AtomicF32>,
38    refill_rate: f32,
39    // Note this value is only an AtomicU32 so it works on 32bit powerpc architectures.
40    // If we ever remove the need for that compatibility it should become an AtomicU64
41    last_refill_time_secs: Arc<AtomicU32>,
42}
43
44impl std::panic::UnwindSafe for AtomicF32 {}
45impl std::panic::RefUnwindSafe for AtomicF32 {}
46struct AtomicF32 {
47    storage: AtomicU32,
48}
49impl AtomicF32 {
50    fn new(value: f32) -> Self {
51        let as_u32 = value.to_bits();
52        Self {
53            storage: AtomicU32::new(as_u32),
54        }
55    }
56    fn store(&self, value: f32) {
57        let as_u32 = value.to_bits();
58        self.storage.store(as_u32, Ordering::Relaxed)
59    }
60    fn load(&self) -> f32 {
61        let as_u32 = self.storage.load(Ordering::Relaxed);
62        f32::from_bits(as_u32)
63    }
64}
65
66impl fmt::Debug for AtomicF32 {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        // Use debug_struct, debug_tuple, or write! for formatting
69        f.debug_struct("AtomicF32")
70            .field("value", &self.load())
71            .finish()
72    }
73}
74
75impl Clone for AtomicF32 {
76    fn clone(&self) -> Self {
77        // Manually clone each field
78        AtomicF32 {
79            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
80        }
81    }
82}
83
84impl Storable for TokenBucket {
85    type Storer = StoreReplace<Self>;
86}
87
88impl Default for TokenBucket {
89    fn default() -> Self {
90        Self {
91            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
92            max_permits: DEFAULT_CAPACITY,
93            timeout_retry_cost: DEFAULT_RETRY_TIMEOUT_COST,
94            retry_cost: DEFAULT_RETRY_COST,
95            success_reward: DEFAULT_SUCCESS_REWARD,
96            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
97            refill_rate: 0.0,
98            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
99        }
100    }
101}
102
103impl TokenBucket {
104    /// Creates a new `TokenBucket` with the given initial quota.
105    pub fn new(initial_quota: usize) -> Self {
106        Self {
107            semaphore: Arc::new(Semaphore::new(initial_quota)),
108            max_permits: initial_quota,
109            ..Default::default()
110        }
111    }
112
113    /// A token bucket with unlimited capacity that allows retries at no cost.
114    pub fn unlimited() -> Self {
115        Self {
116            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
117            max_permits: MAXIMUM_CAPACITY,
118            timeout_retry_cost: 0,
119            retry_cost: 0,
120            success_reward: 0.0,
121            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
122            refill_rate: 0.0,
123            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
124        }
125    }
126
127    /// Creates a builder for constructing a `TokenBucket`.
128    pub fn builder() -> TokenBucketBuilder {
129        TokenBucketBuilder::default()
130    }
131
132    pub(crate) fn acquire(
133        &self,
134        err: &ErrorKind,
135        time_source: &impl TimeSource,
136    ) -> Option<OwnedSemaphorePermit> {
137        // Add time-based tokens to fractional accumulator
138        self.refill_tokens_based_on_time(time_source);
139        // Convert accumulated fractional tokens to whole tokens
140        self.convert_fractional_tokens();
141
142        let retry_cost = if err == &ErrorKind::TransientError {
143            self.timeout_retry_cost
144        } else {
145            self.retry_cost
146        };
147
148        self.semaphore
149            .clone()
150            .try_acquire_many_owned(retry_cost)
151            .ok()
152    }
153
154    pub(crate) fn success_reward(&self) -> f32 {
155        self.success_reward
156    }
157
158    pub(crate) fn regenerate_a_token(&self) {
159        self.add_permits(PERMIT_REGENERATION_AMOUNT);
160    }
161
162    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
163    /// Stores the remaining fractional amount back.
164    /// This is shared by both time-based refill and success rewards.
165    #[inline]
166    fn convert_fractional_tokens(&self) {
167        let mut calc_fractional_tokens = self.fractional_tokens.load();
168        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
169        if !calc_fractional_tokens.is_finite() {
170            tracing::error!(
171                "Fractional tokens corrupted to: {}, resetting to 0.0",
172                calc_fractional_tokens
173            );
174            self.fractional_tokens.store(0.0);
175            return;
176        }
177
178        let full_tokens_accumulated = calc_fractional_tokens.floor();
179        if full_tokens_accumulated >= 1.0 {
180            self.add_permits(full_tokens_accumulated as usize);
181            calc_fractional_tokens -= full_tokens_accumulated;
182        }
183        // Always store the updated fractional tokens back, even if no conversion happened
184        self.fractional_tokens.store(calc_fractional_tokens);
185    }
186
187    /// Refills tokens based on elapsed time since last refill.
188    /// This method implements lazy evaluation - tokens are only calculated when accessed.
189    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
190    #[inline]
191    fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
192        if self.refill_rate > 0.0 {
193            // The cast to u32 here is safe until 2106, and I will be long dead then so ¯\_(ツ)_/¯
194            let current_time_secs = time_source
195                .now()
196                .duration_since(SystemTime::UNIX_EPOCH)
197                .unwrap_or(Duration::ZERO)
198                .as_secs() as u32;
199
200            let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
201
202            // Early exit if no time elapsed - most threads take this path
203            if current_time_secs == last_refill_secs {
204                return;
205            }
206
207            // Try to atomically claim this time window with a single CAS
208            // If we lose, another thread is handling the refill, so we can exit
209            if self
210                .last_refill_time_secs
211                .compare_exchange(
212                    last_refill_secs,
213                    current_time_secs,
214                    Ordering::Relaxed,
215                    Ordering::Relaxed,
216                )
217                .is_err()
218            {
219                // Another thread claimed this time window, we're done
220                return;
221            }
222
223            // We won the CAS - we're responsible for adding tokens for this time window
224            let current_fractional = self.fractional_tokens.load();
225            let max_fractional = self.max_permits as f32;
226
227            // Skip token addition if already at cap
228            if current_fractional >= max_fractional {
229                return;
230            }
231
232            let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
233            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
234
235            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
236            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
237            self.fractional_tokens.store(new_fractional);
238        }
239    }
240
241    #[inline]
242    pub(crate) fn reward_success(&self) {
243        if self.success_reward > 0.0 {
244            let current = self.fractional_tokens.load();
245            let max_fractional = self.max_permits as f32;
246            // Early exit if already at cap - no point calculating
247            if current >= max_fractional {
248                return;
249            }
250            // Cap fractional tokens at max_permits to prevent unbounded growth
251            let new_fractional = (current + self.success_reward).min(max_fractional);
252            self.fractional_tokens.store(new_fractional);
253        }
254    }
255
256    pub(crate) fn add_permits(&self, amount: usize) {
257        let available = self.semaphore.available_permits();
258        if available >= self.max_permits {
259            return;
260        }
261        self.semaphore
262            .add_permits(amount.min(self.max_permits - available));
263    }
264
265    /// Returns true if the token bucket is full, false otherwise
266    pub fn is_full(&self) -> bool {
267        self.semaphore.available_permits() >= self.max_permits
268    }
269
270    /// Returns true if the token bucket is empty, false otherwise
271    pub fn is_empty(&self) -> bool {
272        self.semaphore.available_permits() == 0
273    }
274
275    #[allow(dead_code)] // only used in tests
276    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
277    pub(crate) fn available_permits(&self) -> usize {
278        self.semaphore.available_permits()
279    }
280
281    /// Only used in tests
282    #[allow(dead_code)]
283    #[doc(hidden)]
284    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
285    pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
286        self.last_refill_time_secs.clone()
287    }
288}
289
290/// Builder for constructing a `TokenBucket`.
291#[derive(Clone, Debug, Default)]
292pub struct TokenBucketBuilder {
293    capacity: Option<usize>,
294    retry_cost: Option<u32>,
295    timeout_retry_cost: Option<u32>,
296    success_reward: Option<f32>,
297    refill_rate: Option<f32>,
298}
299
300impl TokenBucketBuilder {
301    /// Creates a new `TokenBucketBuilder` with default values.
302    pub fn new() -> Self {
303        Self::default()
304    }
305
306    /// Sets the maximum bucket capacity for the builder.
307    pub fn capacity(mut self, mut capacity: usize) -> Self {
308        if capacity > MAXIMUM_CAPACITY {
309            capacity = MAXIMUM_CAPACITY;
310        }
311        self.capacity = Some(capacity);
312        self
313    }
314
315    /// Sets the specified retry cost for the builder.
316    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
317        self.retry_cost = Some(retry_cost);
318        self
319    }
320
321    /// Sets the specified timeout retry cost for the builder.
322    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
323        self.timeout_retry_cost = Some(timeout_retry_cost);
324        self
325    }
326
327    /// Sets the reward for any successful request for the builder.
328    pub fn success_reward(mut self, reward: f32) -> Self {
329        self.success_reward = Some(reward);
330        self
331    }
332
333    /// Sets the refill rate (tokens per second) for time-based token regeneration.
334    ///
335    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
336    /// Non-finite values (NaN, infinity) are treated as 0.0.
337    pub fn refill_rate(mut self, rate: f32) -> Self {
338        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
339        self.refill_rate = Some(validated_rate);
340        self
341    }
342
343    /// Builds a `TokenBucket`.
344    pub fn build(self) -> TokenBucket {
345        TokenBucket {
346            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
347            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
348            retry_cost: self.retry_cost.unwrap_or(DEFAULT_RETRY_COST),
349            timeout_retry_cost: self
350                .timeout_retry_cost
351                .unwrap_or(DEFAULT_RETRY_TIMEOUT_COST),
352            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
353            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
354            refill_rate: self.refill_rate.unwrap_or(0.0),
355            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
356        }
357    }
358}
359
360#[cfg(test)]
361mod tests {
362
363    use super::*;
364    use aws_smithy_async::test_util::ManualTimeSource;
365    use std::{sync::LazyLock, time::UNIX_EPOCH};
366
367    static TIME_SOURCE: LazyLock<ManualTimeSource> =
368        LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
369
370    #[test]
371    fn test_unlimited_token_bucket() {
372        let bucket = TokenBucket::unlimited();
373
374        // Should always acquire permits regardless of error type
375        assert!(bucket
376            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
377            .is_some());
378        assert!(bucket
379            .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
380            .is_some());
381
382        // Should have maximum capacity
383        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
384
385        // Should have zero retry costs
386        assert_eq!(bucket.retry_cost, 0);
387        assert_eq!(bucket.timeout_retry_cost, 0);
388
389        // The loop count is arbitrary; should obtain permits without limit
390        let mut permits = Vec::new();
391        for _ in 0..100 {
392            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
393            assert!(permit.is_some());
394            permits.push(permit);
395            // Available permits should stay constant
396            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
397        }
398    }
399
400    #[test]
401    fn test_bounded_permits_exhaustion() {
402        let bucket = TokenBucket::new(10);
403        let mut permits = Vec::new();
404
405        for _ in 0..100 {
406            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
407            if let Some(p) = permit {
408                permits.push(p);
409            } else {
410                break;
411            }
412        }
413
414        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
415
416        // Verify next acquisition fails
417        assert!(bucket
418            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
419            .is_none());
420    }
421
422    #[test]
423    fn test_fractional_tokens_accumulate_and_convert() {
424        let bucket = TokenBucket::builder()
425            .capacity(10)
426            .success_reward(0.4)
427            .build();
428
429        // acquire 10 tokens to bring capacity below max so we can test accumulation
430        let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
431        assert_eq!(bucket.semaphore.available_permits(), 0);
432
433        // First success: 0.4 fractional tokens
434        bucket.reward_success();
435        bucket.convert_fractional_tokens();
436        assert_eq!(bucket.semaphore.available_permits(), 0);
437
438        // Second success: 0.8 fractional tokens
439        bucket.reward_success();
440        bucket.convert_fractional_tokens();
441        assert_eq!(bucket.semaphore.available_permits(), 0);
442
443        // Third success: 1.2 fractional tokens -> 1 full token added
444        bucket.reward_success();
445        bucket.convert_fractional_tokens();
446        assert_eq!(bucket.semaphore.available_permits(), 1);
447    }
448
449    #[test]
450    fn test_fractional_tokens_respect_max_capacity() {
451        let bucket = TokenBucket::builder()
452            .capacity(10)
453            .success_reward(2.0)
454            .build();
455
456        for _ in 0..20 {
457            bucket.reward_success();
458        }
459
460        assert!(bucket.semaphore.available_permits() == 10);
461    }
462
463    #[test]
464    fn test_convert_fractional_tokens() {
465        // (input, expected_permits_added, expected_remaining)
466        let test_cases = [
467            (0.7, 0, 0.7),
468            (1.0, 1, 0.0),
469            (2.3, 2, 0.3),
470            (5.8, 5, 0.8),
471            (10.0, 10, 0.0),
472            // verify that if fractional permits are corrupted, we reset to 0 gracefully
473            (f32::NAN, 0, 0.0),
474            (f32::INFINITY, 0, 0.0),
475        ];
476
477        for (input, expected_permits, expected_remaining) in test_cases {
478            let bucket = TokenBucket::builder().capacity(10).build();
479            let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
480            let initial = bucket.semaphore.available_permits();
481
482            bucket.fractional_tokens.store(input);
483            bucket.convert_fractional_tokens();
484
485            assert_eq!(
486                bucket.semaphore.available_permits() - initial,
487                expected_permits
488            );
489            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
490        }
491    }
492
493    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
494    #[test]
495    fn test_builder_with_custom_values() {
496        let bucket = TokenBucket::builder()
497            .capacity(100)
498            .retry_cost(10)
499            .timeout_retry_cost(20)
500            .success_reward(0.5)
501            .refill_rate(2.5)
502            .build();
503
504        assert_eq!(bucket.max_permits, 100);
505        assert_eq!(bucket.retry_cost, 10);
506        assert_eq!(bucket.timeout_retry_cost, 20);
507        assert_eq!(bucket.success_reward, 0.5);
508        assert_eq!(bucket.refill_rate, 2.5);
509    }
510
511    #[test]
512    fn test_builder_refill_rate_validation() {
513        // Test negative values are clamped to 0.0
514        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
515        assert_eq!(bucket.refill_rate, 0.0);
516
517        // Test valid positive value
518        let bucket = TokenBucket::builder().refill_rate(1.5).build();
519        assert_eq!(bucket.refill_rate, 1.5);
520
521        // Test zero is valid
522        let bucket = TokenBucket::builder().refill_rate(0.0).build();
523        assert_eq!(bucket.refill_rate, 0.0);
524    }
525
526    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
527    #[test]
528    fn test_builder_custom_time_source() {
529        use aws_smithy_async::test_util::ManualTimeSource;
530        use std::time::UNIX_EPOCH;
531
532        // Test that TokenBucket uses provided TimeSource when specified via builder
533        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
534        let bucket = TokenBucket::builder()
535            .capacity(100)
536            .refill_rate(1.0)
537            .build();
538
539        // Consume all tokens to test refill from empty state
540        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
541        assert_eq!(bucket.available_permits(), 0);
542
543        // Advance time and verify tokens are added based on manual time
544        manual_time.advance(Duration::from_secs(5));
545
546        bucket.refill_tokens_based_on_time(&manual_time);
547        bucket.convert_fractional_tokens();
548
549        // Should have 5 tokens (5 seconds * 1 token/sec)
550        assert_eq!(bucket.available_permits(), 5);
551    }
552
553    #[test]
554    fn test_atomicf32_f32_to_bits_conversion_correctness() {
555        // This is the core functionality
556        let test_values = vec![
557            0.0,
558            -0.0,
559            1.0,
560            -1.0,
561            f32::INFINITY,
562            f32::NEG_INFINITY,
563            f32::NAN,
564            f32::MIN,
565            f32::MAX,
566            f32::MIN_POSITIVE,
567            f32::EPSILON,
568            std::f32::consts::PI,
569            std::f32::consts::E,
570            // Test values that could expose bit manipulation bugs
571            1.23456789e-38, // Very small normal number
572            1.23456789e38,  // Very large number (within f32 range)
573            1.1754944e-38,  // Near MIN_POSITIVE for f32
574        ];
575
576        for &expected in &test_values {
577            let atomic = AtomicF32::new(expected);
578            let actual = atomic.load();
579
580            // For NaN, we can't use == but must check bit patterns
581            if expected.is_nan() {
582                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
583                // Different NaN bit patterns should be preserved exactly
584                assert_eq!(expected.to_bits(), actual.to_bits());
585            } else {
586                assert_eq!(expected.to_bits(), actual.to_bits());
587            }
588        }
589    }
590
591    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
592    #[test]
593    fn test_atomicf32_store_load_preserves_exact_bits() {
594        let atomic = AtomicF32::new(0.0);
595
596        // Test that store/load cycle preserves EXACT bit patterns
597        // This would catch bugs in the to_bits/from_bits conversion
598        let critical_bit_patterns = vec![
599            0x00000000u32, // +0.0
600            0x80000000u32, // -0.0
601            0x7F800000u32, // +infinity
602            0xFF800000u32, // -infinity
603            0x7FC00000u32, // Quiet NaN
604            0x7FA00000u32, // Signaling NaN
605            0x00000001u32, // Smallest positive subnormal
606            0x007FFFFFu32, // Largest subnormal
607            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
608        ];
609
610        for &expected_bits in &critical_bit_patterns {
611            let expected_f32 = f32::from_bits(expected_bits);
612            atomic.store(expected_f32);
613            let loaded_f32 = atomic.load();
614            let actual_bits = loaded_f32.to_bits();
615
616            assert_eq!(expected_bits, actual_bits);
617        }
618    }
619
620    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
621    #[test]
622    fn test_atomicf32_concurrent_store_load_safety() {
623        use std::sync::Arc;
624        use std::thread;
625
626        let atomic = Arc::new(AtomicF32::new(0.0));
627        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
628        let mut handles = Vec::new();
629
630        // Start multiple threads that continuously write different values
631        for &value in &test_values {
632            let atomic_clone = Arc::clone(&atomic);
633            let handle = thread::spawn(move || {
634                for _ in 0..1000 {
635                    atomic_clone.store(value);
636                }
637            });
638            handles.push(handle);
639        }
640
641        // Start a reader thread that continuously reads
642        let atomic_reader = Arc::clone(&atomic);
643        let reader_handle = thread::spawn(move || {
644            let mut readings = Vec::new();
645            for _ in 0..5000 {
646                let value = atomic_reader.load();
647                readings.push(value);
648            }
649            readings
650        });
651
652        // Wait for all writers to complete
653        for handle in handles {
654            handle.join().expect("Writer thread panicked");
655        }
656
657        let readings = reader_handle.join().expect("Reader thread panicked");
658
659        // Verify that all read values are valid (one of the written values)
660        // This tests that there's no data corruption from concurrent access
661        for &reading in &readings {
662            assert!(test_values.contains(&reading) || reading == 0.0);
663
664            // More importantly, verify the reading is a valid f32
665            // (not corrupted bits that happen to parse as valid)
666            assert!(
667                reading.is_finite() || reading == 0.0,
668                "Corrupted reading detected"
669            );
670        }
671    }
672
673    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
674    #[test]
675    fn test_atomicf32_stress_concurrent_access() {
676        use std::sync::{Arc, Barrier};
677        use std::thread;
678
679        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
680        let atomic = Arc::new(AtomicF32::new(0.0));
681        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
682        let mut handles = Vec::new();
683
684        // Launch threads that all start simultaneously
685        for i in 0..10 {
686            let atomic_clone = Arc::clone(&atomic);
687            let barrier_clone = Arc::clone(&barrier);
688            let handle = thread::spawn(move || {
689                barrier_clone.wait(); // All threads start at same time
690
691                // Tight loop increases chance of race conditions
692                for _ in 0..10000 {
693                    let value = i as f32;
694                    atomic_clone.store(value);
695                    let loaded = atomic_clone.load();
696                    // Verify no corruption occurred
697                    assert!(loaded >= 0.0 && loaded <= 9.0);
698                    assert!(
699                        expected_values.contains(&loaded),
700                        "Got unexpected value: {}, expected one of {:?}",
701                        loaded,
702                        expected_values
703                    );
704                }
705            });
706            handles.push(handle);
707        }
708
709        for handle in handles {
710            handle.join().unwrap();
711        }
712    }
713
714    #[test]
715    fn test_atomicf32_integration_with_token_bucket_usage() {
716        let atomic = AtomicF32::new(0.0);
717        let success_reward = 0.3;
718        let iterations = 5;
719
720        // Accumulate fractional tokens
721        for _ in 1..=iterations {
722            let current = atomic.load();
723            atomic.store(current + success_reward);
724        }
725
726        let accumulated = atomic.load();
727        let expected_total = iterations as f32 * success_reward; // 1.5
728
729        // Test the floor() operation pattern
730        let full_tokens = accumulated.floor();
731        atomic.store(accumulated - full_tokens);
732        let remaining = atomic.load();
733
734        // These assertions should be general:
735        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
736        assert!(remaining >= 0.0 && remaining < 1.0);
737        assert_eq!(remaining, expected_total - expected_total.floor());
738    }
739
740    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
741    #[test]
742    fn test_atomicf32_clone_creates_independent_copy() {
743        let original = AtomicF32::new(123.456);
744        let cloned = original.clone();
745
746        // Verify they start with the same value
747        assert_eq!(original.load(), cloned.load());
748
749        // Verify they're independent - modifying one doesn't affect the other
750        original.store(999.0);
751        assert_eq!(
752            cloned.load(),
753            123.456,
754            "Clone should be unaffected by original changes"
755        );
756        assert_eq!(original.load(), 999.0, "Original should have new value");
757    }
758
759    #[test]
760    fn test_combined_time_and_success_rewards() {
761        use aws_smithy_async::test_util::ManualTimeSource;
762        use std::time::UNIX_EPOCH;
763
764        let time_source = ManualTimeSource::new(UNIX_EPOCH);
765        let current_time_secs = UNIX_EPOCH
766            .duration_since(SystemTime::UNIX_EPOCH)
767            .unwrap()
768            .as_secs() as u32;
769
770        let bucket = TokenBucket {
771            refill_rate: 1.0,
772            success_reward: 0.5,
773            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
774            semaphore: Arc::new(Semaphore::new(0)),
775            max_permits: 100,
776            ..Default::default()
777        };
778
779        // Add success rewards: 2 * 0.5 = 1.0 token
780        bucket.reward_success();
781        bucket.reward_success();
782
783        // Advance time by 2 seconds
784        time_source.advance(Duration::from_secs(2));
785
786        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
787        // Total: 1.0 + 2.0 = 3.0 tokens
788        bucket.refill_tokens_based_on_time(&time_source);
789        bucket.convert_fractional_tokens();
790
791        assert_eq!(bucket.available_permits(), 3);
792        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
793    }
794
795    #[test]
796    fn test_refill_rates() {
797        use aws_smithy_async::test_util::ManualTimeSource;
798        use std::time::UNIX_EPOCH;
799        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
800        let test_cases = [
801            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
802            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
803            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
804            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
805            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
806            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
807        ];
808
809        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
810            let time_source = ManualTimeSource::new(UNIX_EPOCH);
811            let current_time_secs = UNIX_EPOCH
812                .duration_since(SystemTime::UNIX_EPOCH)
813                .unwrap()
814                .as_secs() as u32;
815
816            let bucket = TokenBucket {
817                refill_rate,
818                last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
819                semaphore: Arc::new(Semaphore::new(0)),
820                max_permits: 100,
821                ..Default::default()
822            };
823
824            // Advance time by the specified duration
825            time_source.advance(Duration::from_secs(elapsed_secs));
826
827            bucket.refill_tokens_based_on_time(&time_source);
828            bucket.convert_fractional_tokens();
829
830            assert_eq!(
831                bucket.available_permits(),
832                expected_permits,
833                "Rate {}: After {}s expected {} permits",
834                refill_rate,
835                elapsed_secs,
836                expected_permits
837            );
838            assert!(
839                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
840                "Rate {}: After {}s expected {} fractional, got {}",
841                refill_rate,
842                elapsed_secs,
843                expected_fractional,
844                bucket.fractional_tokens.load()
845            );
846        }
847    }
848
849    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
850    #[test]
851    fn test_rewards_capped_at_max_capacity() {
852        use aws_smithy_async::test_util::ManualTimeSource;
853        use std::time::UNIX_EPOCH;
854
855        let time_source = ManualTimeSource::new(UNIX_EPOCH);
856        let current_time_secs = UNIX_EPOCH
857            .duration_since(SystemTime::UNIX_EPOCH)
858            .unwrap()
859            .as_secs() as u32;
860
861        let bucket = TokenBucket {
862            refill_rate: 50.0,
863            success_reward: 2.0,
864            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
865            semaphore: Arc::new(Semaphore::new(5)),
866            max_permits: 10,
867            ..Default::default()
868        };
869
870        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
871        for _ in 0..50 {
872            bucket.reward_success();
873        }
874
875        // Fractional tokens capped at 10 from success rewards
876        assert_eq!(bucket.fractional_tokens.load(), 10.0);
877
878        // Advance time by 100 seconds
879        time_source.advance(Duration::from_secs(100));
880
881        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
882        // But fractional is already at 10, so it stays at 10
883        bucket.refill_tokens_based_on_time(&time_source);
884
885        // Fractional tokens should be capped at max_permits (10)
886        assert_eq!(
887            bucket.fractional_tokens.load(),
888            10.0,
889            "Fractional tokens should be capped at max_permits"
890        );
891        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
892        bucket.convert_fractional_tokens();
893        assert_eq!(bucket.available_permits(), 10);
894    }
895
896    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
897    #[test]
898    fn test_concurrent_time_based_refill_no_over_generation() {
899        use aws_smithy_async::test_util::ManualTimeSource;
900        use std::sync::{Arc, Barrier};
901        use std::thread;
902        use std::time::UNIX_EPOCH;
903
904        let time_source = ManualTimeSource::new(UNIX_EPOCH);
905        let current_time_secs = UNIX_EPOCH
906            .duration_since(SystemTime::UNIX_EPOCH)
907            .unwrap()
908            .as_secs() as u32;
909
910        // Create bucket with 1 token/sec refill
911        let bucket = Arc::new(TokenBucket {
912            refill_rate: 1.0,
913            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
914            semaphore: Arc::new(Semaphore::new(0)),
915            max_permits: 100,
916            ..Default::default()
917        });
918
919        // Advance time by 10 seconds
920        time_source.advance(Duration::from_secs(10));
921        let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
922
923        // Launch 100 threads that all try to refill simultaneously
924        let barrier = Arc::new(Barrier::new(100));
925        let mut handles = Vec::new();
926
927        for _ in 0..100 {
928            let bucket_clone1 = Arc::clone(&bucket);
929            let barrier_clone1 = Arc::clone(&barrier);
930            let time_source_clone1 = shared_time_source.clone();
931            let bucket_clone2 = Arc::clone(&bucket);
932            let barrier_clone2 = Arc::clone(&barrier);
933            let time_source_clone2 = shared_time_source.clone();
934
935            let handle1 = thread::spawn(move || {
936                // Wait for all threads to be ready
937                barrier_clone1.wait();
938
939                // All threads call refill at the same time
940                bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
941            });
942
943            let handle2 = thread::spawn(move || {
944                // Wait for all threads to be ready
945                barrier_clone2.wait();
946
947                // All threads call refill at the same time
948                bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
949            });
950            handles.push(handle1);
951            handles.push(handle2);
952        }
953
954        // Wait for all threads to complete
955        for handle in handles {
956            handle.join().unwrap();
957        }
958
959        // Convert fractional tokens to whole tokens
960        bucket.convert_fractional_tokens();
961
962        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
963        // Not 1000 tokens (100 threads * 10 tokens each)
964        assert_eq!(
965            bucket.available_permits(),
966            10,
967            "Only one thread should have added tokens, not all 100"
968        );
969
970        // Fractional should be 0 after conversion
971        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
972    }
973}