aws_smithy_runtime/
expiring_cache.rs1use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::{OnceCell, RwLock};
11
12#[derive(Debug)]
18pub struct ExpiringCache<T, E> {
19 buffer_time: Duration,
22 value: Arc<RwLock<OnceCell<(T, SystemTime)>>>,
23 _phantom: PhantomData<E>,
24}
25
26impl<T, E> Clone for ExpiringCache<T, E> {
27 fn clone(&self) -> Self {
28 Self {
29 buffer_time: self.buffer_time,
30 value: self.value.clone(),
31 _phantom: Default::default(),
32 }
33 }
34}
35
36impl<T, E> ExpiringCache<T, E>
37where
38 T: Clone,
39{
40 pub fn new(buffer_time: Duration) -> Self {
42 ExpiringCache {
43 buffer_time,
44 value: Arc::new(RwLock::new(OnceCell::new())),
45 _phantom: Default::default(),
46 }
47 }
48
49 #[cfg(all(test, feature = "client", feature = "http-auth"))]
50 async fn get(&self) -> Option<T>
51 where
52 T: Clone,
53 {
54 self.value
55 .read()
56 .await
57 .get()
58 .cloned()
59 .map(|(creds, _expiry)| creds)
60 }
61
62 pub async fn get_or_load<F, Fut>(&self, f: F) -> Result<T, E>
68 where
69 F: FnOnce() -> Fut,
70 Fut: Future<Output = Result<(T, SystemTime), E>>,
71 {
72 let lock = self.value.read().await;
73 let future = lock.get_or_try_init(f);
74 future.await.map(|(value, _expiry)| value.clone())
75 }
76
77 pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<T> {
79 if let Some((value, expiry)) = self.value.read().await.get() {
81 if !expired(*expiry, self.buffer_time, now) {
82 return Some(value.clone());
83 } else {
84 tracing::debug!(expiry = ?expiry, delta= ?now.duration_since(*expiry), "An item existed but it expired.")
85 }
86 }
87
88 let mut lock = self.value.write().await;
92 if let Some((_value, expiration)) = lock.get() {
93 if expired(*expiration, self.buffer_time, now) {
96 *lock = OnceCell::new();
97 }
98 }
99 None
100 }
101}
102
103fn expired(expiration: SystemTime, buffer_time: Duration, now: SystemTime) -> bool {
104 now >= (expiration - buffer_time)
105}
106
107#[cfg(all(test, feature = "client", feature = "http-auth"))]
108mod tests {
109 use super::{expired, ExpiringCache};
110 use aws_smithy_runtime_api::box_error::BoxError;
111 use aws_smithy_runtime_api::client::identity::http::Token;
112 use aws_smithy_runtime_api::client::identity::Identity;
113 use std::time::{Duration, SystemTime};
114 use tracing_test::traced_test;
115
116 fn identity(expired_secs: u64) -> Result<(Identity, SystemTime), BoxError> {
117 let expiration = epoch_secs(expired_secs);
118 let identity = Identity::new(Token::new("test", Some(expiration)), Some(expiration));
119 Ok((identity, expiration))
120 }
121
122 fn epoch_secs(secs: u64) -> SystemTime {
123 SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
124 }
125
126 #[test]
127 fn expired_check() {
128 let ts = epoch_secs(100);
129 assert!(expired(ts, Duration::from_secs(10), epoch_secs(1000)));
130 assert!(expired(ts, Duration::from_secs(10), epoch_secs(90)));
131 assert!(!expired(ts, Duration::from_secs(10), epoch_secs(10)));
132 }
133
134 #[traced_test]
135 #[tokio::test]
136 async fn cache_clears_if_expired_only() {
137 let cache = ExpiringCache::new(Duration::from_secs(10));
138 assert!(cache
139 .yield_or_clear_if_expired(epoch_secs(100))
140 .await
141 .is_none());
142
143 cache.get_or_load(|| async { identity(100) }).await.unwrap();
144 assert_eq!(
145 Some(epoch_secs(100)),
146 cache.get().await.unwrap().expiration()
147 );
148
149 assert_eq!(
151 Some(epoch_secs(100)),
152 cache
153 .yield_or_clear_if_expired(epoch_secs(10))
154 .await
155 .unwrap()
156 .expiration()
157 );
158
159 assert!(cache
161 .yield_or_clear_if_expired(epoch_secs(500))
162 .await
163 .is_none());
164 assert!(cache.get().await.is_none());
165 }
166}