aws_smithy_runtime/client/retries/
token_bucket.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_async::time::TimeSource;
7use aws_smithy_types::config_bag::{Storable, StoreReplace};
8use aws_smithy_types::retry::ErrorKind;
9use std::fmt;
10use std::sync::atomic::AtomicU32;
11use std::sync::atomic::Ordering;
12use std::sync::Arc;
13use std::time::{Duration, SystemTime};
14use tokio::sync::{OwnedSemaphorePermit, Semaphore};
15
16pub(crate) const DEFAULT_CAPACITY: usize = 500;
17// On a 32 bit architecture, the value of Semaphore::MAX_PERMITS is 536,870,911.
18// Therefore, we will enforce a value lower than that to ensure behavior is
19// identical across platforms.
20// This also allows room for slight bucket overfill in the case where a bucket
21// is at maximum capacity and another thread drops a permit it was holding.
22/// The maximum number of permits a token bucket can have.
23pub const MAXIMUM_CAPACITY: usize = 500_000_000;
24#[allow(dead_code)]
25pub(crate) const DEFAULT_RETRY_COST: u32 = 14;
26#[allow(dead_code)]
27pub(crate) const DEFAULT_RETRY_TIMEOUT_COST: u32 = 14;
28#[allow(dead_code)]
29pub(crate) const THROTTLING_RETRY_COST: u32 = 5;
30
31// Legacy (Retry 2.0) costs
32const LEGACY_RETRY_COST: u32 = 5;
33const LEGACY_RETRY_TIMEOUT_COST: u32 = LEGACY_RETRY_COST * 2;
34const PERMIT_REGENERATION_AMOUNT: usize = 1;
35const DEFAULT_SUCCESS_REWARD: f32 = 0.0;
36
37/// Token bucket used for standard and adaptive retry.
38#[derive(Clone, Debug)]
39pub struct TokenBucket {
40    semaphore: Arc<Semaphore>,
41    max_permits: usize,
42    timeout_retry_cost: u32,
43    retry_cost: u32,
44    throttling_retry_cost: u32,
45    success_reward: f32,
46    fractional_tokens: Arc<AtomicF32>,
47    refill_rate: f32,
48    // Note this value is only an AtomicU32 so it works on 32bit powerpc architectures.
49    // If we ever remove the need for that compatibility it should become an AtomicU64
50    last_refill_time_secs: Arc<AtomicU32>,
51}
52
53impl std::panic::UnwindSafe for AtomicF32 {}
54impl std::panic::RefUnwindSafe for AtomicF32 {}
55struct AtomicF32 {
56    storage: AtomicU32,
57}
58impl AtomicF32 {
59    fn new(value: f32) -> Self {
60        let as_u32 = value.to_bits();
61        Self {
62            storage: AtomicU32::new(as_u32),
63        }
64    }
65    fn store(&self, value: f32) {
66        let as_u32 = value.to_bits();
67        self.storage.store(as_u32, Ordering::Relaxed)
68    }
69    fn load(&self) -> f32 {
70        let as_u32 = self.storage.load(Ordering::Relaxed);
71        f32::from_bits(as_u32)
72    }
73}
74
75impl fmt::Debug for AtomicF32 {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        // Use debug_struct, debug_tuple, or write! for formatting
78        f.debug_struct("AtomicF32")
79            .field("value", &self.load())
80            .finish()
81    }
82}
83
84impl Clone for AtomicF32 {
85    fn clone(&self) -> Self {
86        // Manually clone each field
87        AtomicF32 {
88            storage: AtomicU32::new(self.storage.load(Ordering::Relaxed)),
89        }
90    }
91}
92
93impl Storable for TokenBucket {
94    type Storer = StoreReplace<Self>;
95}
96
97impl Default for TokenBucket {
98    fn default() -> Self {
99        Self {
100            semaphore: Arc::new(Semaphore::new(DEFAULT_CAPACITY)),
101            max_permits: DEFAULT_CAPACITY,
102            timeout_retry_cost: LEGACY_RETRY_TIMEOUT_COST,
103            retry_cost: LEGACY_RETRY_COST,
104            throttling_retry_cost: LEGACY_RETRY_COST,
105            success_reward: DEFAULT_SUCCESS_REWARD,
106            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
107            refill_rate: 0.0,
108            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
109        }
110    }
111}
112
113impl TokenBucket {
114    /// Creates a new `TokenBucket` with the given initial quota.
115    pub fn new(initial_quota: usize) -> Self {
116        Self {
117            semaphore: Arc::new(Semaphore::new(initial_quota)),
118            max_permits: initial_quota,
119            ..Default::default()
120        }
121    }
122
123    /// A token bucket with unlimited capacity that allows retries at no cost.
124    pub fn unlimited() -> Self {
125        Self {
126            semaphore: Arc::new(Semaphore::new(MAXIMUM_CAPACITY)),
127            max_permits: MAXIMUM_CAPACITY,
128            timeout_retry_cost: 0,
129            retry_cost: 0,
130            throttling_retry_cost: 0,
131            success_reward: 0.0,
132            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
133            refill_rate: 0.0,
134            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
135        }
136    }
137
138    /// Creates a builder for constructing a `TokenBucket`.
139    pub fn builder() -> TokenBucketBuilder {
140        TokenBucketBuilder::default()
141    }
142
143    pub(crate) fn acquire(
144        &self,
145        err: &ErrorKind,
146        time_source: &impl TimeSource,
147    ) -> Option<OwnedSemaphorePermit> {
148        // Add time-based tokens to fractional accumulator
149        self.refill_tokens_based_on_time(time_source);
150        // Convert accumulated fractional tokens to whole tokens
151        self.convert_fractional_tokens();
152
153        let retry_cost = match err {
154            ErrorKind::TransientError => self.timeout_retry_cost,
155            ErrorKind::ThrottlingError => self.throttling_retry_cost,
156            _ => self.retry_cost,
157        };
158
159        self.semaphore
160            .clone()
161            .try_acquire_many_owned(retry_cost)
162            .ok()
163    }
164
165    pub(crate) fn success_reward(&self) -> f32 {
166        self.success_reward
167    }
168
169    pub(crate) fn regenerate_a_token(&self) {
170        self.add_permits(PERMIT_REGENERATION_AMOUNT);
171    }
172
173    /// Converts accumulated fractional tokens to whole tokens and adds them as permits.
174    /// Stores the remaining fractional amount back.
175    /// This is shared by both time-based refill and success rewards.
176    #[inline]
177    fn convert_fractional_tokens(&self) {
178        let mut calc_fractional_tokens = self.fractional_tokens.load();
179        // Verify that fractional tokens have not become corrupted - if they have, reset to zero
180        if !calc_fractional_tokens.is_finite() {
181            tracing::error!(
182                "Fractional tokens corrupted to: {}, resetting to 0.0",
183                calc_fractional_tokens
184            );
185            self.fractional_tokens.store(0.0);
186            return;
187        }
188
189        let full_tokens_accumulated = calc_fractional_tokens.floor();
190        if full_tokens_accumulated >= 1.0 {
191            self.add_permits(full_tokens_accumulated as usize);
192            calc_fractional_tokens -= full_tokens_accumulated;
193        }
194        // Always store the updated fractional tokens back, even if no conversion happened
195        self.fractional_tokens.store(calc_fractional_tokens);
196    }
197
198    /// Refills tokens based on elapsed time since last refill.
199    /// This method implements lazy evaluation - tokens are only calculated when accessed.
200    /// Uses a single compare-and-swap to ensure only one thread processes each time window.
201    #[inline]
202    fn refill_tokens_based_on_time(&self, time_source: &impl TimeSource) {
203        if self.refill_rate > 0.0 {
204            // The cast to u32 here is safe until 2106, and I will be long dead then so ¯\_(ツ)_/¯
205            let current_time_secs = time_source
206                .now()
207                .duration_since(SystemTime::UNIX_EPOCH)
208                .unwrap_or(Duration::ZERO)
209                .as_secs() as u32;
210
211            let last_refill_secs = self.last_refill_time_secs.load(Ordering::Relaxed);
212
213            // Early exit if no time elapsed - most threads take this path
214            if current_time_secs == last_refill_secs {
215                return;
216            }
217
218            // Try to atomically claim this time window with a single CAS
219            // If we lose, another thread is handling the refill, so we can exit
220            if self
221                .last_refill_time_secs
222                .compare_exchange(
223                    last_refill_secs,
224                    current_time_secs,
225                    Ordering::Relaxed,
226                    Ordering::Relaxed,
227                )
228                .is_err()
229            {
230                // Another thread claimed this time window, we're done
231                return;
232            }
233
234            // We won the CAS - we're responsible for adding tokens for this time window
235            let current_fractional = self.fractional_tokens.load();
236            let max_fractional = self.max_permits as f32;
237
238            // Skip token addition if already at cap
239            if current_fractional >= max_fractional {
240                return;
241            }
242
243            let elapsed_secs = current_time_secs.saturating_sub(last_refill_secs);
244            let tokens_to_add = elapsed_secs as f32 * self.refill_rate;
245
246            // Add tokens to fractional accumulator, capping at max_permits to prevent unbounded growth
247            let new_fractional = (current_fractional + tokens_to_add).min(max_fractional);
248            self.fractional_tokens.store(new_fractional);
249        }
250    }
251
252    #[inline]
253    pub(crate) fn reward_success(&self) {
254        if self.success_reward > 0.0 {
255            let current = self.fractional_tokens.load();
256            let max_fractional = self.max_permits as f32;
257            // Early exit if already at cap - no point calculating
258            if current >= max_fractional {
259                return;
260            }
261            // Cap fractional tokens at max_permits to prevent unbounded growth
262            let new_fractional = (current + self.success_reward).min(max_fractional);
263            self.fractional_tokens.store(new_fractional);
264        }
265    }
266
267    pub(crate) fn add_permits(&self, amount: usize) {
268        let available = self.semaphore.available_permits();
269        if available >= self.max_permits {
270            return;
271        }
272        self.semaphore
273            .add_permits(amount.min(self.max_permits - available));
274    }
275
276    /// Returns true if the token bucket is full, false otherwise
277    pub fn is_full(&self) -> bool {
278        self.convert_fractional_tokens();
279        self.semaphore.available_permits() >= self.max_permits
280    }
281
282    /// Returns true if the token bucket is empty, false otherwise
283    pub fn is_empty(&self) -> bool {
284        self.convert_fractional_tokens();
285        self.semaphore.available_permits() == 0
286    }
287
288    #[allow(dead_code)] // only used in tests
289    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
290    pub(crate) fn available_permits(&self) -> usize {
291        self.semaphore.available_permits()
292    }
293
294    /// Only used in tests
295    #[allow(dead_code)]
296    #[doc(hidden)]
297    #[cfg(any(test, feature = "test-util", feature = "legacy-test-util"))]
298    pub fn last_refill_time_secs(&self) -> Arc<AtomicU32> {
299        self.last_refill_time_secs.clone()
300    }
301}
302
303/// Builder for constructing a `TokenBucket`.
304#[derive(Clone, Debug, Default)]
305pub struct TokenBucketBuilder {
306    capacity: Option<usize>,
307    retry_cost: Option<u32>,
308    throttling_retry_cost: Option<u32>,
309    timeout_retry_cost: Option<u32>,
310    success_reward: Option<f32>,
311    refill_rate: Option<f32>,
312}
313
314impl TokenBucketBuilder {
315    /// Creates a new `TokenBucketBuilder` with default values.
316    pub fn new() -> Self {
317        Self::default()
318    }
319
320    /// Sets the maximum bucket capacity for the builder.
321    pub fn capacity(mut self, mut capacity: usize) -> Self {
322        if capacity > MAXIMUM_CAPACITY {
323            capacity = MAXIMUM_CAPACITY;
324        }
325        self.capacity = Some(capacity);
326        self
327    }
328
329    /// Sets the specified retry cost for the builder.
330    pub fn retry_cost(mut self, retry_cost: u32) -> Self {
331        self.retry_cost = Some(retry_cost);
332        self
333    }
334
335    /// Sets the throttling retry cost for the builder.
336    pub fn throttling_retry_cost(mut self, throttling_retry_cost: u32) -> Self {
337        self.throttling_retry_cost = Some(throttling_retry_cost);
338        self
339    }
340
341    /// Sets the specified timeout retry cost for the builder.
342    pub fn timeout_retry_cost(mut self, timeout_retry_cost: u32) -> Self {
343        self.timeout_retry_cost = Some(timeout_retry_cost);
344        self
345    }
346
347    /// Sets the reward for any successful request for the builder.
348    pub fn success_reward(mut self, reward: f32) -> Self {
349        self.success_reward = Some(reward);
350        self
351    }
352
353    /// Sets the refill rate (tokens per second) for time-based token regeneration.
354    ///
355    /// Negative values are clamped to 0.0. A refill rate of 0.0 disables time-based regeneration.
356    /// Non-finite values (NaN, infinity) are treated as 0.0.
357    pub fn refill_rate(mut self, rate: f32) -> Self {
358        let validated_rate = if rate.is_finite() { rate.max(0.0) } else { 0.0 };
359        self.refill_rate = Some(validated_rate);
360        self
361    }
362
363    /// Builds a `TokenBucket`.
364    pub fn build(self) -> TokenBucket {
365        TokenBucket {
366            semaphore: Arc::new(Semaphore::new(self.capacity.unwrap_or(DEFAULT_CAPACITY))),
367            max_permits: self.capacity.unwrap_or(DEFAULT_CAPACITY),
368            retry_cost: self.retry_cost.unwrap_or(LEGACY_RETRY_COST),
369            throttling_retry_cost: self.throttling_retry_cost.unwrap_or(LEGACY_RETRY_COST),
370            timeout_retry_cost: self.timeout_retry_cost.unwrap_or(LEGACY_RETRY_TIMEOUT_COST),
371            success_reward: self.success_reward.unwrap_or(DEFAULT_SUCCESS_REWARD),
372            fractional_tokens: Arc::new(AtomicF32::new(0.0)),
373            refill_rate: self.refill_rate.unwrap_or(0.0),
374            last_refill_time_secs: Arc::new(AtomicU32::new(0)),
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381
382    use super::*;
383    use aws_smithy_async::test_util::ManualTimeSource;
384    use std::{sync::LazyLock, time::UNIX_EPOCH};
385
386    static TIME_SOURCE: LazyLock<ManualTimeSource> =
387        LazyLock::new(|| ManualTimeSource::new(UNIX_EPOCH + Duration::from_secs(12344321)));
388
389    #[test]
390    fn test_unlimited_token_bucket() {
391        let bucket = TokenBucket::unlimited();
392
393        // Should always acquire permits regardless of error type
394        assert!(bucket
395            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
396            .is_some());
397        assert!(bucket
398            .acquire(&ErrorKind::TransientError, &*TIME_SOURCE)
399            .is_some());
400
401        // Should have maximum capacity
402        assert_eq!(bucket.max_permits, MAXIMUM_CAPACITY);
403
404        // Should have zero retry costs
405        assert_eq!(bucket.retry_cost, 0);
406        assert_eq!(bucket.timeout_retry_cost, 0);
407
408        // The loop count is arbitrary; should obtain permits without limit
409        let mut permits = Vec::new();
410        for _ in 0..100 {
411            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
412            assert!(permit.is_some());
413            permits.push(permit);
414            // Available permits should stay constant
415            assert_eq!(MAXIMUM_CAPACITY, bucket.semaphore.available_permits());
416        }
417    }
418
419    #[test]
420    fn test_bounded_permits_exhaustion() {
421        let bucket = TokenBucket::new(10);
422        let mut permits = Vec::new();
423
424        for _ in 0..100 {
425            let permit = bucket.acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE);
426            if let Some(p) = permit {
427                permits.push(p);
428            } else {
429                break;
430            }
431        }
432
433        assert_eq!(permits.len(), 2); // 10 capacity / 5 retry cost = 2 permits
434
435        // Verify next acquisition fails
436        assert!(bucket
437            .acquire(&ErrorKind::ThrottlingError, &*TIME_SOURCE)
438            .is_none());
439    }
440
441    #[test]
442    fn test_fractional_tokens_accumulate_and_convert() {
443        let bucket = TokenBucket::builder()
444            .capacity(10)
445            .success_reward(0.4)
446            .build();
447
448        // acquire 10 tokens to bring capacity below max so we can test accumulation
449        let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
450        assert_eq!(bucket.semaphore.available_permits(), 0);
451
452        // First success: 0.4 fractional tokens
453        bucket.reward_success();
454        bucket.convert_fractional_tokens();
455        assert_eq!(bucket.semaphore.available_permits(), 0);
456
457        // Second success: 0.8 fractional tokens
458        bucket.reward_success();
459        bucket.convert_fractional_tokens();
460        assert_eq!(bucket.semaphore.available_permits(), 0);
461
462        // Third success: 1.2 fractional tokens -> 1 full token added
463        bucket.reward_success();
464        bucket.convert_fractional_tokens();
465        assert_eq!(bucket.semaphore.available_permits(), 1);
466    }
467
468    #[test]
469    fn test_fractional_tokens_respect_max_capacity() {
470        let bucket = TokenBucket::builder()
471            .capacity(10)
472            .success_reward(2.0)
473            .build();
474
475        for _ in 0..20 {
476            bucket.reward_success();
477        }
478
479        assert!(bucket.semaphore.available_permits() == 10);
480    }
481
482    #[test]
483    fn test_convert_fractional_tokens() {
484        // (input, expected_permits_added, expected_remaining)
485        let test_cases = [
486            (0.7, 0, 0.7),
487            (1.0, 1, 0.0),
488            (2.3, 2, 0.3),
489            (5.8, 5, 0.8),
490            (10.0, 10, 0.0),
491            // verify that if fractional permits are corrupted, we reset to 0 gracefully
492            (f32::NAN, 0, 0.0),
493            (f32::INFINITY, 0, 0.0),
494        ];
495
496        for (input, expected_permits, expected_remaining) in test_cases {
497            let bucket = TokenBucket::builder().capacity(10).build();
498            let _hold_permit = bucket.acquire(&ErrorKind::TransientError, &*TIME_SOURCE);
499            let initial = bucket.semaphore.available_permits();
500
501            bucket.fractional_tokens.store(input);
502            bucket.convert_fractional_tokens();
503
504            assert_eq!(
505                bucket.semaphore.available_permits() - initial,
506                expected_permits
507            );
508            assert!((bucket.fractional_tokens.load() - expected_remaining).abs() < 0.0001);
509        }
510    }
511
512    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
513    #[test]
514    fn test_builder_with_custom_values() {
515        let bucket = TokenBucket::builder()
516            .capacity(100)
517            .retry_cost(10)
518            .timeout_retry_cost(20)
519            .success_reward(0.5)
520            .refill_rate(2.5)
521            .build();
522
523        assert_eq!(bucket.max_permits, 100);
524        assert_eq!(bucket.retry_cost, 10);
525        assert_eq!(bucket.timeout_retry_cost, 20);
526        assert_eq!(bucket.success_reward, 0.5);
527        assert_eq!(bucket.refill_rate, 2.5);
528    }
529
530    #[test]
531    fn test_builder_refill_rate_validation() {
532        // Test negative values are clamped to 0.0
533        let bucket = TokenBucket::builder().refill_rate(-5.0).build();
534        assert_eq!(bucket.refill_rate, 0.0);
535
536        // Test valid positive value
537        let bucket = TokenBucket::builder().refill_rate(1.5).build();
538        assert_eq!(bucket.refill_rate, 1.5);
539
540        // Test zero is valid
541        let bucket = TokenBucket::builder().refill_rate(0.0).build();
542        assert_eq!(bucket.refill_rate, 0.0);
543    }
544
545    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
546    #[test]
547    fn test_builder_custom_time_source() {
548        use aws_smithy_async::test_util::ManualTimeSource;
549        use std::time::UNIX_EPOCH;
550
551        // Test that TokenBucket uses provided TimeSource when specified via builder
552        let manual_time = ManualTimeSource::new(UNIX_EPOCH);
553        let bucket = TokenBucket::builder()
554            .capacity(100)
555            .refill_rate(1.0)
556            .build();
557
558        // Consume all tokens to test refill from empty state
559        let _permits = bucket.semaphore.try_acquire_many(100).unwrap();
560        assert_eq!(bucket.available_permits(), 0);
561
562        // Advance time and verify tokens are added based on manual time
563        manual_time.advance(Duration::from_secs(5));
564
565        bucket.refill_tokens_based_on_time(&manual_time);
566        bucket.convert_fractional_tokens();
567
568        // Should have 5 tokens (5 seconds * 1 token/sec)
569        assert_eq!(bucket.available_permits(), 5);
570    }
571
572    #[test]
573    fn test_atomicf32_f32_to_bits_conversion_correctness() {
574        // This is the core functionality
575        let test_values = vec![
576            0.0,
577            -0.0,
578            1.0,
579            -1.0,
580            f32::INFINITY,
581            f32::NEG_INFINITY,
582            f32::NAN,
583            f32::MIN,
584            f32::MAX,
585            f32::MIN_POSITIVE,
586            f32::EPSILON,
587            std::f32::consts::PI,
588            std::f32::consts::E,
589            // Test values that could expose bit manipulation bugs
590            1.23456789e-38, // Very small normal number
591            1.23456789e38,  // Very large number (within f32 range)
592            1.1754944e-38,  // Near MIN_POSITIVE for f32
593        ];
594
595        for &expected in &test_values {
596            let atomic = AtomicF32::new(expected);
597            let actual = atomic.load();
598
599            // For NaN, we can't use == but must check bit patterns
600            if expected.is_nan() {
601                assert!(actual.is_nan(), "Expected NaN, got {}", actual);
602                // Different NaN bit patterns should be preserved exactly
603                assert_eq!(expected.to_bits(), actual.to_bits());
604            } else {
605                assert_eq!(expected.to_bits(), actual.to_bits());
606            }
607        }
608    }
609
610    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
611    #[test]
612    fn test_atomicf32_store_load_preserves_exact_bits() {
613        let atomic = AtomicF32::new(0.0);
614
615        // Test that store/load cycle preserves EXACT bit patterns
616        // This would catch bugs in the to_bits/from_bits conversion
617        let critical_bit_patterns = vec![
618            0x00000000u32, // +0.0
619            0x80000000u32, // -0.0
620            0x7F800000u32, // +infinity
621            0xFF800000u32, // -infinity
622            0x7FC00000u32, // Quiet NaN
623            0x7FA00000u32, // Signaling NaN
624            0x00000001u32, // Smallest positive subnormal
625            0x007FFFFFu32, // Largest subnormal
626            0x00800000u32, // Smallest positive normal (MIN_POSITIVE)
627        ];
628
629        for &expected_bits in &critical_bit_patterns {
630            let expected_f32 = f32::from_bits(expected_bits);
631            atomic.store(expected_f32);
632            let loaded_f32 = atomic.load();
633            let actual_bits = loaded_f32.to_bits();
634
635            assert_eq!(expected_bits, actual_bits);
636        }
637    }
638
639    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
640    #[test]
641    fn test_atomicf32_concurrent_store_load_safety() {
642        use std::sync::Arc;
643        use std::thread;
644
645        let atomic = Arc::new(AtomicF32::new(0.0));
646        let test_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
647        let mut handles = Vec::new();
648
649        // Start multiple threads that continuously write different values
650        for &value in &test_values {
651            let atomic_clone = Arc::clone(&atomic);
652            let handle = thread::spawn(move || {
653                for _ in 0..1000 {
654                    atomic_clone.store(value);
655                }
656            });
657            handles.push(handle);
658        }
659
660        // Start a reader thread that continuously reads
661        let atomic_reader = Arc::clone(&atomic);
662        let reader_handle = thread::spawn(move || {
663            let mut readings = Vec::new();
664            for _ in 0..5000 {
665                let value = atomic_reader.load();
666                readings.push(value);
667            }
668            readings
669        });
670
671        // Wait for all writers to complete
672        for handle in handles {
673            handle.join().expect("Writer thread panicked");
674        }
675
676        let readings = reader_handle.join().expect("Reader thread panicked");
677
678        // Verify that all read values are valid (one of the written values)
679        // This tests that there's no data corruption from concurrent access
680        for &reading in &readings {
681            assert!(test_values.contains(&reading) || reading == 0.0);
682
683            // More importantly, verify the reading is a valid f32
684            // (not corrupted bits that happen to parse as valid)
685            assert!(
686                reading.is_finite() || reading == 0.0,
687                "Corrupted reading detected"
688            );
689        }
690    }
691
692    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
693    #[test]
694    fn test_atomicf32_stress_concurrent_access() {
695        use std::sync::{Arc, Barrier};
696        use std::thread;
697
698        let expected_values = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
699        let atomic = Arc::new(AtomicF32::new(0.0));
700        let barrier = Arc::new(Barrier::new(10)); // Synchronize all threads
701        let mut handles = Vec::new();
702
703        // Launch threads that all start simultaneously
704        for i in 0..10 {
705            let atomic_clone = Arc::clone(&atomic);
706            let barrier_clone = Arc::clone(&barrier);
707            let handle = thread::spawn(move || {
708                barrier_clone.wait(); // All threads start at same time
709
710                // Tight loop increases chance of race conditions
711                for _ in 0..10000 {
712                    let value = i as f32;
713                    atomic_clone.store(value);
714                    let loaded = atomic_clone.load();
715                    // Verify no corruption occurred
716                    assert!(loaded >= 0.0 && loaded <= 9.0);
717                    assert!(
718                        expected_values.contains(&loaded),
719                        "Got unexpected value: {}, expected one of {:?}",
720                        loaded,
721                        expected_values
722                    );
723                }
724            });
725            handles.push(handle);
726        }
727
728        for handle in handles {
729            handle.join().unwrap();
730        }
731    }
732
733    #[test]
734    fn test_atomicf32_integration_with_token_bucket_usage() {
735        let atomic = AtomicF32::new(0.0);
736        let success_reward = 0.3;
737        let iterations = 5;
738
739        // Accumulate fractional tokens
740        for _ in 1..=iterations {
741            let current = atomic.load();
742            atomic.store(current + success_reward);
743        }
744
745        let accumulated = atomic.load();
746        let expected_total = iterations as f32 * success_reward; // 1.5
747
748        // Test the floor() operation pattern
749        let full_tokens = accumulated.floor();
750        atomic.store(accumulated - full_tokens);
751        let remaining = atomic.load();
752
753        // These assertions should be general:
754        assert_eq!(full_tokens, expected_total.floor()); // Could be 1.0, 2.0, 3.0, etc.
755        assert!(remaining >= 0.0 && remaining < 1.0);
756        assert_eq!(remaining, expected_total - expected_total.floor());
757    }
758
759    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
760    #[test]
761    fn test_atomicf32_clone_creates_independent_copy() {
762        let original = AtomicF32::new(123.456);
763        let cloned = original.clone();
764
765        // Verify they start with the same value
766        assert_eq!(original.load(), cloned.load());
767
768        // Verify they're independent - modifying one doesn't affect the other
769        original.store(999.0);
770        assert_eq!(
771            cloned.load(),
772            123.456,
773            "Clone should be unaffected by original changes"
774        );
775        assert_eq!(original.load(), 999.0, "Original should have new value");
776    }
777
778    #[test]
779    fn test_combined_time_and_success_rewards() {
780        use aws_smithy_async::test_util::ManualTimeSource;
781        use std::time::UNIX_EPOCH;
782
783        let time_source = ManualTimeSource::new(UNIX_EPOCH);
784        let current_time_secs = UNIX_EPOCH
785            .duration_since(SystemTime::UNIX_EPOCH)
786            .unwrap()
787            .as_secs() as u32;
788
789        let bucket = TokenBucket {
790            refill_rate: 1.0,
791            success_reward: 0.5,
792            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
793            semaphore: Arc::new(Semaphore::new(0)),
794            max_permits: 100,
795            ..Default::default()
796        };
797
798        // Add success rewards: 2 * 0.5 = 1.0 token
799        bucket.reward_success();
800        bucket.reward_success();
801
802        // Advance time by 2 seconds
803        time_source.advance(Duration::from_secs(2));
804
805        // Trigger time-based refill: 2 sec * 1.0 = 2.0 tokens
806        // Total: 1.0 + 2.0 = 3.0 tokens
807        bucket.refill_tokens_based_on_time(&time_source);
808        bucket.convert_fractional_tokens();
809
810        assert_eq!(bucket.available_permits(), 3);
811        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
812    }
813
814    #[test]
815    fn test_refill_rates() {
816        use aws_smithy_async::test_util::ManualTimeSource;
817        use std::time::UNIX_EPOCH;
818        // (refill_rate, elapsed_secs, expected_permits, expected_fractional)
819        let test_cases = [
820            (10.0, 2, 20, 0.0),      // Basic: 2 sec * 10 tokens/sec = 20 tokens
821            (0.001, 1100, 1, 0.1),   // Small: 1100 * 0.001 = 1.1 tokens
822            (0.0001, 11000, 1, 0.1), // Tiny: 11000 * 0.0001 = 1.1 tokens
823            (0.001, 1200, 1, 0.2),   // 1200 * 0.001 = 1.2 tokens
824            (0.0001, 10000, 1, 0.0), // 10000 * 0.0001 = 1.0 tokens
825            (0.001, 500, 0, 0.5),    // Fractional only: 500 * 0.001 = 0.5 tokens
826        ];
827
828        for (refill_rate, elapsed_secs, expected_permits, expected_fractional) in test_cases {
829            let time_source = ManualTimeSource::new(UNIX_EPOCH);
830            let current_time_secs = UNIX_EPOCH
831                .duration_since(SystemTime::UNIX_EPOCH)
832                .unwrap()
833                .as_secs() as u32;
834
835            let bucket = TokenBucket {
836                refill_rate,
837                last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
838                semaphore: Arc::new(Semaphore::new(0)),
839                max_permits: 100,
840                ..Default::default()
841            };
842
843            // Advance time by the specified duration
844            time_source.advance(Duration::from_secs(elapsed_secs));
845
846            bucket.refill_tokens_based_on_time(&time_source);
847            bucket.convert_fractional_tokens();
848
849            assert_eq!(
850                bucket.available_permits(),
851                expected_permits,
852                "Rate {}: After {}s expected {} permits",
853                refill_rate,
854                elapsed_secs,
855                expected_permits
856            );
857            assert!(
858                (bucket.fractional_tokens.load() - expected_fractional).abs() < 0.0001,
859                "Rate {}: After {}s expected {} fractional, got {}",
860                refill_rate,
861                elapsed_secs,
862                expected_fractional,
863                bucket.fractional_tokens.load()
864            );
865        }
866    }
867
868    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
869    #[test]
870    fn test_rewards_capped_at_max_capacity() {
871        use aws_smithy_async::test_util::ManualTimeSource;
872        use std::time::UNIX_EPOCH;
873
874        let time_source = ManualTimeSource::new(UNIX_EPOCH);
875        let current_time_secs = UNIX_EPOCH
876            .duration_since(SystemTime::UNIX_EPOCH)
877            .unwrap()
878            .as_secs() as u32;
879
880        let bucket = TokenBucket {
881            refill_rate: 50.0,
882            success_reward: 2.0,
883            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
884            semaphore: Arc::new(Semaphore::new(5)),
885            max_permits: 10,
886            ..Default::default()
887        };
888
889        // Add success rewards: 50 * 2.0 = 100 tokens (without cap)
890        for _ in 0..50 {
891            bucket.reward_success();
892        }
893
894        // Fractional tokens capped at 10 from success rewards
895        assert_eq!(bucket.fractional_tokens.load(), 10.0);
896
897        // Advance time by 100 seconds
898        time_source.advance(Duration::from_secs(100));
899
900        // Time-based refill: 100 * 50 = 5000 tokens (without cap)
901        // But fractional is already at 10, so it stays at 10
902        bucket.refill_tokens_based_on_time(&time_source);
903
904        // Fractional tokens should be capped at max_permits (10)
905        assert_eq!(
906            bucket.fractional_tokens.load(),
907            10.0,
908            "Fractional tokens should be capped at max_permits"
909        );
910        // Convert should add 5 tokens (bucket at 5, can add 5 more to reach max 10)
911        bucket.convert_fractional_tokens();
912        assert_eq!(bucket.available_permits(), 10);
913    }
914
915    #[cfg(any(feature = "test-util", feature = "legacy-test-util"))]
916    #[test]
917    fn test_concurrent_time_based_refill_no_over_generation() {
918        use aws_smithy_async::test_util::ManualTimeSource;
919        use std::sync::{Arc, Barrier};
920        use std::thread;
921        use std::time::UNIX_EPOCH;
922
923        let time_source = ManualTimeSource::new(UNIX_EPOCH);
924        let current_time_secs = UNIX_EPOCH
925            .duration_since(SystemTime::UNIX_EPOCH)
926            .unwrap()
927            .as_secs() as u32;
928
929        // Create bucket with 1 token/sec refill
930        let bucket = Arc::new(TokenBucket {
931            refill_rate: 1.0,
932            last_refill_time_secs: Arc::new(AtomicU32::new(current_time_secs)),
933            semaphore: Arc::new(Semaphore::new(0)),
934            max_permits: 100,
935            ..Default::default()
936        });
937
938        // Advance time by 10 seconds
939        time_source.advance(Duration::from_secs(10));
940        let shared_time_source = aws_smithy_async::time::SharedTimeSource::new(time_source);
941
942        // Launch 100 threads that all try to refill simultaneously
943        let barrier = Arc::new(Barrier::new(100));
944        let mut handles = Vec::new();
945
946        for _ in 0..100 {
947            let bucket_clone1 = Arc::clone(&bucket);
948            let barrier_clone1 = Arc::clone(&barrier);
949            let time_source_clone1 = shared_time_source.clone();
950            let bucket_clone2 = Arc::clone(&bucket);
951            let barrier_clone2 = Arc::clone(&barrier);
952            let time_source_clone2 = shared_time_source.clone();
953
954            let handle1 = thread::spawn(move || {
955                // Wait for all threads to be ready
956                barrier_clone1.wait();
957
958                // All threads call refill at the same time
959                bucket_clone1.refill_tokens_based_on_time(&time_source_clone1);
960            });
961
962            let handle2 = thread::spawn(move || {
963                // Wait for all threads to be ready
964                barrier_clone2.wait();
965
966                // All threads call refill at the same time
967                bucket_clone2.refill_tokens_based_on_time(&time_source_clone2);
968            });
969            handles.push(handle1);
970            handles.push(handle2);
971        }
972
973        // Wait for all threads to complete
974        for handle in handles {
975            handle.join().unwrap();
976        }
977
978        // Convert fractional tokens to whole tokens
979        bucket.convert_fractional_tokens();
980
981        // Should have exactly 10 tokens (10 seconds * 1 token/sec)
982        // Not 1000 tokens (100 threads * 10 tokens each)
983        assert_eq!(
984            bucket.available_permits(),
985            10,
986            "Only one thread should have added tokens, not all 100"
987        );
988
989        // Fractional should be 0 after conversion
990        assert!(bucket.fractional_tokens.load().abs() < 0.0001);
991    }
992
993    /// Regression test for https://github.com/awslabs/aws-sdk-rust/issues/1423
994    #[test]
995    fn test_is_full_accounts_for_fractional_tokens() {
996        let bucket = TokenBucket::builder()
997            .capacity(2)
998            .retry_cost(1)
999            .success_reward(0.9)
1000            .build();
1001
1002        assert!(bucket.is_full());
1003
1004        let _p1 = bucket
1005            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1006            .unwrap();
1007        let _p2 = bucket
1008            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1009            .unwrap();
1010
1011        assert!(bucket.is_empty());
1012
1013        // 3 rewards of 0.9 = 2.7 fractional tokens, which converts to 2 whole
1014        // permits — enough to fill the bucket (capacity 2).
1015        bucket.reward_success();
1016        bucket.reward_success();
1017        bucket.reward_success();
1018
1019        // Before the fix, is_full() returned false here because fractional
1020        // tokens hadn't been converted to real permits.
1021        assert!(bucket.is_full());
1022        assert!(!bucket.is_empty());
1023    }
1024
1025    #[test]
1026    fn test_is_empty_accounts_for_fractional_tokens() {
1027        let bucket = TokenBucket::builder()
1028            .capacity(10)
1029            .retry_cost(10)
1030            .success_reward(0.5)
1031            .build();
1032
1033        let _p = bucket
1034            .acquire(&ErrorKind::ServerError, &*TIME_SOURCE)
1035            .unwrap();
1036        assert_eq!(bucket.semaphore.available_permits(), 0);
1037
1038        // 0.5 fractional tokens can't convert to a whole permit
1039        bucket.reward_success();
1040        assert!(bucket.is_empty());
1041
1042        // 1.0 fractional tokens converts to a permit
1043        bucket.reward_success();
1044        assert!(!bucket.is_empty());
1045    }
1046}