aws_config/
ecs.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Ecs Credentials Provider
7//!
8//! This credential provider is frequently used with an AWS-provided credentials service (e.g.
9//! [IAM Roles for tasks](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html)).
10//! However, it's possible to use environment variables to configure this provider to use your own
11//! credentials sources.
12//!
13//! This provider is part of the [default credentials chain](crate::default_provider::credentials).
14//!
15//! ## Configuration
16//! **First**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`. It will use this
17//! to construct a URI rooted at `http://169.254.170.2`. For example, if the value of the environment
18//! variable was `/credentials`, the SDK would look for credentials at `http://169.254.170.2/credentials`.
19//!
20//! **Next**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_FULL_URI`. This specifies the full
21//! URL to load credentials. The URL MUST satisfy one of the following three properties:
22//! 1. The URL begins with `https`
23//! 2. The URL refers to an allowed IP address. If a URL contains a domain name instead of an IP address,
24//!    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP address, or
25//!    the credentials provider will return `CredentialsError::InvalidConfiguration`. Valid IP addresses are:
26//!     a) Loopback interfaces
27//!     b) The [ECS Task Metadata V2](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html)
28//!        address ie 169.254.170.2.
29//!     c) [EKS Pod Identity](https://docs.aws.amazon.com/eks/latest/userguide/pod-identities.html) addresses
30//!        ie 169.254.170.23 or fd00:ec2::23
31//!
32//! **Next**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE`. If this is set,
33//! the filename specified will be read, and the value passed in the `Authorization` header. If the file
34//! cannot be read, an error is returned.
35//!
36//! **Finally**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN`. If this is set, the
37//! value will be passed in the `Authorization` header.
38//!
39//! ## Credentials Format
40//! Credentials MUST be returned in a JSON format:
41//! ```json
42//! {
43//!    "AccessKeyId" : "MUA...",
44//!    "SecretAccessKey" : "/7PC5om....",
45//!    "Token" : "AQoDY....=",
46//!    "Expiration" : "2016-02-25T06:03:31Z"
47//!  }
48//! ```
49//!
50//! Credentials errors MAY be returned with a `code` and `message` field:
51//! ```json
52//! {
53//!   "code": "ErrorCode",
54//!   "message": "Helpful error message."
55//! }
56//! ```
57
58use crate::http_credential_provider::HttpCredentialProvider;
59use crate::provider_config::ProviderConfig;
60use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
61use aws_smithy_http::endpoint::apply_endpoint;
62use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
63use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
64use aws_smithy_runtime_api::shared::IntoShared;
65use aws_smithy_types::error::display::DisplayErrorContext;
66use aws_types::os_shim_internal::{Env, Fs};
67use http::header::InvalidHeaderValue;
68use http::uri::{InvalidUri, PathAndQuery, Scheme};
69use http::{HeaderValue, Uri};
70use std::error::Error;
71use std::fmt::{Display, Formatter};
72use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
73use std::time::Duration;
74use tokio::sync::OnceCell;
75
76const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
77const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
78
79// URL from https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html
80const BASE_HOST: &str = "http://169.254.170.2";
81const ENV_RELATIVE_URI: &str = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
82const ENV_FULL_URI: &str = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
83const ENV_AUTHORIZATION_TOKEN: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
84const ENV_AUTHORIZATION_TOKEN_FILE: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE";
85
86/// Credential provider for ECS and generalized HTTP credentials
87///
88/// See the [module](crate::ecs) documentation for more details.
89///
90/// This credential provider is part of the default chain.
91#[derive(Debug)]
92pub struct EcsCredentialsProvider {
93    inner: OnceCell<Provider>,
94    env: Env,
95    fs: Fs,
96    builder: Builder,
97}
98
99impl EcsCredentialsProvider {
100    /// Builder for [`EcsCredentialsProvider`]
101    pub fn builder() -> Builder {
102        Builder::default()
103    }
104
105    /// Load credentials from this credentials provider
106    pub async fn credentials(&self) -> provider::Result {
107        let env_token_file = self.env.get(ENV_AUTHORIZATION_TOKEN_FILE).ok();
108        let env_token = self.env.get(ENV_AUTHORIZATION_TOKEN).ok();
109        let auth = if let Some(auth_token_file) = env_token_file {
110            let auth = self
111                .fs
112                .read_to_end(auth_token_file)
113                .await
114                .map_err(CredentialsError::provider_error)?;
115            Some(HeaderValue::from_bytes(auth.as_slice()).map_err(|err| {
116                let auth_token = String::from_utf8_lossy(auth.as_slice()).to_string();
117                tracing::warn!(token = %auth_token, "invalid auth token");
118                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
119                    err,
120                    value: auth_token,
121                })
122            })?)
123        } else if let Some(auth_token) = env_token {
124            Some(HeaderValue::from_str(&auth_token).map_err(|err| {
125                tracing::warn!(token = %auth_token, "invalid auth token");
126                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
127                    err,
128                    value: auth_token,
129                })
130            })?)
131        } else {
132            None
133        };
134        match self.provider().await {
135            Provider::NotConfigured => {
136                Err(CredentialsError::not_loaded("ECS provider not configured"))
137            }
138            Provider::InvalidConfiguration(err) => {
139                Err(CredentialsError::invalid_configuration(format!("{}", err)))
140            }
141            Provider::Configured(provider) => provider.credentials(auth).await,
142        }
143    }
144
145    async fn provider(&self) -> &Provider {
146        self.inner
147            .get_or_init(|| Provider::make(self.builder.clone()))
148            .await
149    }
150}
151
152impl ProvideCredentials for EcsCredentialsProvider {
153    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
154    where
155        Self: 'a,
156    {
157        future::ProvideCredentials::new(self.credentials())
158    }
159}
160
161/// Inner Provider that can record failed configuration state
162#[derive(Debug)]
163#[allow(clippy::large_enum_variant)]
164enum Provider {
165    Configured(HttpCredentialProvider),
166    NotConfigured,
167    InvalidConfiguration(EcsConfigurationError),
168}
169
170impl Provider {
171    async fn uri(env: Env, dns: Option<SharedDnsResolver>) -> Result<Uri, EcsConfigurationError> {
172        let relative_uri = env.get(ENV_RELATIVE_URI).ok();
173        let full_uri = env.get(ENV_FULL_URI).ok();
174        if let Some(relative_uri) = relative_uri {
175            Self::build_full_uri(relative_uri)
176        } else if let Some(full_uri) = full_uri {
177            let dns = dns.or_else(default_dns);
178            validate_full_uri(&full_uri, dns)
179                .await
180                .map_err(|err| EcsConfigurationError::InvalidFullUri { err, uri: full_uri })
181        } else {
182            Err(EcsConfigurationError::NotConfigured)
183        }
184    }
185
186    async fn make(builder: Builder) -> Self {
187        let provider_config = builder.provider_config.unwrap_or_default();
188        let env = provider_config.env();
189        let uri = match Self::uri(env, builder.dns).await {
190            Ok(uri) => uri,
191            Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
192            Err(err) => return Provider::InvalidConfiguration(err),
193        };
194        let path_and_query = match uri.path_and_query() {
195            Some(path_and_query) => path_and_query.to_string(),
196            None => uri.path().to_string(),
197        };
198        let endpoint = {
199            let mut parts = uri.into_parts();
200            parts.path_and_query = Some(PathAndQuery::from_static("/"));
201            Uri::from_parts(parts)
202        }
203        .expect("parts will be valid")
204        .to_string();
205
206        let http_provider = HttpCredentialProvider::builder()
207            .configure(&provider_config)
208            .http_connector_settings(
209                HttpConnectorSettings::builder()
210                    .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
211                    .read_timeout(DEFAULT_READ_TIMEOUT)
212                    .build(),
213            )
214            .build("EcsContainer", &endpoint, path_and_query);
215        Provider::Configured(http_provider)
216    }
217
218    fn build_full_uri(relative_uri: String) -> Result<Uri, EcsConfigurationError> {
219        let mut relative_uri = match relative_uri.parse::<Uri>() {
220            Ok(uri) => uri,
221            Err(invalid_uri) => {
222                tracing::warn!(uri = %DisplayErrorContext(&invalid_uri), "invalid URI loaded from environment");
223                return Err(EcsConfigurationError::InvalidRelativeUri {
224                    err: invalid_uri,
225                    uri: relative_uri,
226                });
227            }
228        };
229        let endpoint = Uri::from_static(BASE_HOST);
230        apply_endpoint(&mut relative_uri, &endpoint, None)
231            .expect("appending relative URLs to the ECS endpoint should always succeed");
232        Ok(relative_uri)
233    }
234}
235
236#[derive(Debug)]
237enum EcsConfigurationError {
238    InvalidRelativeUri {
239        err: InvalidUri,
240        uri: String,
241    },
242    InvalidFullUri {
243        err: InvalidFullUriError,
244        uri: String,
245    },
246    InvalidAuthToken {
247        err: InvalidHeaderValue,
248        value: String,
249    },
250    NotConfigured,
251}
252
253impl Display for EcsConfigurationError {
254    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
255        match self {
256            EcsConfigurationError::InvalidRelativeUri { err, uri } => write!(
257                f,
258                "invalid relative URI for ECS provider ({}): {}",
259                err, uri
260            ),
261            EcsConfigurationError::InvalidFullUri { err, uri } => {
262                write!(f, "invalid full URI for ECS provider ({}): {}", err, uri)
263            }
264            EcsConfigurationError::NotConfigured => write!(
265                f,
266                "No environment variables were set to configure ECS provider"
267            ),
268            EcsConfigurationError::InvalidAuthToken { err, value } => write!(
269                f,
270                "`{}` could not be used as a header value for the auth token. {}",
271                value, err
272            ),
273        }
274    }
275}
276
277impl Error for EcsConfigurationError {
278    fn source(&self) -> Option<&(dyn Error + 'static)> {
279        match &self {
280            EcsConfigurationError::InvalidRelativeUri { err, .. } => Some(err),
281            EcsConfigurationError::InvalidFullUri { err, .. } => Some(err),
282            EcsConfigurationError::InvalidAuthToken { err, .. } => Some(err),
283            EcsConfigurationError::NotConfigured => None,
284        }
285    }
286}
287
288/// Builder for [`EcsCredentialsProvider`]
289#[derive(Default, Debug, Clone)]
290pub struct Builder {
291    provider_config: Option<ProviderConfig>,
292    dns: Option<SharedDnsResolver>,
293    connect_timeout: Option<Duration>,
294    read_timeout: Option<Duration>,
295}
296
297impl Builder {
298    /// Override the configuration used for this provider
299    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
300        self.provider_config = Some(provider_config.clone());
301        self
302    }
303
304    /// Override the DNS resolver used to validate URIs
305    ///
306    /// URIs must refer to valid IP addresses as defined in the module documentation. The [`ResolveDns`]
307    /// implementation is used to retrieve IP addresses for a given domain.
308    pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
309        self.dns = Some(dns.into_shared());
310        self
311    }
312
313    /// Override the connect timeout for the HTTP client
314    ///
315    /// This value defaults to 2 seconds
316    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
317        self.connect_timeout = Some(timeout);
318        self
319    }
320
321    /// Override the read timeout for the HTTP client
322    ///
323    /// This value defaults to 5 seconds
324    pub fn read_timeout(mut self, timeout: Duration) -> Self {
325        self.read_timeout = Some(timeout);
326        self
327    }
328
329    /// Create an [`EcsCredentialsProvider`] from this builder
330    pub fn build(self) -> EcsCredentialsProvider {
331        let env = self
332            .provider_config
333            .as_ref()
334            .map(|config| config.env())
335            .unwrap_or_default();
336        let fs = self
337            .provider_config
338            .as_ref()
339            .map(|config| config.fs())
340            .unwrap_or_default();
341        EcsCredentialsProvider {
342            inner: OnceCell::new(),
343            env,
344            fs,
345            builder: self,
346        }
347    }
348}
349
350#[derive(Debug)]
351enum InvalidFullUriErrorKind {
352    /// The provided URI could not be parsed as a URI
353    #[non_exhaustive]
354    InvalidUri(InvalidUri),
355
356    /// No Dns resolver was provided
357    #[non_exhaustive]
358    NoDnsResolver,
359
360    /// The URI did not specify a host
361    #[non_exhaustive]
362    MissingHost,
363
364    /// The URI did not refer to an allowed IP address
365    #[non_exhaustive]
366    DisallowedIP,
367
368    /// DNS lookup failed when attempting to resolve the host to an IP Address for validation.
369    DnsLookupFailed(ResolveDnsError),
370}
371
372/// Invalid Full URI
373///
374/// When the full URI setting is used, the URI must either be HTTPS, point to a loopback interface,
375/// or point to known ECS/EKS container IPs.
376#[derive(Debug)]
377pub struct InvalidFullUriError {
378    kind: InvalidFullUriErrorKind,
379}
380
381impl Display for InvalidFullUriError {
382    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
383        use InvalidFullUriErrorKind::*;
384        match self.kind {
385            InvalidUri(_) => write!(f, "URI was invalid"),
386            MissingHost => write!(f, "URI did not specify a host"),
387            DisallowedIP => {
388                write!(f, "URI did not refer to an allowed IP address")
389            }
390            DnsLookupFailed(_) => {
391                write!(
392                    f,
393                    "failed to perform DNS lookup while validating URI"
394                )
395            }
396            NoDnsResolver => write!(f, "no DNS resolver was provided. Enable `rt-tokio` or provide a `dns` resolver to the builder.")
397        }
398    }
399}
400
401impl Error for InvalidFullUriError {
402    fn source(&self) -> Option<&(dyn Error + 'static)> {
403        use InvalidFullUriErrorKind::*;
404        match &self.kind {
405            InvalidUri(err) => Some(err),
406            DnsLookupFailed(err) => Some(err as _),
407            _ => None,
408        }
409    }
410}
411
412impl From<InvalidFullUriErrorKind> for InvalidFullUriError {
413    fn from(kind: InvalidFullUriErrorKind) -> Self {
414        Self { kind }
415    }
416}
417
418/// Validate that `uri` is valid to be used as a full provider URI
419/// Either:
420/// 1. The URL is uses `https`
421/// 2. The URL refers to an allowed IP. If a URL contains a domain name instead of an IP address,
422///    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP, or
423///    the credentials provider will return `CredentialsError::InvalidConfiguration`. Allowed IPs
424///    are the loopback interfaces, and the known ECS/EKS container IPs.
425async fn validate_full_uri(
426    uri: &str,
427    dns: Option<SharedDnsResolver>,
428) -> Result<Uri, InvalidFullUriError> {
429    let uri = uri
430        .parse::<Uri>()
431        .map_err(InvalidFullUriErrorKind::InvalidUri)?;
432    if uri.scheme() == Some(&Scheme::HTTPS) {
433        return Ok(uri);
434    }
435    // For HTTP URIs, we need to validate that it points to a valid IP
436    let host = uri.host().ok_or(InvalidFullUriErrorKind::MissingHost)?;
437    let maybe_ip = if host.starts_with('[') && host.ends_with(']') {
438        host[1..host.len() - 1].parse::<IpAddr>()
439    } else {
440        host.parse::<IpAddr>()
441    };
442    let is_allowed = match maybe_ip {
443        Ok(addr) => is_full_uri_ip_allowed(&addr),
444        Err(_domain_name) => {
445            let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
446            dns.resolve_dns(host)
447                .await
448                .map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
449                .iter()
450                    .all(|addr| {
451                        if !is_full_uri_ip_allowed(addr) {
452                            tracing::warn!(
453                                addr = ?addr,
454                                "HTTP credential provider cannot be used: Address does not resolve to an allowed IP."
455                            )
456                        };
457                        is_full_uri_ip_allowed(addr)
458                    })
459        }
460    };
461    match is_allowed {
462        true => Ok(uri),
463        false => Err(InvalidFullUriErrorKind::DisallowedIP.into()),
464    }
465}
466
467// "169.254.170.2"
468const ECS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 2));
469
470// "169.254.170.23"
471const EKS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 23));
472
473// "fd00:ec2::23"
474const EKS_CONTAINER_IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0xFD00, 0x0EC2, 0, 0, 0, 0, 0, 0x23));
475fn is_full_uri_ip_allowed(ip: &IpAddr) -> bool {
476    ip.is_loopback()
477        || ip.eq(&ECS_CONTAINER_IPV4)
478        || ip.eq(&EKS_CONTAINER_IPV4)
479        || ip.eq(&EKS_CONTAINER_IPV6)
480}
481
482/// Default DNS resolver impl
483///
484/// DNS resolution is required to validate that provided URIs point to a valid IP address
485#[cfg(any(not(feature = "rt-tokio"), target_family = "wasm"))]
486fn default_dns() -> Option<SharedDnsResolver> {
487    None
488}
489#[cfg(all(feature = "rt-tokio", not(target_family = "wasm")))]
490fn default_dns() -> Option<SharedDnsResolver> {
491    use aws_smithy_runtime::client::dns::TokioDnsResolver;
492    Some(TokioDnsResolver::new().into_shared())
493}
494
495#[cfg(test)]
496mod test {
497    use super::*;
498    use crate::provider_config::ProviderConfig;
499    use crate::test_case::{no_traffic_client, GenericTestResult};
500    use aws_credential_types::provider::ProvideCredentials;
501    use aws_credential_types::Credentials;
502    use aws_smithy_async::future::never::Never;
503    use aws_smithy_async::rt::sleep::TokioSleep;
504    use aws_smithy_http_client::test_util::{ReplayEvent, StaticReplayClient};
505    use aws_smithy_runtime_api::client::dns::DnsFuture;
506    use aws_smithy_runtime_api::client::http::HttpClient;
507    use aws_smithy_runtime_api::shared::IntoShared;
508    use aws_smithy_types::body::SdkBody;
509    use aws_types::os_shim_internal::Env;
510    use futures_util::FutureExt;
511    use http::header::AUTHORIZATION;
512    use http::Uri;
513    use serde::Deserialize;
514    use std::collections::HashMap;
515    use std::error::Error;
516    use std::ffi::OsString;
517    use std::net::IpAddr;
518    use std::time::{Duration, UNIX_EPOCH};
519    use tracing_test::traced_test;
520
521    fn provider(
522        env: Env,
523        fs: Fs,
524        http_client: impl HttpClient + 'static,
525    ) -> EcsCredentialsProvider {
526        let provider_config = ProviderConfig::empty()
527            .with_env(env)
528            .with_fs(fs)
529            .with_http_client(http_client)
530            .with_sleep_impl(TokioSleep::new());
531        Builder::default().configure(&provider_config).build()
532    }
533
534    #[derive(Deserialize)]
535    struct EcsUriTest {
536        env: HashMap<String, String>,
537        result: GenericTestResult<String>,
538    }
539
540    impl EcsUriTest {
541        async fn check(&self) {
542            let env = Env::from(self.env.clone());
543            let uri = Provider::uri(env, Some(TestDns::default().into_shared()))
544                .await
545                .map(|uri| uri.to_string());
546            self.result.assert_matches(uri.as_ref());
547        }
548    }
549
550    #[tokio::test]
551    async fn run_config_tests() -> Result<(), Box<dyn Error>> {
552        let test_cases = std::fs::read_to_string("test-data/ecs-tests.json")?;
553        #[derive(Deserialize)]
554        struct TestCases {
555            tests: Vec<EcsUriTest>,
556        }
557
558        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
559        let test_cases = test_cases.tests;
560        for test in test_cases {
561            test.check().await
562        }
563        Ok(())
564    }
565
566    #[test]
567    fn validate_uri_https() {
568        // over HTTPs, any URI is fine
569        let dns = Some(NeverDns.into_shared());
570        assert_eq!(
571            validate_full_uri("https://amazon.com", None)
572                .now_or_never()
573                .unwrap()
574                .expect("valid"),
575            Uri::from_static("https://amazon.com")
576        );
577        // over HTTP, it will try to lookup
578        assert!(
579            validate_full_uri("http://amazon.com", dns)
580                .now_or_never()
581                .is_none(),
582            "DNS lookup should occur, but it will never return"
583        );
584
585        let no_dns_error = validate_full_uri("http://amazon.com", None)
586            .now_or_never()
587            .unwrap()
588            .expect_err("DNS service is required");
589        assert!(
590            matches!(
591                no_dns_error,
592                InvalidFullUriError {
593                    kind: InvalidFullUriErrorKind::NoDnsResolver
594                }
595            ),
596            "expected no dns service, got: {}",
597            no_dns_error
598        );
599    }
600
601    #[test]
602    fn valid_uri_loopback() {
603        assert_eq!(
604            validate_full_uri("http://127.0.0.1:8080/get-credentials", None)
605                .now_or_never()
606                .unwrap()
607                .expect("valid uri"),
608            Uri::from_static("http://127.0.0.1:8080/get-credentials")
609        );
610
611        let err = validate_full_uri("http://192.168.10.120/creds", None)
612            .now_or_never()
613            .unwrap()
614            .expect_err("not a loopback");
615        assert!(matches!(
616            err,
617            InvalidFullUriError {
618                kind: InvalidFullUriErrorKind::DisallowedIP
619            }
620        ));
621    }
622
623    #[test]
624    fn valid_uri_ecs_eks() {
625        assert_eq!(
626            validate_full_uri("http://169.254.170.2:8080/get-credentials", None)
627                .now_or_never()
628                .unwrap()
629                .expect("valid uri"),
630            Uri::from_static("http://169.254.170.2:8080/get-credentials")
631        );
632        assert_eq!(
633            validate_full_uri("http://169.254.170.23:8080/get-credentials", None)
634                .now_or_never()
635                .unwrap()
636                .expect("valid uri"),
637            Uri::from_static("http://169.254.170.23:8080/get-credentials")
638        );
639        assert_eq!(
640            validate_full_uri("http://[fd00:ec2::23]:8080/get-credentials", None)
641                .now_or_never()
642                .unwrap()
643                .expect("valid uri"),
644            Uri::from_static("http://[fd00:ec2::23]:8080/get-credentials")
645        );
646
647        let err = validate_full_uri("http://169.254.171.23/creds", None)
648            .now_or_never()
649            .unwrap()
650            .expect_err("not an ecs/eks container address");
651        assert!(matches!(
652            err,
653            InvalidFullUriError {
654                kind: InvalidFullUriErrorKind::DisallowedIP
655            }
656        ));
657
658        let err = validate_full_uri("http://[fd00:ec2::2]/creds", None)
659            .now_or_never()
660            .unwrap()
661            .expect_err("not an ecs/eks container address");
662        assert!(matches!(
663            err,
664            InvalidFullUriError {
665                kind: InvalidFullUriErrorKind::DisallowedIP
666            }
667        ));
668    }
669
670    #[test]
671    fn all_addrs_local() {
672        let dns = Some(
673            TestDns::with_fallback(vec![
674                "127.0.0.1".parse().unwrap(),
675                "127.0.0.2".parse().unwrap(),
676                "169.254.170.23".parse().unwrap(),
677                "fd00:ec2::23".parse().unwrap(),
678            ])
679            .into_shared(),
680        );
681        let resp = validate_full_uri("http://localhost:8888", dns)
682            .now_or_never()
683            .unwrap();
684        assert!(resp.is_ok(), "Should be valid: {:?}", resp);
685    }
686
687    #[test]
688    fn all_addrs_not_local() {
689        let dns = Some(
690            TestDns::with_fallback(vec![
691                "127.0.0.1".parse().unwrap(),
692                "192.168.0.1".parse().unwrap(),
693            ])
694            .into_shared(),
695        );
696        let resp = validate_full_uri("http://localhost:8888", dns)
697            .now_or_never()
698            .unwrap();
699        assert!(
700            matches!(
701                resp,
702                Err(InvalidFullUriError {
703                    kind: InvalidFullUriErrorKind::DisallowedIP
704                })
705            ),
706            "Should be invalid: {:?}",
707            resp
708        );
709    }
710
711    fn creds_request(uri: &str, auth: Option<&str>) -> http::Request<SdkBody> {
712        let mut builder = http::Request::builder();
713        if let Some(auth) = auth {
714            builder = builder.header(AUTHORIZATION, auth);
715        }
716        builder.uri(uri).body(SdkBody::empty()).unwrap()
717    }
718
719    fn ok_creds_response() -> http::Response<SdkBody> {
720        http::Response::builder()
721            .status(200)
722            .body(SdkBody::from(
723                r#" {
724                       "AccessKeyId" : "AKID",
725                       "SecretAccessKey" : "SECRET",
726                       "Token" : "TOKEN....=",
727                       "AccountId" : "AID",
728                       "Expiration" : "2009-02-13T23:31:30Z"
729                     }"#,
730            ))
731            .unwrap()
732    }
733
734    #[track_caller]
735    fn assert_correct(creds: Credentials) {
736        assert_eq!(creds.access_key_id(), "AKID");
737        assert_eq!(creds.secret_access_key(), "SECRET");
738        assert_eq!(creds.account_id().unwrap().as_str(), "AID");
739        assert_eq!(creds.session_token().unwrap(), "TOKEN....=");
740        assert_eq!(
741            creds.expiry().unwrap(),
742            UNIX_EPOCH + Duration::from_secs(1234567890)
743        );
744    }
745
746    #[tokio::test]
747    async fn load_valid_creds_auth() {
748        let env = Env::from_slice(&[
749            ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials"),
750            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "Basic password"),
751        ]);
752        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
753            creds_request("http://169.254.170.2/credentials", Some("Basic password")),
754            ok_creds_response(),
755        )]);
756        let provider = provider(env, Fs::default(), http_client.clone());
757        let creds = provider
758            .provide_credentials()
759            .await
760            .expect("valid credentials");
761        assert_correct(creds);
762        http_client.assert_requests_match(&[]);
763    }
764
765    #[tokio::test]
766    async fn load_valid_creds_auth_file() {
767        let env = Env::from_slice(&[
768            (
769                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
770                "http://169.254.170.23/v1/credentials",
771            ),
772            (
773                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
774                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
775            ),
776        ]);
777        let fs = Fs::from_raw_map(HashMap::from([(
778            OsString::from(
779                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
780            ),
781            "Basic password".into(),
782        )]));
783
784        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
785            creds_request(
786                "http://169.254.170.23/v1/credentials",
787                Some("Basic password"),
788            ),
789            ok_creds_response(),
790        )]);
791        let provider = provider(env, fs, http_client.clone());
792        let creds = provider
793            .provide_credentials()
794            .await
795            .expect("valid credentials");
796        assert_correct(creds);
797        http_client.assert_requests_match(&[]);
798    }
799
800    #[tokio::test]
801    async fn auth_file_precedence_over_env() {
802        let env = Env::from_slice(&[
803            (
804                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
805                "http://169.254.170.23/v1/credentials",
806            ),
807            (
808                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
809                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
810            ),
811            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
812        ]);
813        let fs = Fs::from_raw_map(HashMap::from([(
814            OsString::from(
815                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
816            ),
817            "Basic password".into(),
818        )]));
819
820        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
821            creds_request(
822                "http://169.254.170.23/v1/credentials",
823                Some("Basic password"),
824            ),
825            ok_creds_response(),
826        )]);
827        let provider = provider(env, fs, http_client.clone());
828        let creds = provider
829            .provide_credentials()
830            .await
831            .expect("valid credentials");
832        assert_correct(creds);
833        http_client.assert_requests_match(&[]);
834    }
835
836    #[tokio::test]
837    async fn query_params_should_be_included_in_credentials_http_request() {
838        let env = Env::from_slice(&[
839            (
840                "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
841                "/my-credentials/?applicationName=test2024",
842            ),
843            (
844                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
845                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
846            ),
847            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
848        ]);
849        let fs = Fs::from_raw_map(HashMap::from([(
850            OsString::from(
851                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
852            ),
853            "Basic password".into(),
854        )]));
855
856        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
857            creds_request(
858                "http://169.254.170.2/my-credentials/?applicationName=test2024",
859                Some("Basic password"),
860            ),
861            ok_creds_response(),
862        )]);
863        let provider = provider(env, fs, http_client.clone());
864        let creds = provider
865            .provide_credentials()
866            .await
867            .expect("valid credentials");
868        assert_correct(creds);
869        http_client.assert_requests_match(&[]);
870    }
871
872    #[tokio::test]
873    async fn fs_missing_file() {
874        let env = Env::from_slice(&[
875            (
876                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
877                "http://169.254.170.23/v1/credentials",
878            ),
879            (
880                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
881                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
882            ),
883        ]);
884        let fs = Fs::from_raw_map(HashMap::new());
885
886        let provider = provider(env, fs, no_traffic_client());
887        let err = provider.credentials().await.expect_err("no JWT token file");
888        match err {
889            CredentialsError::ProviderError { .. } => { /* ok */ }
890            _ => panic!("incorrect error variant"),
891        }
892    }
893
894    #[tokio::test]
895    async fn retry_5xx() {
896        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
897        let http_client = StaticReplayClient::new(vec![
898            ReplayEvent::new(
899                creds_request("http://169.254.170.2/credentials", None),
900                http::Response::builder()
901                    .status(500)
902                    .body(SdkBody::empty())
903                    .unwrap(),
904            ),
905            ReplayEvent::new(
906                creds_request("http://169.254.170.2/credentials", None),
907                ok_creds_response(),
908            ),
909        ]);
910        tokio::time::pause();
911        let provider = provider(env, Fs::default(), http_client.clone());
912        let creds = provider
913            .provide_credentials()
914            .await
915            .expect("valid credentials");
916        assert_correct(creds);
917    }
918
919    #[tokio::test]
920    async fn load_valid_creds_no_auth() {
921        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
922        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
923            creds_request("http://169.254.170.2/credentials", None),
924            ok_creds_response(),
925        )]);
926        let provider = provider(env, Fs::default(), http_client.clone());
927        let creds = provider
928            .provide_credentials()
929            .await
930            .expect("valid credentials");
931        assert_correct(creds);
932        http_client.assert_requests_match(&[]);
933    }
934
935    // ignored by default because it relies on actual DNS resolution
936    #[allow(unused_attributes)]
937    #[tokio::test]
938    #[traced_test]
939    #[ignore]
940    async fn real_dns_lookup() {
941        let dns = Some(
942            default_dns()
943                .expect("feature must be enabled")
944                .into_shared(),
945        );
946        let err = validate_full_uri("http://www.amazon.com/creds", dns.clone())
947            .await
948            .expect_err("not a valid IP");
949        assert!(
950            matches!(
951                err,
952                InvalidFullUriError {
953                    kind: InvalidFullUriErrorKind::DisallowedIP
954                }
955            ),
956            "{:?}",
957            err
958        );
959        assert!(logs_contain("Address does not resolve to an allowed IP"));
960        validate_full_uri("http://localhost:8888/creds", dns.clone())
961            .await
962            .expect("localhost is the loopback interface");
963        validate_full_uri("http://169.254.170.2.backname.io:8888/creds", dns.clone())
964            .await
965            .expect("169.254.170.2.backname.io is the ecs container address");
966        validate_full_uri("http://169.254.170.23.backname.io:8888/creds", dns.clone())
967            .await
968            .expect("169.254.170.23.backname.io is the eks pod identity address");
969        validate_full_uri("http://fd00-ec2--23.backname.io:8888/creds", dns)
970            .await
971            .expect("fd00-ec2--23.backname.io is the eks pod identity address");
972    }
973
974    /// Always returns the same IP addresses
975    #[derive(Clone, Debug)]
976    struct TestDns {
977        addrs: HashMap<String, Vec<IpAddr>>,
978        fallback: Vec<IpAddr>,
979    }
980
981    /// Default that returns a loopback for `localhost` and a non-loopback for all other hostnames
982    impl Default for TestDns {
983        fn default() -> Self {
984            let mut addrs = HashMap::new();
985            addrs.insert(
986                "localhost".into(),
987                vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
988            );
989            TestDns {
990                addrs,
991                // non-loopback address
992                fallback: vec!["72.21.210.29".parse().unwrap()],
993            }
994        }
995    }
996
997    impl TestDns {
998        fn with_fallback(fallback: Vec<IpAddr>) -> Self {
999            TestDns {
1000                addrs: Default::default(),
1001                fallback,
1002            }
1003        }
1004    }
1005
1006    impl ResolveDns for TestDns {
1007        fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
1008            DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
1009        }
1010    }
1011
1012    #[derive(Debug)]
1013    struct NeverDns;
1014    impl ResolveDns for NeverDns {
1015        fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
1016            DnsFuture::new(async {
1017                Never::new().await;
1018                unreachable!()
1019            })
1020        }
1021    }
1022}