aws_config/imds/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Raw IMDSv2 Client
7//!
8//! Client for direct access to IMDSv2.
9
10use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::metrics::MetricsRuntimePlugin;
16use aws_smithy_runtime::client::orchestrator::operation::Operation;
17use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
18use aws_smithy_runtime_api::box_error::BoxError;
19use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
20use aws_smithy_runtime_api::client::endpoint::{
21    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
22};
23use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
24use aws_smithy_runtime_api::client::orchestrator::{
25    HttpRequest, Metadata, OrchestratorError, SensitiveOutput,
26};
27use aws_smithy_runtime_api::client::result::ConnectorError;
28use aws_smithy_runtime_api::client::result::SdkError;
29use aws_smithy_runtime_api::client::retries::classifiers::{
30    ClassifyRetry, RetryAction, SharedRetryClassifier,
31};
32use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
33use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
34use aws_smithy_types::body::SdkBody;
35use aws_smithy_types::config_bag::{FrozenLayer, Layer};
36use aws_smithy_types::endpoint::Endpoint;
37use aws_smithy_types::retry::RetryConfig;
38use aws_smithy_types::timeout::TimeoutConfig;
39use aws_types::os_shim_internal::Env;
40use http::Uri;
41use std::borrow::Cow;
42use std::error::Error as _;
43use std::fmt;
44use std::str::FromStr;
45use std::sync::Arc;
46use std::time::Duration;
47
48pub mod error;
49mod token;
50
51// 6 hours
52const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
53const DEFAULT_ATTEMPTS: u32 = 4;
54const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
56const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
57const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
58
59fn user_agent() -> AwsUserAgent {
60    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
61}
62
63/// IMDSv2 Client
64///
65/// Client for IMDSv2. This client handles fetching tokens, retrying on failure, and token
66/// caching according to the specified token TTL.
67///
68/// _Note: This client ONLY supports IMDSv2. It will not fallback to IMDSv1. See
69/// [transitioning to IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-transition-to-version-2)
70/// for more information._
71///
72/// **Note**: When running in a Docker container, all network requests will incur an additional hop. When combined with the default IMDS hop limit of 1, this will cause requests to IMDS to timeout! To fix this issue, you'll need to set the following instance metadata settings :
73/// ```txt
74/// amazonec2-metadata-token=required
75/// amazonec2-metadata-token-response-hop-limit=2
76/// ```
77///
78/// On an instance that is already running, these can be set with [ModifyInstanceMetadataOptions](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_ModifyInstanceMetadataOptions.html). On a new instance, these can be set with the `MetadataOptions` field on [RunInstances](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_RunInstances.html).
79///
80/// For more information about IMDSv2 vs. IMDSv1 see [this guide](https://docs.aws.amazon.com/AWSEC2/latest/WindowsGuide/configuring-instance-metadata-service.html)
81///
82/// # Client Configuration
83/// The IMDS client can load configuration explicitly, via environment variables, or via
84/// `~/.aws/config`. It will first attempt to resolve an endpoint override. If no endpoint
85/// override exists, it will attempt to resolve an [`EndpointMode`]. If no
86/// [`EndpointMode`] override exists, it will fallback to [`IpV4`](EndpointMode::IpV4). An exhaustive
87/// list is below:
88///
89/// ## Endpoint configuration list
90/// 1. Explicit configuration of `Endpoint` via the [builder](Builder):
91/// ```no_run
92/// use aws_config::imds::client::Client;
93/// # async fn docs() {
94/// let client = Client::builder()
95///   .endpoint("http://customimds:456/").expect("valid URI")
96///   .build();
97/// # }
98/// ```
99///
100/// 2. The `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable. Note: If this environment variable
101///    is set, it MUST contain a valid URI or client construction will fail.
102///
103/// 3. The `ec2_metadata_service_endpoint` field in `~/.aws/config`:
104/// ```ini
105/// [default]
106/// # ... other configuration
107/// ec2_metadata_service_endpoint = http://my-custom-endpoint:444
108/// ```
109///
110/// 4. An explicitly set endpoint mode:
111/// ```no_run
112/// use aws_config::imds::client::{Client, EndpointMode};
113/// # async fn docs() {
114/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
115/// # }
116/// ```
117///
118/// 5. An [endpoint mode](EndpointMode) loaded from the `AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE` environment
119///    variable. Valid values: `IPv4`, `IPv6`
120///
121/// 6. An [endpoint mode](EndpointMode) loaded from the `ec2_metadata_service_endpoint_mode` field in
122///    `~/.aws/config`:
123/// ```ini
124/// [default]
125/// # ... other configuration
126/// ec2_metadata_service_endpoint_mode = IPv4
127/// ```
128///
129/// 7. The default value of `http://169.254.169.254` will be used.
130///
131#[derive(Clone, Debug)]
132pub struct Client {
133    operation: Operation<String, SensitiveString, InnerImdsError>,
134}
135
136impl Client {
137    /// IMDS client builder
138    pub fn builder() -> Builder {
139        Builder::default()
140    }
141
142    /// Retrieve information from IMDS
143    ///
144    /// This method will handle loading and caching a session token, combining the `path` with the
145    /// configured IMDS endpoint, and retrying potential errors.
146    ///
147    /// For more information about IMDSv2 methods and functionality, see
148    /// [Instance metadata and user data](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
149    ///
150    /// # Examples
151    ///
152    /// ```no_run
153    /// use aws_config::imds::client::Client;
154    /// # async fn docs() {
155    /// let client = Client::builder().build();
156    /// let ami_id = client
157    ///   .get("/latest/meta-data/ami-id")
158    ///   .await
159    ///   .expect("failure communicating with IMDS");
160    /// # }
161    /// ```
162    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
163        self.operation
164            .invoke(path.into())
165            .await
166            .map_err(|err| match err {
167                SdkError::ConstructionFailure(_) if err.source().is_some() => {
168                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
169                        Ok(Ok(token_failure)) => *token_failure,
170                        Ok(Err(err)) => ImdsError::unexpected(err),
171                        Err(err) => ImdsError::unexpected(err),
172                    }
173                }
174                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
175                SdkError::ServiceError(context) => match context.err() {
176                    InnerImdsError::InvalidUtf8 => {
177                        ImdsError::unexpected("IMDS returned invalid UTF-8")
178                    }
179                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
180                },
181                // If the error source is an ImdsError, then we need to directly return that source.
182                // That way, the IMDS token provider's errors can become the top-level ImdsError.
183                // There is a unit test that checks the correct error is being extracted.
184                err @ SdkError::DispatchFailure(_) => match err.into_source() {
185                    Ok(source) => match source.downcast::<ConnectorError>() {
186                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
187                            Ok(source) => *source,
188                            Err(err) => ImdsError::unexpected(err),
189                        },
190                        Err(err) => ImdsError::unexpected(err),
191                    },
192                    Err(err) => ImdsError::unexpected(err),
193                },
194                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
195                _ => ImdsError::unexpected(err),
196            })
197    }
198}
199
200/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
201#[derive(Clone)]
202pub struct SensitiveString(String);
203
204impl fmt::Debug for SensitiveString {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        f.debug_tuple("SensitiveString")
207            .field(&"** redacted **")
208            .finish()
209    }
210}
211
212impl AsRef<str> for SensitiveString {
213    fn as_ref(&self) -> &str {
214        &self.0
215    }
216}
217
218impl From<String> for SensitiveString {
219    fn from(value: String) -> Self {
220        Self(value)
221    }
222}
223
224impl From<SensitiveString> for String {
225    fn from(value: SensitiveString) -> Self {
226        value.0
227    }
228}
229
230/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
231/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
232/// sensitive, configures user agent headers, and sets up retries and timeouts.
233#[derive(Debug)]
234struct ImdsCommonRuntimePlugin {
235    config: FrozenLayer,
236    components: RuntimeComponentsBuilder,
237}
238
239impl ImdsCommonRuntimePlugin {
240    fn new(
241        config: &ProviderConfig,
242        endpoint_resolver: ImdsEndpointResolver,
243        retry_config: RetryConfig,
244        retry_classifier: SharedRetryClassifier,
245        timeout_config: TimeoutConfig,
246    ) -> Self {
247        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
248        layer.store_put(AuthSchemeOptionResolverParams::new(()));
249        layer.store_put(EndpointResolverParams::new(()));
250        layer.store_put(SensitiveOutput);
251        layer.store_put(retry_config);
252        layer.store_put(timeout_config);
253        layer.store_put(user_agent());
254
255        Self {
256            config: layer.freeze(),
257            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
258                .with_http_client(config.http_client())
259                .with_endpoint_resolver(Some(endpoint_resolver))
260                .with_interceptor(UserAgentInterceptor::new())
261                .with_retry_classifier(retry_classifier)
262                .with_retry_strategy(Some(StandardRetryStrategy::new()))
263                .with_time_source(Some(config.time_source()))
264                .with_sleep_impl(config.sleep_impl()),
265        }
266    }
267}
268
269impl RuntimePlugin for ImdsCommonRuntimePlugin {
270    fn config(&self) -> Option<FrozenLayer> {
271        Some(self.config.clone())
272    }
273
274    fn runtime_components(
275        &self,
276        _current_components: &RuntimeComponentsBuilder,
277    ) -> Cow<'_, RuntimeComponentsBuilder> {
278        Cow::Borrowed(&self.components)
279    }
280}
281
282/// IMDSv2 Endpoint Mode
283///
284/// IMDS can be accessed in two ways:
285/// 1. Via the IpV4 endpoint: `http://169.254.169.254`
286/// 2. Via the Ipv6 endpoint: `http://[fd00:ec2::254]`
287#[derive(Debug, Clone)]
288#[non_exhaustive]
289pub enum EndpointMode {
290    /// IpV4 mode: `http://169.254.169.254`
291    ///
292    /// This mode is the default unless otherwise specified.
293    IpV4,
294    /// IpV6 mode: `http://[fd00:ec2::254]`
295    IpV6,
296}
297
298impl FromStr for EndpointMode {
299    type Err = InvalidEndpointMode;
300
301    fn from_str(value: &str) -> Result<Self, Self::Err> {
302        match value {
303            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
304            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
305            other => Err(InvalidEndpointMode::new(other.to_owned())),
306        }
307    }
308}
309
310impl EndpointMode {
311    /// IMDS URI for this endpoint mode
312    fn endpoint(&self) -> Uri {
313        match self {
314            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
315            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
316        }
317    }
318}
319
320/// IMDSv2 Client Builder
321#[derive(Default, Debug, Clone)]
322pub struct Builder {
323    max_attempts: Option<u32>,
324    endpoint: Option<EndpointSource>,
325    mode_override: Option<EndpointMode>,
326    token_ttl: Option<Duration>,
327    connect_timeout: Option<Duration>,
328    read_timeout: Option<Duration>,
329    operation_timeout: Option<Duration>,
330    operation_attempt_timeout: Option<Duration>,
331    config: Option<ProviderConfig>,
332    retry_classifier: Option<SharedRetryClassifier>,
333}
334
335impl Builder {
336    /// Override the number of retries for fetching tokens & metadata
337    ///
338    /// By default, 4 attempts will be made.
339    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
340        self.max_attempts = Some(max_attempts);
341        self
342    }
343
344    /// Configure generic options of the [`Client`]
345    ///
346    /// # Examples
347    /// ```no_run
348    /// # async fn test() {
349    /// use aws_config::imds::Client;
350    /// use aws_config::provider_config::ProviderConfig;
351    ///
352    /// let provider = Client::builder()
353    ///     .configure(&ProviderConfig::with_default_region().await)
354    ///     .build();
355    /// # }
356    /// ```
357    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
358        self.config = Some(provider_config.clone());
359        self
360    }
361
362    /// Override the endpoint for the [`Client`]
363    ///
364    /// By default, the client will resolve an endpoint from the environment, AWS config, and endpoint mode.
365    ///
366    /// See [`Client`] for more information.
367    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
368        let uri: Uri = endpoint.as_ref().parse()?;
369        self.endpoint = Some(EndpointSource::Explicit(uri));
370        Ok(self)
371    }
372
373    /// Override the endpoint mode for [`Client`]
374    ///
375    /// * When set to [`IpV4`](EndpointMode::IpV4), the endpoint will be `http://169.254.169.254`.
376    /// * When set to [`IpV6`](EndpointMode::IpV6), the endpoint will be `http://[fd00:ec2::254]`.
377    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
378        self.mode_override = Some(mode);
379        self
380    }
381
382    /// Override the time-to-live for the session token
383    ///
384    /// Requests to IMDS utilize a session token for authentication. By default, session tokens last
385    /// for 6 hours. When the TTL for the token expires, a new token must be retrieved from the
386    /// metadata service.
387    pub fn token_ttl(mut self, ttl: Duration) -> Self {
388        self.token_ttl = Some(ttl);
389        self
390    }
391
392    /// Override the connect timeout for IMDS
393    ///
394    /// This value defaults to 1 second
395    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
396        self.connect_timeout = Some(timeout);
397        self
398    }
399
400    /// Override the read timeout for IMDS
401    ///
402    /// This value defaults to 1 second
403    pub fn read_timeout(mut self, timeout: Duration) -> Self {
404        self.read_timeout = Some(timeout);
405        self
406    }
407
408    /// Override the operation timeout for IMDS
409    ///
410    /// This value defaults to 1 second
411    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
412        self.operation_timeout = Some(timeout);
413        self
414    }
415
416    /// Override the operation attempt timeout for IMDS
417    ///
418    /// This value defaults to 1 second
419    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
420        self.operation_attempt_timeout = Some(timeout);
421        self
422    }
423
424    /// Override the retry classifier for IMDS
425    ///
426    /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
427    /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
428    /// here or you can create your own fully customized [SharedRetryClassifier].
429    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
430        self.retry_classifier = Some(retry_classifier);
431        self
432    }
433
434    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
435    /*
436    pub fn port(mut self, port: u32) -> Self {
437        self.port_override = Some(port);
438        self
439    }*/
440
441    /// Build an IMDSv2 Client
442    pub fn build(self) -> Client {
443        let config = self.config.unwrap_or_default();
444        let timeout_config = TimeoutConfig::builder()
445            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
446            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
447            .operation_attempt_timeout(
448                self.operation_attempt_timeout
449                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
450            )
451            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
452            .build();
453        let endpoint_source = self
454            .endpoint
455            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
456        let endpoint_resolver = ImdsEndpointResolver {
457            endpoint_source: Arc::new(endpoint_source),
458            mode_override: self.mode_override,
459        };
460        let retry_config = RetryConfig::standard()
461            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
462        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
463            ImdsResponseRetryClassifier::default(),
464        ));
465        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
466            &config,
467            endpoint_resolver,
468            retry_config,
469            retry_classifier,
470            timeout_config,
471        ));
472        let operation = Operation::builder()
473            .service_name("imds")
474            .operation_name("get")
475            .runtime_plugin(common_plugin.clone())
476            .runtime_plugin(TokenRuntimePlugin::new(
477                common_plugin,
478                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
479            ))
480            .runtime_plugin(
481                MetricsRuntimePlugin::builder()
482                    .with_scope("aws_config::imds_credentials")
483                    .with_time_source(config.time_source())
484                    .with_metadata(Metadata::new("get_credentials", "imds"))
485                    .build()
486                    .expect("All required fields have been set"),
487            )
488            .with_connection_poisoning()
489            .serializer(|path| {
490                Ok(HttpRequest::try_from(
491                    http::Request::builder()
492                        .uri(path)
493                        .body(SdkBody::empty())
494                        .expect("valid request"),
495                )
496                .unwrap())
497            })
498            .deserializer(|response| {
499                if response.status().is_success() {
500                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
501                        .map(|data| SensitiveString::from(data.to_string()))
502                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
503                } else {
504                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
505                }
506            })
507            .build();
508        Client { operation }
509    }
510}
511
512mod env {
513    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
514    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
515}
516
517mod profile_keys {
518    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
519    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
520}
521
522/// Endpoint Configuration Abstraction
523#[derive(Debug, Clone)]
524enum EndpointSource {
525    Explicit(Uri),
526    Env(ProviderConfig),
527}
528
529impl EndpointSource {
530    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
531        match self {
532            EndpointSource::Explicit(uri) => {
533                if mode_override.is_some() {
534                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
535                        "Endpoint mode override was set in combination with an explicit endpoint. \
536                        The mode override will be ignored.")
537                }
538                Ok(uri.clone())
539            }
540            EndpointSource::Env(conf) => {
541                let env = conf.env();
542                // load an endpoint override from the environment
543                let profile = conf.profile().await;
544                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
545                    Some(Cow::Owned(uri))
546                } else {
547                    profile
548                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
549                        .map(Cow::Borrowed)
550                };
551                if let Some(uri) = uri_override {
552                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
553                }
554
555                // if not, load a endpoint mode from the environment
556                let mode = if let Some(mode) = mode_override {
557                    mode
558                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
559                    mode.parse::<EndpointMode>()
560                        .map_err(BuildError::invalid_endpoint_mode)?
561                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
562                {
563                    mode.parse::<EndpointMode>()
564                        .map_err(BuildError::invalid_endpoint_mode)?
565                } else {
566                    EndpointMode::IpV4
567                };
568
569                Ok(mode.endpoint())
570            }
571        }
572    }
573}
574
575#[derive(Clone, Debug)]
576struct ImdsEndpointResolver {
577    endpoint_source: Arc<EndpointSource>,
578    mode_override: Option<EndpointMode>,
579}
580
581impl ResolveEndpoint for ImdsEndpointResolver {
582    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
583        EndpointFuture::new(async move {
584            self.endpoint_source
585                .endpoint(self.mode_override.clone())
586                .await
587                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
588                .map_err(|err| err.into())
589        })
590    }
591}
592
593/// IMDS Response Retry Classifier
594///
595/// Possible status codes:
596/// - 200 (OK)
597/// - 400 (Missing or invalid parameters) **Not Retryable**
598/// - 401 (Unauthorized, expired token) **Retryable**
599/// - 403 (IMDS disabled): **Not Retryable**
600/// - 404 (Not found): **Not Retryable**
601/// - >=500 (server error): **Retryable**
602/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
603#[derive(Clone, Debug, Default)]
604#[non_exhaustive]
605pub struct ImdsResponseRetryClassifier {
606    retry_connect_timeouts: bool,
607}
608
609impl ImdsResponseRetryClassifier {
610    /// Indicate whether the IMDS client should retry on connection timeouts
611    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
612        self.retry_connect_timeouts = retry_connect_timeouts;
613        self
614    }
615}
616
617impl ClassifyRetry for ImdsResponseRetryClassifier {
618    fn name(&self) -> &'static str {
619        "ImdsResponseRetryClassifier"
620    }
621
622    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
623        if let Some(response) = ctx.response() {
624            let status = response.status();
625            match status {
626                _ if status.is_server_error() => RetryAction::server_error(),
627                // 401 indicates that the token has expired, this is retryable
628                _ if status.as_u16() == 401 => RetryAction::server_error(),
629                // This catch-all includes successful responses that fail to parse. These should not be retried.
630                _ => RetryAction::NoActionIndicated,
631            }
632        } else if self.retry_connect_timeouts {
633            RetryAction::server_error()
634        } else {
635            // This is the default behavior.
636            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
637            // credentials provider chain to fail to provide credentials.
638            // Also don't retry non-responses.
639            RetryAction::NoActionIndicated
640        }
641    }
642}
643
644#[cfg(test)]
645pub(crate) mod test {
646    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
647    use crate::provider_config::ProviderConfig;
648    use aws_smithy_async::rt::sleep::TokioSleep;
649    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
650    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
651    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
652    use aws_smithy_runtime_api::client::interceptors::context::{
653        Input, InterceptorContext, Output,
654    };
655    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
656    use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
657    use aws_smithy_runtime_api::client::result::ConnectorError;
658    use aws_smithy_runtime_api::client::retries::classifiers::{
659        ClassifyRetry, RetryAction, SharedRetryClassifier,
660    };
661    use aws_smithy_types::body::SdkBody;
662    use aws_smithy_types::error::display::DisplayErrorContext;
663    use aws_types::os_shim_internal::{Env, Fs};
664    use http::header::USER_AGENT;
665    use http::Uri;
666    use serde::Deserialize;
667    use std::collections::HashMap;
668    use std::error::Error;
669    use std::io;
670    use std::time::SystemTime;
671    use std::time::{Duration, UNIX_EPOCH};
672    use tracing_test::traced_test;
673
674    macro_rules! assert_full_error_contains {
675        ($err:expr, $contains:expr) => {
676            let err = $err;
677            let message = format!(
678                "{}",
679                aws_smithy_types::error::display::DisplayErrorContext(&err)
680            );
681            assert!(
682                message.contains($contains),
683                "Error message '{message}' didn't contain text '{}'",
684                $contains
685            );
686        };
687    }
688
689    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
690    const TOKEN_B: &str = "alternatetoken==";
691
692    /// Create a simple token request
693    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
694        http::Request::builder()
695            .uri(format!("{}/latest/api/token", base))
696            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
697            .method("PUT")
698            .body(SdkBody::empty())
699            .unwrap()
700            .try_into()
701            .unwrap()
702    }
703
704    /// Create a simple token response
705    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
706        HttpResponse::try_from(
707            http::Response::builder()
708                .status(200)
709                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
710                .body(SdkBody::from(token))
711                .unwrap(),
712        )
713        .unwrap()
714    }
715
716    /// Create a simple IMDS request
717    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
718        http::Request::builder()
719            .uri(Uri::from_static(path))
720            .method("GET")
721            .header("x-aws-ec2-metadata-token", token)
722            .body(SdkBody::empty())
723            .unwrap()
724            .try_into()
725            .unwrap()
726    }
727
728    /// Create a simple IMDS response
729    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
730        HttpResponse::try_from(
731            http::Response::builder()
732                .status(200)
733                .body(SdkBody::from(body))
734                .unwrap(),
735        )
736        .unwrap()
737    }
738
739    /// Create an IMDS client with an underlying [StaticReplayClient]
740    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
741        tokio::time::pause();
742        super::Client::builder()
743            .configure(
744                &ProviderConfig::no_configuration()
745                    .with_sleep_impl(InstantSleep::unlogged())
746                    .with_http_client(http_client.clone()),
747            )
748            .build()
749    }
750
751    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
752        let http_client = StaticReplayClient::new(events);
753        let client = make_imds_client(&http_client);
754        (client, http_client)
755    }
756
757    #[tokio::test]
758    async fn client_caches_token() {
759        let (client, http_client) = mock_imds_client(vec![
760            ReplayEvent::new(
761                token_request("http://169.254.169.254", 21600),
762                token_response(21600, TOKEN_A),
763            ),
764            ReplayEvent::new(
765                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
766                imds_response(r#"test-imds-output"#),
767            ),
768            ReplayEvent::new(
769                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
770                imds_response("output2"),
771            ),
772        ]);
773        // load once
774        let metadata = client.get("/latest/metadata").await.expect("failed");
775        assert_eq!("test-imds-output", metadata.as_ref());
776        // load again: the cached token should be used
777        let metadata = client.get("/latest/metadata2").await.expect("failed");
778        assert_eq!("output2", metadata.as_ref());
779        http_client.assert_requests_match(&[]);
780    }
781
782    #[tokio::test]
783    async fn token_can_expire() {
784        let (_, http_client) = mock_imds_client(vec![
785            ReplayEvent::new(
786                token_request("http://[fd00:ec2::254]", 600),
787                token_response(600, TOKEN_A),
788            ),
789            ReplayEvent::new(
790                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
791                imds_response(r#"test-imds-output1"#),
792            ),
793            ReplayEvent::new(
794                token_request("http://[fd00:ec2::254]", 600),
795                token_response(600, TOKEN_B),
796            ),
797            ReplayEvent::new(
798                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
799                imds_response(r#"test-imds-output2"#),
800            ),
801        ]);
802        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
803        let client = super::Client::builder()
804            .configure(
805                &ProviderConfig::no_configuration()
806                    .with_http_client(http_client.clone())
807                    .with_time_source(time_source.clone())
808                    .with_sleep_impl(sleep),
809            )
810            .endpoint_mode(EndpointMode::IpV6)
811            .token_ttl(Duration::from_secs(600))
812            .build();
813
814        let resp1 = client.get("/latest/metadata").await.expect("success");
815        // now the cached credential has expired
816        time_source.advance(Duration::from_secs(600));
817        let resp2 = client.get("/latest/metadata").await.expect("success");
818        http_client.assert_requests_match(&[]);
819        assert_eq!("test-imds-output1", resp1.as_ref());
820        assert_eq!("test-imds-output2", resp2.as_ref());
821    }
822
823    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
824    #[tokio::test]
825    async fn token_refresh_buffer() {
826        let _logs = capture_test_logs();
827        let (_, http_client) = mock_imds_client(vec![
828            ReplayEvent::new(
829                token_request("http://[fd00:ec2::254]", 600),
830                token_response(600, TOKEN_A),
831            ),
832            // t = 0
833            ReplayEvent::new(
834                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
835                imds_response(r#"test-imds-output1"#),
836            ),
837            // t = 400 (no refresh)
838            ReplayEvent::new(
839                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
840                imds_response(r#"test-imds-output2"#),
841            ),
842            // t = 550 (within buffer)
843            ReplayEvent::new(
844                token_request("http://[fd00:ec2::254]", 600),
845                token_response(600, TOKEN_B),
846            ),
847            ReplayEvent::new(
848                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
849                imds_response(r#"test-imds-output3"#),
850            ),
851        ]);
852        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
853        let client = super::Client::builder()
854            .configure(
855                &ProviderConfig::no_configuration()
856                    .with_sleep_impl(sleep)
857                    .with_http_client(http_client.clone())
858                    .with_time_source(time_source.clone()),
859            )
860            .endpoint_mode(EndpointMode::IpV6)
861            .token_ttl(Duration::from_secs(600))
862            .build();
863
864        tracing::info!("resp1 -----------------------------------------------------------");
865        let resp1 = client.get("/latest/metadata").await.expect("success");
866        // now the cached credential has expired
867        time_source.advance(Duration::from_secs(400));
868        tracing::info!("resp2 -----------------------------------------------------------");
869        let resp2 = client.get("/latest/metadata").await.expect("success");
870        time_source.advance(Duration::from_secs(150));
871        tracing::info!("resp3 -----------------------------------------------------------");
872        let resp3 = client.get("/latest/metadata").await.expect("success");
873        http_client.assert_requests_match(&[]);
874        assert_eq!("test-imds-output1", resp1.as_ref());
875        assert_eq!("test-imds-output2", resp2.as_ref());
876        assert_eq!("test-imds-output3", resp3.as_ref());
877    }
878
879    /// 500 error during the GET should be retried
880    #[tokio::test]
881    #[traced_test]
882    async fn retry_500() {
883        let (client, http_client) = mock_imds_client(vec![
884            ReplayEvent::new(
885                token_request("http://169.254.169.254", 21600),
886                token_response(21600, TOKEN_A),
887            ),
888            ReplayEvent::new(
889                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
890                http::Response::builder()
891                    .status(500)
892                    .body(SdkBody::empty())
893                    .unwrap(),
894            ),
895            ReplayEvent::new(
896                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
897                imds_response("ok"),
898            ),
899        ]);
900        assert_eq!(
901            "ok",
902            client
903                .get("/latest/metadata")
904                .await
905                .expect("success")
906                .as_ref()
907        );
908        http_client.assert_requests_match(&[]);
909
910        // all requests should have a user agent header
911        for request in http_client.actual_requests() {
912            assert!(request.headers().get(USER_AGENT).is_some());
913        }
914    }
915
916    /// 500 error during token acquisition should be retried
917    #[tokio::test]
918    #[traced_test]
919    async fn retry_token_failure() {
920        let (client, http_client) = mock_imds_client(vec![
921            ReplayEvent::new(
922                token_request("http://169.254.169.254", 21600),
923                http::Response::builder()
924                    .status(500)
925                    .body(SdkBody::empty())
926                    .unwrap(),
927            ),
928            ReplayEvent::new(
929                token_request("http://169.254.169.254", 21600),
930                token_response(21600, TOKEN_A),
931            ),
932            ReplayEvent::new(
933                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
934                imds_response("ok"),
935            ),
936        ]);
937        assert_eq!(
938            "ok",
939            client
940                .get("/latest/metadata")
941                .await
942                .expect("success")
943                .as_ref()
944        );
945        http_client.assert_requests_match(&[]);
946    }
947
948    /// 401 error during metadata retrieval must be retried
949    #[tokio::test]
950    #[traced_test]
951    async fn retry_metadata_401() {
952        let (client, http_client) = mock_imds_client(vec![
953            ReplayEvent::new(
954                token_request("http://169.254.169.254", 21600),
955                token_response(0, TOKEN_A),
956            ),
957            ReplayEvent::new(
958                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
959                http::Response::builder()
960                    .status(401)
961                    .body(SdkBody::empty())
962                    .unwrap(),
963            ),
964            ReplayEvent::new(
965                token_request("http://169.254.169.254", 21600),
966                token_response(21600, TOKEN_B),
967            ),
968            ReplayEvent::new(
969                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
970                imds_response("ok"),
971            ),
972        ]);
973        assert_eq!(
974            "ok",
975            client
976                .get("/latest/metadata")
977                .await
978                .expect("success")
979                .as_ref()
980        );
981        http_client.assert_requests_match(&[]);
982    }
983
984    /// 403 responses from IMDS during token acquisition MUST NOT be retried
985    #[tokio::test]
986    #[traced_test]
987    async fn no_403_retry() {
988        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
989            token_request("http://169.254.169.254", 21600),
990            http::Response::builder()
991                .status(403)
992                .body(SdkBody::empty())
993                .unwrap(),
994        )]);
995        let err = client.get("/latest/metadata").await.expect_err("no token");
996        assert_full_error_contains!(err, "forbidden");
997        http_client.assert_requests_match(&[]);
998    }
999
1000    /// The classifier should return `None` when classifying a successful response.
1001    #[test]
1002    fn successful_response_properly_classified() {
1003        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1004        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
1005        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
1006        let classifier = ImdsResponseRetryClassifier::default();
1007        assert_eq!(
1008            RetryAction::NoActionIndicated,
1009            classifier.classify_retry(&ctx)
1010        );
1011
1012        // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
1013        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1014        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1015            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1016        ))));
1017        assert_eq!(
1018            RetryAction::NoActionIndicated,
1019            classifier.classify_retry(&ctx)
1020        );
1021    }
1022
1023    /// User provided retry classifier works
1024    #[tokio::test]
1025    async fn user_provided_retry_classifier() {
1026        #[derive(Clone, Debug)]
1027        struct UserProvidedRetryClassifier;
1028
1029        impl ClassifyRetry for UserProvidedRetryClassifier {
1030            fn name(&self) -> &'static str {
1031                "UserProvidedRetryClassifier"
1032            }
1033
1034            // Don't retry anything
1035            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1036                RetryAction::RetryForbidden
1037            }
1038        }
1039
1040        let events = vec![
1041            ReplayEvent::new(
1042                token_request("http://169.254.169.254", 21600),
1043                token_response(0, TOKEN_A),
1044            ),
1045            ReplayEvent::new(
1046                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1047                http::Response::builder()
1048                    .status(401)
1049                    .body(SdkBody::empty())
1050                    .unwrap(),
1051            ),
1052            ReplayEvent::new(
1053                token_request("http://169.254.169.254", 21600),
1054                token_response(21600, TOKEN_B),
1055            ),
1056            ReplayEvent::new(
1057                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1058                imds_response("ok"),
1059            ),
1060        ];
1061        let http_client = StaticReplayClient::new(events);
1062
1063        let imds_client = super::Client::builder()
1064            .configure(
1065                &ProviderConfig::no_configuration()
1066                    .with_sleep_impl(InstantSleep::unlogged())
1067                    .with_http_client(http_client.clone()),
1068            )
1069            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1070            .build();
1071
1072        let res = imds_client
1073            .get("/latest/metadata")
1074            .await
1075            .expect_err("Client should error");
1076
1077        // Assert that the operation errored on the initial 401 and did not retry and get
1078        // the 200 (since the user provided retry classifier never retries)
1079        assert_full_error_contains!(res, "401");
1080    }
1081
1082    // since tokens are sent as headers, the tokens need to be valid header values
1083    #[tokio::test]
1084    async fn invalid_token() {
1085        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1086            token_request("http://169.254.169.254", 21600),
1087            token_response(21600, "invalid\nheader\nvalue\0"),
1088        )]);
1089        let err = client.get("/latest/metadata").await.expect_err("no token");
1090        assert_full_error_contains!(err, "invalid token");
1091        http_client.assert_requests_match(&[]);
1092    }
1093
1094    #[tokio::test]
1095    async fn non_utf8_response() {
1096        let (client, http_client) = mock_imds_client(vec![
1097            ReplayEvent::new(
1098                token_request("http://169.254.169.254", 21600),
1099                token_response(21600, TOKEN_A).map(SdkBody::from),
1100            ),
1101            ReplayEvent::new(
1102                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1103                http::Response::builder()
1104                    .status(200)
1105                    .body(SdkBody::from(vec![0xA0, 0xA1]))
1106                    .unwrap(),
1107            ),
1108        ]);
1109        let err = client.get("/latest/metadata").await.expect_err("no token");
1110        assert_full_error_contains!(err, "invalid UTF-8");
1111        http_client.assert_requests_match(&[]);
1112    }
1113
1114    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
1115    #[cfg_attr(windows, ignore)]
1116    /// Verify that the end-to-end real client has a 1-second connect timeout
1117    #[tokio::test]
1118    #[cfg(feature = "default-https-client")]
1119    async fn one_second_connect_timeout() {
1120        use crate::imds::client::ImdsError;
1121        let client = Client::builder()
1122            // 240.* can never be resolved
1123            .endpoint("http://240.0.0.0")
1124            .expect("valid uri")
1125            .build();
1126        let now = SystemTime::now();
1127        let resp = client
1128            .get("/latest/metadata")
1129            .await
1130            .expect_err("240.0.0.0 will never resolve");
1131        match resp {
1132            err @ ImdsError::FailedToLoadToken(_)
1133                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
1134            other => panic!(
1135                "wrong error, expected construction failure with TimedOutError inside: {}",
1136                DisplayErrorContext(&other)
1137            ),
1138        }
1139        let time_elapsed = now.elapsed().unwrap();
1140        assert!(
1141            time_elapsed > Duration::from_secs(1),
1142            "time_elapsed should be greater than 1s but was {:?}",
1143            time_elapsed
1144        );
1145        assert!(
1146            time_elapsed < Duration::from_secs(2),
1147            "time_elapsed should be less than 2s but was {:?}",
1148            time_elapsed
1149        );
1150    }
1151
1152    /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
1153    #[tokio::test]
1154    async fn retry_connect_timeouts() {
1155        let http_client = StaticReplayClient::new(vec![]);
1156        let imds_client = super::Client::builder()
1157            .retry_classifier(SharedRetryClassifier::new(
1158                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1159            ))
1160            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1161            .operation_timeout(Duration::from_secs(1))
1162            .endpoint("http://240.0.0.0")
1163            .expect("valid uri")
1164            .build();
1165
1166        let now = SystemTime::now();
1167        let _res = imds_client
1168            .get("/latest/metadata")
1169            .await
1170            .expect_err("240.0.0.0 will never resolve");
1171        let time_elapsed: Duration = now.elapsed().unwrap();
1172
1173        assert!(
1174            time_elapsed > Duration::from_secs(1),
1175            "time_elapsed should be greater than 1s but was {:?}",
1176            time_elapsed
1177        );
1178
1179        assert!(
1180            time_elapsed < Duration::from_secs(2),
1181            "time_elapsed should be less than 2s but was {:?}",
1182            time_elapsed
1183        );
1184    }
1185
1186    #[derive(Debug, Deserialize)]
1187    struct ImdsConfigTest {
1188        env: HashMap<String, String>,
1189        fs: HashMap<String, String>,
1190        endpoint_override: Option<String>,
1191        mode_override: Option<String>,
1192        result: Result<String, String>,
1193        docs: String,
1194    }
1195
1196    #[tokio::test]
1197    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1198        let _logs = capture_test_logs();
1199
1200        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1201        #[derive(Deserialize)]
1202        struct TestCases {
1203            tests: Vec<ImdsConfigTest>,
1204        }
1205
1206        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1207        let test_cases = test_cases.tests;
1208        for test in test_cases {
1209            check(test).await;
1210        }
1211        Ok(())
1212    }
1213
1214    async fn check(test_case: ImdsConfigTest) {
1215        let (http_client, watcher) = capture_request(None);
1216        let provider_config = ProviderConfig::no_configuration()
1217            .with_sleep_impl(TokioSleep::new())
1218            .with_env(Env::from(test_case.env))
1219            .with_fs(Fs::from_map(test_case.fs))
1220            .with_http_client(http_client);
1221        let mut imds_client = Client::builder().configure(&provider_config);
1222        if let Some(endpoint_override) = test_case.endpoint_override {
1223            imds_client = imds_client
1224                .endpoint(endpoint_override)
1225                .expect("invalid URI");
1226        }
1227
1228        if let Some(mode_override) = test_case.mode_override {
1229            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1230        }
1231
1232        let imds_client = imds_client.build();
1233        match &test_case.result {
1234            Ok(uri) => {
1235                // this request will fail, we just want to capture the endpoint configuration
1236                let _ = imds_client.get("/hello").await;
1237                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1238            }
1239            Err(expected) => {
1240                let err = imds_client.get("/hello").await.expect_err("it should fail");
1241                let message = format!("{}", DisplayErrorContext(&err));
1242                assert!(
1243                    message.contains(expected),
1244                    "{}\nexpected error: {expected}\nactual error: {message}",
1245                    test_case.docs
1246                );
1247            }
1248        };
1249    }
1250}