aws_config/meta/credentials/
chain.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_credential_types::{
7    provider::{self, error::CredentialsError, future, ProvideCredentials},
8    Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use std::fmt::Debug;
13use tracing::Instrument;
14
15/// Credentials provider that checks a series of inner providers
16///
17/// Each provider will be evaluated in order:
18/// * If a provider returns valid [`Credentials`] they will be returned immediately.
19///   No other credential providers will be used.
20/// * Otherwise, if a provider returns [`CredentialsError::CredentialsNotLoaded`], the next provider will be checked.
21/// * Finally, if a provider returns any other error condition, an error will be returned immediately.
22///
23/// # Examples
24///
25/// ```no_run
26/// # fn example() {
27/// use aws_config::meta::credentials::CredentialsProviderChain;
28/// use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider;
29/// use aws_config::profile::ProfileFileCredentialsProvider;
30///
31/// let provider = CredentialsProviderChain::first_try("Environment", EnvironmentVariableCredentialsProvider::new())
32///     .or_else("Profile", ProfileFileCredentialsProvider::builder().build());
33/// # }
34/// ```
35pub struct CredentialsProviderChain {
36    providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl Debug for CredentialsProviderChain {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("CredentialsProviderChain")
42            .field(
43                "providers",
44                &self
45                    .providers
46                    .iter()
47                    .map(|provider| &provider.0)
48                    .collect::<Vec<&Cow<'static, str>>>(),
49            )
50            .finish()
51    }
52}
53
54impl CredentialsProviderChain {
55    /// Create a `CredentialsProviderChain` that begins by evaluating this provider
56    pub fn first_try(
57        name: impl Into<Cow<'static, str>>,
58        provider: impl ProvideCredentials + 'static,
59    ) -> Self {
60        CredentialsProviderChain {
61            providers: vec![(name.into(), Box::new(provider))],
62        }
63    }
64
65    /// Add a fallback provider to the credentials provider chain
66    pub fn or_else(
67        mut self,
68        name: impl Into<Cow<'static, str>>,
69        provider: impl ProvideCredentials + 'static,
70    ) -> Self {
71        self.providers.push((name.into(), Box::new(provider)));
72        self
73    }
74
75    /// Add a fallback to the default provider chain
76    #[cfg(any(feature = "default-https-client", feature = "rustls"))]
77    pub async fn or_default_provider(self) -> Self {
78        self.or_else(
79            "DefaultProviderChain",
80            crate::default_provider::credentials::default_provider().await,
81        )
82    }
83
84    /// Creates a credential provider chain that starts with the default provider
85    #[cfg(any(feature = "default-https-client", feature = "rustls"))]
86    pub async fn default_provider() -> Self {
87        Self::first_try(
88            "DefaultProviderChain",
89            crate::default_provider::credentials::default_provider().await,
90        )
91    }
92
93    async fn credentials(&self) -> provider::Result {
94        for (name, provider) in &self.providers {
95            let span = tracing::debug_span!("credentials_provider_chain", provider = %name);
96            match provider.provide_credentials().instrument(span).await {
97                Ok(credentials) => {
98                    tracing::debug!(provider = %name, "loaded credentials");
99                    return Ok(credentials);
100                }
101                Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
102                    tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
103                }
104                Err(err) => {
105                    tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
106                    return Err(err);
107                }
108            }
109        }
110        Err(CredentialsError::not_loaded(
111            "no providers in chain provided credentials",
112        ))
113    }
114}
115
116impl ProvideCredentials for CredentialsProviderChain {
117    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
118    where
119        Self: 'a,
120    {
121        future::ProvideCredentials::new(self.credentials())
122    }
123
124    fn fallback_on_interrupt(&self) -> Option<Credentials> {
125        for (_, provider) in &self.providers {
126            if let creds @ Some(_) = provider.fallback_on_interrupt() {
127                return creds;
128            }
129        }
130        None
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::time::Duration;
137
138    use aws_credential_types::{
139        credential_fn::provide_credentials_fn,
140        provider::{error::CredentialsError, future, ProvideCredentials},
141        Credentials,
142    };
143    use aws_smithy_async::future::timeout::Timeout;
144
145    use crate::meta::credentials::CredentialsProviderChain;
146
147    #[derive(Debug)]
148    struct FallbackCredentials(Credentials);
149
150    impl ProvideCredentials for FallbackCredentials {
151        fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
152        where
153            Self: 'a,
154        {
155            future::ProvideCredentials::new(async {
156                tokio::time::sleep(Duration::from_millis(200)).await;
157                Ok(self.0.clone())
158            })
159        }
160
161        fn fallback_on_interrupt(&self) -> Option<Credentials> {
162            Some(self.0.clone())
163        }
164    }
165
166    #[tokio::test]
167    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
168    ) {
169        let chain = CredentialsProviderChain::first_try(
170            "provider1",
171            provide_credentials_fn(|| async {
172                tokio::time::sleep(Duration::from_millis(200)).await;
173                Err(CredentialsError::not_loaded(
174                    "no providers in chain provided credentials",
175                ))
176            }),
177        )
178        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
179
180        // Let the first call to `provide_credentials` succeed.
181        let expected = chain.provide_credentials().await.unwrap();
182
183        // Let the second call fail with an external timeout.
184        let timeout = Timeout::new(
185            chain.provide_credentials(),
186            tokio::time::sleep(Duration::from_millis(300)),
187        );
188        match timeout.await {
189            Ok(_) => panic!("provide_credentials completed before timeout future"),
190            Err(_err) => match chain.fallback_on_interrupt() {
191                Some(actual) => assert_eq!(actual, expected),
192                None => panic!(
193                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
194                ),
195            },
196        };
197    }
198
199    #[tokio::test]
200    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
201    ) {
202        let chain = CredentialsProviderChain::first_try(
203            "provider1",
204            provide_credentials_fn(|| async {
205                tokio::time::sleep(Duration::from_millis(200)).await;
206                Err(CredentialsError::not_loaded(
207                    "no providers in chain provided credentials",
208                ))
209            }),
210        )
211        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
212
213        // Let the first call to `provide_credentials` succeed.
214        let expected = chain.provide_credentials().await.unwrap();
215
216        // Let the second call fail with an external timeout.
217        let timeout = Timeout::new(
218            chain.provide_credentials(),
219            tokio::time::sleep(Duration::from_millis(100)),
220        );
221        match timeout.await {
222            Ok(_) => panic!("provide_credentials completed before timeout future"),
223            Err(_err) => match chain.fallback_on_interrupt() {
224                Some(actual) => assert_eq!(actual, expected),
225                None => panic!(
226                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
227                ),
228            },
229        };
230    }
231}