aws_config/meta/credentials/
chain.rs1use aws_credential_types::{
7 provider::{self, error::CredentialsError, future, ProvideCredentials},
8 Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use std::fmt::Debug;
13use tracing::Instrument;
14
15pub struct CredentialsProviderChain {
36 providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl Debug for CredentialsProviderChain {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("CredentialsProviderChain")
42 .field(
43 "providers",
44 &self
45 .providers
46 .iter()
47 .map(|provider| &provider.0)
48 .collect::<Vec<&Cow<'static, str>>>(),
49 )
50 .finish()
51 }
52}
53
54impl CredentialsProviderChain {
55 pub fn first_try(
57 name: impl Into<Cow<'static, str>>,
58 provider: impl ProvideCredentials + 'static,
59 ) -> Self {
60 CredentialsProviderChain {
61 providers: vec![(name.into(), Box::new(provider))],
62 }
63 }
64
65 pub fn or_else(
67 mut self,
68 name: impl Into<Cow<'static, str>>,
69 provider: impl ProvideCredentials + 'static,
70 ) -> Self {
71 self.providers.push((name.into(), Box::new(provider)));
72 self
73 }
74
75 #[cfg(any(feature = "default-https-client", feature = "rustls"))]
77 pub async fn or_default_provider(self) -> Self {
78 self.or_else(
79 "DefaultProviderChain",
80 crate::default_provider::credentials::default_provider().await,
81 )
82 }
83
84 #[cfg(any(feature = "default-https-client", feature = "rustls"))]
86 pub async fn default_provider() -> Self {
87 Self::first_try(
88 "DefaultProviderChain",
89 crate::default_provider::credentials::default_provider().await,
90 )
91 }
92
93 async fn credentials(&self) -> provider::Result {
94 for (name, provider) in &self.providers {
95 let span = tracing::debug_span!("credentials_provider_chain", provider = %name);
96 match provider.provide_credentials().instrument(span).await {
97 Ok(credentials) => {
98 tracing::debug!(provider = %name, "loaded credentials");
99 return Ok(credentials);
100 }
101 Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
102 tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
103 }
104 Err(err) => {
105 tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
106 return Err(err);
107 }
108 }
109 }
110 Err(CredentialsError::not_loaded(
111 "no providers in chain provided credentials",
112 ))
113 }
114}
115
116impl ProvideCredentials for CredentialsProviderChain {
117 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
118 where
119 Self: 'a,
120 {
121 future::ProvideCredentials::new(self.credentials())
122 }
123
124 fn fallback_on_interrupt(&self) -> Option<Credentials> {
125 for (_, provider) in &self.providers {
126 if let creds @ Some(_) = provider.fallback_on_interrupt() {
127 return creds;
128 }
129 }
130 None
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::time::Duration;
137
138 use aws_credential_types::{
139 credential_fn::provide_credentials_fn,
140 provider::{error::CredentialsError, future, ProvideCredentials},
141 Credentials,
142 };
143 use aws_smithy_async::future::timeout::Timeout;
144
145 use crate::meta::credentials::CredentialsProviderChain;
146
147 #[derive(Debug)]
148 struct FallbackCredentials(Credentials);
149
150 impl ProvideCredentials for FallbackCredentials {
151 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
152 where
153 Self: 'a,
154 {
155 future::ProvideCredentials::new(async {
156 tokio::time::sleep(Duration::from_millis(200)).await;
157 Ok(self.0.clone())
158 })
159 }
160
161 fn fallback_on_interrupt(&self) -> Option<Credentials> {
162 Some(self.0.clone())
163 }
164 }
165
166 #[tokio::test]
167 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
168 ) {
169 let chain = CredentialsProviderChain::first_try(
170 "provider1",
171 provide_credentials_fn(|| async {
172 tokio::time::sleep(Duration::from_millis(200)).await;
173 Err(CredentialsError::not_loaded(
174 "no providers in chain provided credentials",
175 ))
176 }),
177 )
178 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
179
180 let expected = chain.provide_credentials().await.unwrap();
182
183 let timeout = Timeout::new(
185 chain.provide_credentials(),
186 tokio::time::sleep(Duration::from_millis(300)),
187 );
188 match timeout.await {
189 Ok(_) => panic!("provide_credentials completed before timeout future"),
190 Err(_err) => match chain.fallback_on_interrupt() {
191 Some(actual) => assert_eq!(actual, expected),
192 None => panic!(
193 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
194 ),
195 },
196 };
197 }
198
199 #[tokio::test]
200 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
201 ) {
202 let chain = CredentialsProviderChain::first_try(
203 "provider1",
204 provide_credentials_fn(|| async {
205 tokio::time::sleep(Duration::from_millis(200)).await;
206 Err(CredentialsError::not_loaded(
207 "no providers in chain provided credentials",
208 ))
209 }),
210 )
211 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
212
213 let expected = chain.provide_credentials().await.unwrap();
215
216 let timeout = Timeout::new(
218 chain.provide_credentials(),
219 tokio::time::sleep(Duration::from_millis(100)),
220 );
221 match timeout.await {
222 Ok(_) => panic!("provide_credentials completed before timeout future"),
223 Err(_err) => match chain.fallback_on_interrupt() {
224 Some(actual) => assert_eq!(actual, expected),
225 None => panic!(
226 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
227 ),
228 },
229 };
230 }
231}