aws_config/default_provider/
account_id_endpoint_mode.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::provider_config::ProviderConfig;
7use aws_runtime::env_config::EnvConfigValue;
8use aws_smithy_types::error::display::DisplayErrorContext;
9use aws_types::endpoint_config::AccountIdEndpointMode;
10use std::str::FromStr;
11
12mod env {
13    pub(super) const ACCOUNT_ID_ENDPOINT_MODE: &str = "AWS_ACCOUNT_ID_ENDPOINT_MODE";
14}
15
16mod profile_key {
17    pub(super) const ACCOUNT_ID_ENDPOINT_MODE: &str = "account_id_endpoint_mode";
18}
19
20/// Load the value for the Account-based endpoint mode
21///
22/// This checks the following sources:
23/// 1. The environment variable `AWS_ACCOUNT_ID_ENDPOINT_MODE=preferred/disabled/required`
24/// 2. The profile key `account_id_endpoint_mode=preferred/disabled/required`
25///
26/// If invalid values are found, the provider will return `None` and an error will be logged.
27pub(crate) async fn account_id_endpoint_mode_provider(
28    provider_config: &ProviderConfig,
29) -> Option<AccountIdEndpointMode> {
30    let env = provider_config.env();
31    let profiles = provider_config.profile().await;
32
33    EnvConfigValue::new()
34        .env(env::ACCOUNT_ID_ENDPOINT_MODE)
35        .profile(profile_key::ACCOUNT_ID_ENDPOINT_MODE)
36        .validate(&env, profiles, AccountIdEndpointMode::from_str)
37        .map_err(|err| tracing::warn!(err = %DisplayErrorContext(&err), "invalid value for `AccountIdEndpointMode`"))
38        .unwrap_or(None)
39}
40
41#[cfg(test)]
42mod test {
43    use super::account_id_endpoint_mode_provider;
44    use super::env;
45    #[allow(deprecated)]
46    use crate::profile::profile_file::{ProfileFileKind, ProfileFiles};
47    use crate::provider_config::ProviderConfig;
48    use aws_types::os_shim_internal::{Env, Fs};
49    use tracing_test::traced_test;
50
51    #[tokio::test]
52    #[traced_test]
53    async fn log_error_on_invalid_value() {
54        let conf = ProviderConfig::empty().with_env(Env::from_slice(&[(
55            env::ACCOUNT_ID_ENDPOINT_MODE,
56            "invalid",
57        )]));
58        assert_eq!(None, account_id_endpoint_mode_provider(&conf).await);
59        assert!(logs_contain("invalid value for `AccountIdEndpointMode`"));
60    }
61
62    #[tokio::test]
63    #[traced_test]
64    async fn environment_priority() {
65        let conf = ProviderConfig::empty()
66            .with_env(Env::from_slice(&[(
67                env::ACCOUNT_ID_ENDPOINT_MODE,
68                "disabled",
69            )]))
70            .with_profile_config(
71                Some(
72                    #[allow(deprecated)]
73                    ProfileFiles::builder()
74                        .with_file(
75                            #[allow(deprecated)]
76                            ProfileFileKind::Config,
77                            "conf",
78                        )
79                        .build(),
80                ),
81                None,
82            )
83            .with_fs(Fs::from_slice(&[(
84                "conf",
85                "[default]\naccount_id_endpoint_mode = required",
86            )]));
87        assert_eq!(
88            "disabled".to_owned(),
89            account_id_endpoint_mode_provider(&conf)
90                .await
91                .unwrap()
92                .to_string(),
93        );
94    }
95}