1use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::metrics::MetricsRuntimePlugin;
16use aws_smithy_runtime::client::orchestrator::operation::Operation;
17use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
18use aws_smithy_runtime_api::box_error::BoxError;
19use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
20use aws_smithy_runtime_api::client::endpoint::{
21    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
22};
23use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
24use aws_smithy_runtime_api::client::orchestrator::{
25    HttpRequest, Metadata, OrchestratorError, SensitiveOutput,
26};
27use aws_smithy_runtime_api::client::result::ConnectorError;
28use aws_smithy_runtime_api::client::result::SdkError;
29use aws_smithy_runtime_api::client::retries::classifiers::{
30    ClassifyRetry, RetryAction, SharedRetryClassifier,
31};
32use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
33use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
34use aws_smithy_types::body::SdkBody;
35use aws_smithy_types::config_bag::{FrozenLayer, Layer};
36use aws_smithy_types::endpoint::Endpoint;
37use aws_smithy_types::retry::RetryConfig;
38use aws_smithy_types::timeout::TimeoutConfig;
39use aws_types::os_shim_internal::Env;
40use http::Uri;
41use std::borrow::Cow;
42use std::error::Error as _;
43use std::fmt;
44use std::str::FromStr;
45use std::sync::Arc;
46use std::time::Duration;
47
48pub mod error;
49mod token;
50
51const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
53const DEFAULT_ATTEMPTS: u32 = 4;
54const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
56const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
57const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
58
59fn user_agent() -> AwsUserAgent {
60    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
61}
62
63#[derive(Clone, Debug)]
132pub struct Client {
133    operation: Operation<String, SensitiveString, InnerImdsError>,
134}
135
136impl Client {
137    pub fn builder() -> Builder {
139        Builder::default()
140    }
141
142    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
163        self.operation
164            .invoke(path.into())
165            .await
166            .map_err(|err| match err {
167                SdkError::ConstructionFailure(_) if err.source().is_some() => {
168                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
169                        Ok(Ok(token_failure)) => *token_failure,
170                        Ok(Err(err)) => ImdsError::unexpected(err),
171                        Err(err) => ImdsError::unexpected(err),
172                    }
173                }
174                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
175                SdkError::ServiceError(context) => match context.err() {
176                    InnerImdsError::InvalidUtf8 => {
177                        ImdsError::unexpected("IMDS returned invalid UTF-8")
178                    }
179                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
180                },
181                err @ SdkError::DispatchFailure(_) => match err.into_source() {
185                    Ok(source) => match source.downcast::<ConnectorError>() {
186                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
187                            Ok(source) => *source,
188                            Err(err) => ImdsError::unexpected(err),
189                        },
190                        Err(err) => ImdsError::unexpected(err),
191                    },
192                    Err(err) => ImdsError::unexpected(err),
193                },
194                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
195                _ => ImdsError::unexpected(err),
196            })
197    }
198}
199
200#[derive(Clone)]
202pub struct SensitiveString(String);
203
204impl fmt::Debug for SensitiveString {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        f.debug_tuple("SensitiveString")
207            .field(&"** redacted **")
208            .finish()
209    }
210}
211
212impl AsRef<str> for SensitiveString {
213    fn as_ref(&self) -> &str {
214        &self.0
215    }
216}
217
218impl From<String> for SensitiveString {
219    fn from(value: String) -> Self {
220        Self(value)
221    }
222}
223
224impl From<SensitiveString> for String {
225    fn from(value: SensitiveString) -> Self {
226        value.0
227    }
228}
229
230#[derive(Debug)]
234struct ImdsCommonRuntimePlugin {
235    config: FrozenLayer,
236    components: RuntimeComponentsBuilder,
237}
238
239impl ImdsCommonRuntimePlugin {
240    fn new(
241        config: &ProviderConfig,
242        endpoint_resolver: ImdsEndpointResolver,
243        retry_config: RetryConfig,
244        retry_classifier: SharedRetryClassifier,
245        timeout_config: TimeoutConfig,
246    ) -> Self {
247        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
248        layer.store_put(AuthSchemeOptionResolverParams::new(()));
249        layer.store_put(EndpointResolverParams::new(()));
250        layer.store_put(SensitiveOutput);
251        layer.store_put(retry_config);
252        layer.store_put(timeout_config);
253        layer.store_put(user_agent());
254
255        Self {
256            config: layer.freeze(),
257            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
258                .with_http_client(config.http_client())
259                .with_endpoint_resolver(Some(endpoint_resolver))
260                .with_interceptor(UserAgentInterceptor::new())
261                .with_retry_classifier(retry_classifier)
262                .with_retry_strategy(Some(StandardRetryStrategy::new()))
263                .with_time_source(Some(config.time_source()))
264                .with_sleep_impl(config.sleep_impl()),
265        }
266    }
267}
268
269impl RuntimePlugin for ImdsCommonRuntimePlugin {
270    fn config(&self) -> Option<FrozenLayer> {
271        Some(self.config.clone())
272    }
273
274    fn runtime_components(
275        &self,
276        _current_components: &RuntimeComponentsBuilder,
277    ) -> Cow<'_, RuntimeComponentsBuilder> {
278        Cow::Borrowed(&self.components)
279    }
280}
281
282#[derive(Debug, Clone)]
288#[non_exhaustive]
289pub enum EndpointMode {
290    IpV4,
294    IpV6,
296}
297
298impl FromStr for EndpointMode {
299    type Err = InvalidEndpointMode;
300
301    fn from_str(value: &str) -> Result<Self, Self::Err> {
302        match value {
303            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
304            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
305            other => Err(InvalidEndpointMode::new(other.to_owned())),
306        }
307    }
308}
309
310impl EndpointMode {
311    fn endpoint(&self) -> Uri {
313        match self {
314            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
315            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
316        }
317    }
318}
319
320#[derive(Default, Debug, Clone)]
322pub struct Builder {
323    max_attempts: Option<u32>,
324    endpoint: Option<EndpointSource>,
325    mode_override: Option<EndpointMode>,
326    token_ttl: Option<Duration>,
327    connect_timeout: Option<Duration>,
328    read_timeout: Option<Duration>,
329    operation_timeout: Option<Duration>,
330    operation_attempt_timeout: Option<Duration>,
331    config: Option<ProviderConfig>,
332    retry_classifier: Option<SharedRetryClassifier>,
333}
334
335impl Builder {
336    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
340        self.max_attempts = Some(max_attempts);
341        self
342    }
343
344    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
358        self.config = Some(provider_config.clone());
359        self
360    }
361
362    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
368        let uri: Uri = endpoint.as_ref().parse()?;
369        self.endpoint = Some(EndpointSource::Explicit(uri));
370        Ok(self)
371    }
372
373    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
378        self.mode_override = Some(mode);
379        self
380    }
381
382    pub fn token_ttl(mut self, ttl: Duration) -> Self {
388        self.token_ttl = Some(ttl);
389        self
390    }
391
392    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
396        self.connect_timeout = Some(timeout);
397        self
398    }
399
400    pub fn read_timeout(mut self, timeout: Duration) -> Self {
404        self.read_timeout = Some(timeout);
405        self
406    }
407
408    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
412        self.operation_timeout = Some(timeout);
413        self
414    }
415
416    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
420        self.operation_attempt_timeout = Some(timeout);
421        self
422    }
423
424    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
430        self.retry_classifier = Some(retry_classifier);
431        self
432    }
433
434    pub fn build(self) -> Client {
443        let config = self.config.unwrap_or_default();
444        let timeout_config = TimeoutConfig::builder()
445            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
446            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
447            .operation_attempt_timeout(
448                self.operation_attempt_timeout
449                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
450            )
451            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
452            .build();
453        let endpoint_source = self
454            .endpoint
455            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
456        let endpoint_resolver = ImdsEndpointResolver {
457            endpoint_source: Arc::new(endpoint_source),
458            mode_override: self.mode_override,
459        };
460        let retry_config = RetryConfig::standard()
461            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
462        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
463            ImdsResponseRetryClassifier::default(),
464        ));
465        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
466            &config,
467            endpoint_resolver,
468            retry_config,
469            retry_classifier,
470            timeout_config,
471        ));
472        let operation = Operation::builder()
473            .service_name("imds")
474            .operation_name("get")
475            .runtime_plugin(common_plugin.clone())
476            .runtime_plugin(TokenRuntimePlugin::new(
477                common_plugin,
478                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
479            ))
480            .runtime_plugin(
481                MetricsRuntimePlugin::builder()
482                    .with_scope("aws_config::imds_credentials")
483                    .with_time_source(config.time_source())
484                    .with_metadata(Metadata::new("get_credentials", "imds"))
485                    .build()
486                    .expect("All required fields have been set"),
487            )
488            .with_connection_poisoning()
489            .serializer(|path| {
490                Ok(HttpRequest::try_from(
491                    http::Request::builder()
492                        .uri(path)
493                        .body(SdkBody::empty())
494                        .expect("valid request"),
495                )
496                .unwrap())
497            })
498            .deserializer(|response| {
499                if response.status().is_success() {
500                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
501                        .map(|data| SensitiveString::from(data.to_string()))
502                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
503                } else {
504                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
505                }
506            })
507            .build();
508        Client { operation }
509    }
510}
511
512mod env {
513    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
514    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
515}
516
517mod profile_keys {
518    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
519    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
520}
521
522#[derive(Debug, Clone)]
524enum EndpointSource {
525    Explicit(Uri),
526    Env(ProviderConfig),
527}
528
529impl EndpointSource {
530    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
531        match self {
532            EndpointSource::Explicit(uri) => {
533                if mode_override.is_some() {
534                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
535                        "Endpoint mode override was set in combination with an explicit endpoint. \
536                        The mode override will be ignored.")
537                }
538                Ok(uri.clone())
539            }
540            EndpointSource::Env(conf) => {
541                let env = conf.env();
542                let profile = conf.profile().await;
544                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
545                    Some(Cow::Owned(uri))
546                } else {
547                    profile
548                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
549                        .map(Cow::Borrowed)
550                };
551                if let Some(uri) = uri_override {
552                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
553                }
554
555                let mode = if let Some(mode) = mode_override {
557                    mode
558                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
559                    mode.parse::<EndpointMode>()
560                        .map_err(BuildError::invalid_endpoint_mode)?
561                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
562                {
563                    mode.parse::<EndpointMode>()
564                        .map_err(BuildError::invalid_endpoint_mode)?
565                } else {
566                    EndpointMode::IpV4
567                };
568
569                Ok(mode.endpoint())
570            }
571        }
572    }
573}
574
575#[derive(Clone, Debug)]
576struct ImdsEndpointResolver {
577    endpoint_source: Arc<EndpointSource>,
578    mode_override: Option<EndpointMode>,
579}
580
581impl ResolveEndpoint for ImdsEndpointResolver {
582    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
583        EndpointFuture::new(async move {
584            self.endpoint_source
585                .endpoint(self.mode_override.clone())
586                .await
587                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
588                .map_err(|err| err.into())
589        })
590    }
591}
592
593#[derive(Clone, Debug, Default)]
604#[non_exhaustive]
605pub struct ImdsResponseRetryClassifier {
606    retry_connect_timeouts: bool,
607}
608
609impl ImdsResponseRetryClassifier {
610    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
612        self.retry_connect_timeouts = retry_connect_timeouts;
613        self
614    }
615}
616
617impl ClassifyRetry for ImdsResponseRetryClassifier {
618    fn name(&self) -> &'static str {
619        "ImdsResponseRetryClassifier"
620    }
621
622    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
623        if let Some(response) = ctx.response() {
624            let status = response.status();
625            match status {
626                _ if status.is_server_error() => RetryAction::server_error(),
627                _ if status.as_u16() == 401 => RetryAction::server_error(),
629                _ => RetryAction::NoActionIndicated,
631            }
632        } else if self.retry_connect_timeouts {
633            RetryAction::server_error()
634        } else {
635            RetryAction::NoActionIndicated
640        }
641    }
642}
643
644#[cfg(test)]
645pub(crate) mod test {
646    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
647    use crate::provider_config::ProviderConfig;
648    use aws_smithy_async::rt::sleep::TokioSleep;
649    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
650    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
651    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
652    use aws_smithy_runtime_api::client::interceptors::context::{
653        Input, InterceptorContext, Output,
654    };
655    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
656    use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
657    use aws_smithy_runtime_api::client::result::ConnectorError;
658    use aws_smithy_runtime_api::client::retries::classifiers::{
659        ClassifyRetry, RetryAction, SharedRetryClassifier,
660    };
661    use aws_smithy_types::body::SdkBody;
662    use aws_smithy_types::error::display::DisplayErrorContext;
663    use aws_types::os_shim_internal::{Env, Fs};
664    use http::header::USER_AGENT;
665    use http::Uri;
666    use serde::Deserialize;
667    use std::collections::HashMap;
668    use std::error::Error;
669    use std::io;
670    use std::time::SystemTime;
671    use std::time::{Duration, UNIX_EPOCH};
672    use tracing_test::traced_test;
673
674    macro_rules! assert_full_error_contains {
675        ($err:expr, $contains:expr) => {
676            let err = $err;
677            let message = format!(
678                "{}",
679                aws_smithy_types::error::display::DisplayErrorContext(&err)
680            );
681            assert!(
682                message.contains($contains),
683                "Error message '{message}' didn't contain text '{}'",
684                $contains
685            );
686        };
687    }
688
689    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
690    const TOKEN_B: &str = "alternatetoken==";
691
692    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
694        http::Request::builder()
695            .uri(format!("{}/latest/api/token", base))
696            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
697            .method("PUT")
698            .body(SdkBody::empty())
699            .unwrap()
700            .try_into()
701            .unwrap()
702    }
703
704    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
706        HttpResponse::try_from(
707            http::Response::builder()
708                .status(200)
709                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
710                .body(SdkBody::from(token))
711                .unwrap(),
712        )
713        .unwrap()
714    }
715
716    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
718        http::Request::builder()
719            .uri(Uri::from_static(path))
720            .method("GET")
721            .header("x-aws-ec2-metadata-token", token)
722            .body(SdkBody::empty())
723            .unwrap()
724            .try_into()
725            .unwrap()
726    }
727
728    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
730        HttpResponse::try_from(
731            http::Response::builder()
732                .status(200)
733                .body(SdkBody::from(body))
734                .unwrap(),
735        )
736        .unwrap()
737    }
738
739    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
741        tokio::time::pause();
742        super::Client::builder()
743            .configure(
744                &ProviderConfig::no_configuration()
745                    .with_sleep_impl(InstantSleep::unlogged())
746                    .with_http_client(http_client.clone()),
747            )
748            .build()
749    }
750
751    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
752        let http_client = StaticReplayClient::new(events);
753        let client = make_imds_client(&http_client);
754        (client, http_client)
755    }
756
757    #[tokio::test]
758    async fn client_caches_token() {
759        let (client, http_client) = mock_imds_client(vec![
760            ReplayEvent::new(
761                token_request("http://169.254.169.254", 21600),
762                token_response(21600, TOKEN_A),
763            ),
764            ReplayEvent::new(
765                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
766                imds_response(r#"test-imds-output"#),
767            ),
768            ReplayEvent::new(
769                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
770                imds_response("output2"),
771            ),
772        ]);
773        let metadata = client.get("/latest/metadata").await.expect("failed");
775        assert_eq!("test-imds-output", metadata.as_ref());
776        let metadata = client.get("/latest/metadata2").await.expect("failed");
778        assert_eq!("output2", metadata.as_ref());
779        http_client.assert_requests_match(&[]);
780    }
781
782    #[tokio::test]
783    async fn token_can_expire() {
784        let (_, http_client) = mock_imds_client(vec![
785            ReplayEvent::new(
786                token_request("http://[fd00:ec2::254]", 600),
787                token_response(600, TOKEN_A),
788            ),
789            ReplayEvent::new(
790                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
791                imds_response(r#"test-imds-output1"#),
792            ),
793            ReplayEvent::new(
794                token_request("http://[fd00:ec2::254]", 600),
795                token_response(600, TOKEN_B),
796            ),
797            ReplayEvent::new(
798                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
799                imds_response(r#"test-imds-output2"#),
800            ),
801        ]);
802        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
803        let client = super::Client::builder()
804            .configure(
805                &ProviderConfig::no_configuration()
806                    .with_http_client(http_client.clone())
807                    .with_time_source(time_source.clone())
808                    .with_sleep_impl(sleep),
809            )
810            .endpoint_mode(EndpointMode::IpV6)
811            .token_ttl(Duration::from_secs(600))
812            .build();
813
814        let resp1 = client.get("/latest/metadata").await.expect("success");
815        time_source.advance(Duration::from_secs(600));
817        let resp2 = client.get("/latest/metadata").await.expect("success");
818        http_client.assert_requests_match(&[]);
819        assert_eq!("test-imds-output1", resp1.as_ref());
820        assert_eq!("test-imds-output2", resp2.as_ref());
821    }
822
823    #[tokio::test]
825    async fn token_refresh_buffer() {
826        let _logs = capture_test_logs();
827        let (_, http_client) = mock_imds_client(vec![
828            ReplayEvent::new(
829                token_request("http://[fd00:ec2::254]", 600),
830                token_response(600, TOKEN_A),
831            ),
832            ReplayEvent::new(
834                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
835                imds_response(r#"test-imds-output1"#),
836            ),
837            ReplayEvent::new(
839                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
840                imds_response(r#"test-imds-output2"#),
841            ),
842            ReplayEvent::new(
844                token_request("http://[fd00:ec2::254]", 600),
845                token_response(600, TOKEN_B),
846            ),
847            ReplayEvent::new(
848                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
849                imds_response(r#"test-imds-output3"#),
850            ),
851        ]);
852        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
853        let client = super::Client::builder()
854            .configure(
855                &ProviderConfig::no_configuration()
856                    .with_sleep_impl(sleep)
857                    .with_http_client(http_client.clone())
858                    .with_time_source(time_source.clone()),
859            )
860            .endpoint_mode(EndpointMode::IpV6)
861            .token_ttl(Duration::from_secs(600))
862            .build();
863
864        tracing::info!("resp1 -----------------------------------------------------------");
865        let resp1 = client.get("/latest/metadata").await.expect("success");
866        time_source.advance(Duration::from_secs(400));
868        tracing::info!("resp2 -----------------------------------------------------------");
869        let resp2 = client.get("/latest/metadata").await.expect("success");
870        time_source.advance(Duration::from_secs(150));
871        tracing::info!("resp3 -----------------------------------------------------------");
872        let resp3 = client.get("/latest/metadata").await.expect("success");
873        http_client.assert_requests_match(&[]);
874        assert_eq!("test-imds-output1", resp1.as_ref());
875        assert_eq!("test-imds-output2", resp2.as_ref());
876        assert_eq!("test-imds-output3", resp3.as_ref());
877    }
878
879    #[tokio::test]
881    #[traced_test]
882    async fn retry_500() {
883        let (client, http_client) = mock_imds_client(vec![
884            ReplayEvent::new(
885                token_request("http://169.254.169.254", 21600),
886                token_response(21600, TOKEN_A),
887            ),
888            ReplayEvent::new(
889                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
890                http::Response::builder()
891                    .status(500)
892                    .body(SdkBody::empty())
893                    .unwrap(),
894            ),
895            ReplayEvent::new(
896                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
897                imds_response("ok"),
898            ),
899        ]);
900        assert_eq!(
901            "ok",
902            client
903                .get("/latest/metadata")
904                .await
905                .expect("success")
906                .as_ref()
907        );
908        http_client.assert_requests_match(&[]);
909
910        for request in http_client.actual_requests() {
912            assert!(request.headers().get(USER_AGENT).is_some());
913        }
914    }
915
916    #[tokio::test]
918    #[traced_test]
919    async fn retry_token_failure() {
920        let (client, http_client) = mock_imds_client(vec![
921            ReplayEvent::new(
922                token_request("http://169.254.169.254", 21600),
923                http::Response::builder()
924                    .status(500)
925                    .body(SdkBody::empty())
926                    .unwrap(),
927            ),
928            ReplayEvent::new(
929                token_request("http://169.254.169.254", 21600),
930                token_response(21600, TOKEN_A),
931            ),
932            ReplayEvent::new(
933                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
934                imds_response("ok"),
935            ),
936        ]);
937        assert_eq!(
938            "ok",
939            client
940                .get("/latest/metadata")
941                .await
942                .expect("success")
943                .as_ref()
944        );
945        http_client.assert_requests_match(&[]);
946    }
947
948    #[tokio::test]
950    #[traced_test]
951    async fn retry_metadata_401() {
952        let (client, http_client) = mock_imds_client(vec![
953            ReplayEvent::new(
954                token_request("http://169.254.169.254", 21600),
955                token_response(0, TOKEN_A),
956            ),
957            ReplayEvent::new(
958                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
959                http::Response::builder()
960                    .status(401)
961                    .body(SdkBody::empty())
962                    .unwrap(),
963            ),
964            ReplayEvent::new(
965                token_request("http://169.254.169.254", 21600),
966                token_response(21600, TOKEN_B),
967            ),
968            ReplayEvent::new(
969                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
970                imds_response("ok"),
971            ),
972        ]);
973        assert_eq!(
974            "ok",
975            client
976                .get("/latest/metadata")
977                .await
978                .expect("success")
979                .as_ref()
980        );
981        http_client.assert_requests_match(&[]);
982    }
983
984    #[tokio::test]
986    #[traced_test]
987    async fn no_403_retry() {
988        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
989            token_request("http://169.254.169.254", 21600),
990            http::Response::builder()
991                .status(403)
992                .body(SdkBody::empty())
993                .unwrap(),
994        )]);
995        let err = client.get("/latest/metadata").await.expect_err("no token");
996        assert_full_error_contains!(err, "forbidden");
997        http_client.assert_requests_match(&[]);
998    }
999
1000    #[test]
1002    fn successful_response_properly_classified() {
1003        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1004        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
1005        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
1006        let classifier = ImdsResponseRetryClassifier::default();
1007        assert_eq!(
1008            RetryAction::NoActionIndicated,
1009            classifier.classify_retry(&ctx)
1010        );
1011
1012        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1014        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1015            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1016        ))));
1017        assert_eq!(
1018            RetryAction::NoActionIndicated,
1019            classifier.classify_retry(&ctx)
1020        );
1021    }
1022
1023    #[tokio::test]
1025    async fn user_provided_retry_classifier() {
1026        #[derive(Clone, Debug)]
1027        struct UserProvidedRetryClassifier;
1028
1029        impl ClassifyRetry for UserProvidedRetryClassifier {
1030            fn name(&self) -> &'static str {
1031                "UserProvidedRetryClassifier"
1032            }
1033
1034            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1036                RetryAction::RetryForbidden
1037            }
1038        }
1039
1040        let events = vec![
1041            ReplayEvent::new(
1042                token_request("http://169.254.169.254", 21600),
1043                token_response(0, TOKEN_A),
1044            ),
1045            ReplayEvent::new(
1046                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1047                http::Response::builder()
1048                    .status(401)
1049                    .body(SdkBody::empty())
1050                    .unwrap(),
1051            ),
1052            ReplayEvent::new(
1053                token_request("http://169.254.169.254", 21600),
1054                token_response(21600, TOKEN_B),
1055            ),
1056            ReplayEvent::new(
1057                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1058                imds_response("ok"),
1059            ),
1060        ];
1061        let http_client = StaticReplayClient::new(events);
1062
1063        let imds_client = super::Client::builder()
1064            .configure(
1065                &ProviderConfig::no_configuration()
1066                    .with_sleep_impl(InstantSleep::unlogged())
1067                    .with_http_client(http_client.clone()),
1068            )
1069            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1070            .build();
1071
1072        let res = imds_client
1073            .get("/latest/metadata")
1074            .await
1075            .expect_err("Client should error");
1076
1077        assert_full_error_contains!(res, "401");
1080    }
1081
1082    #[tokio::test]
1084    async fn invalid_token() {
1085        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1086            token_request("http://169.254.169.254", 21600),
1087            token_response(21600, "invalid\nheader\nvalue\0"),
1088        )]);
1089        let err = client.get("/latest/metadata").await.expect_err("no token");
1090        assert_full_error_contains!(err, "invalid token");
1091        http_client.assert_requests_match(&[]);
1092    }
1093
1094    #[tokio::test]
1095    async fn non_utf8_response() {
1096        let (client, http_client) = mock_imds_client(vec![
1097            ReplayEvent::new(
1098                token_request("http://169.254.169.254", 21600),
1099                token_response(21600, TOKEN_A).map(SdkBody::from),
1100            ),
1101            ReplayEvent::new(
1102                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1103                http::Response::builder()
1104                    .status(200)
1105                    .body(SdkBody::from(vec![0xA0, 0xA1]))
1106                    .unwrap(),
1107            ),
1108        ]);
1109        let err = client.get("/latest/metadata").await.expect_err("no token");
1110        assert_full_error_contains!(err, "invalid UTF-8");
1111        http_client.assert_requests_match(&[]);
1112    }
1113
1114    #[cfg_attr(windows, ignore)]
1116    #[tokio::test]
1118    #[cfg(feature = "default-https-client")]
1119    async fn one_second_connect_timeout() {
1120        use crate::imds::client::ImdsError;
1121        let client = Client::builder()
1122            .endpoint("http://240.0.0.0")
1124            .expect("valid uri")
1125            .build();
1126        let now = SystemTime::now();
1127        let resp = client
1128            .get("/latest/metadata")
1129            .await
1130            .expect_err("240.0.0.0 will never resolve");
1131        match resp {
1132            err @ ImdsError::FailedToLoadToken(_)
1133                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
1135                "wrong error, expected construction failure with TimedOutError inside: {}",
1136                DisplayErrorContext(&other)
1137            ),
1138        }
1139        let time_elapsed = now.elapsed().unwrap();
1140        assert!(
1141            time_elapsed > Duration::from_secs(1),
1142            "time_elapsed should be greater than 1s but was {:?}",
1143            time_elapsed
1144        );
1145        assert!(
1146            time_elapsed < Duration::from_secs(2),
1147            "time_elapsed should be less than 2s but was {:?}",
1148            time_elapsed
1149        );
1150    }
1151
1152    #[tokio::test]
1154    async fn retry_connect_timeouts() {
1155        let http_client = StaticReplayClient::new(vec![]);
1156        let imds_client = super::Client::builder()
1157            .retry_classifier(SharedRetryClassifier::new(
1158                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1159            ))
1160            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1161            .operation_timeout(Duration::from_secs(1))
1162            .endpoint("http://240.0.0.0")
1163            .expect("valid uri")
1164            .build();
1165
1166        let now = SystemTime::now();
1167        let _res = imds_client
1168            .get("/latest/metadata")
1169            .await
1170            .expect_err("240.0.0.0 will never resolve");
1171        let time_elapsed: Duration = now.elapsed().unwrap();
1172
1173        assert!(
1174            time_elapsed > Duration::from_secs(1),
1175            "time_elapsed should be greater than 1s but was {:?}",
1176            time_elapsed
1177        );
1178
1179        assert!(
1180            time_elapsed < Duration::from_secs(2),
1181            "time_elapsed should be less than 2s but was {:?}",
1182            time_elapsed
1183        );
1184    }
1185
1186    #[derive(Debug, Deserialize)]
1187    struct ImdsConfigTest {
1188        env: HashMap<String, String>,
1189        fs: HashMap<String, String>,
1190        endpoint_override: Option<String>,
1191        mode_override: Option<String>,
1192        result: Result<String, String>,
1193        docs: String,
1194    }
1195
1196    #[tokio::test]
1197    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1198        let _logs = capture_test_logs();
1199
1200        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1201        #[derive(Deserialize)]
1202        struct TestCases {
1203            tests: Vec<ImdsConfigTest>,
1204        }
1205
1206        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1207        let test_cases = test_cases.tests;
1208        for test in test_cases {
1209            check(test).await;
1210        }
1211        Ok(())
1212    }
1213
1214    async fn check(test_case: ImdsConfigTest) {
1215        let (http_client, watcher) = capture_request(None);
1216        let provider_config = ProviderConfig::no_configuration()
1217            .with_sleep_impl(TokioSleep::new())
1218            .with_env(Env::from(test_case.env))
1219            .with_fs(Fs::from_map(test_case.fs))
1220            .with_http_client(http_client);
1221        let mut imds_client = Client::builder().configure(&provider_config);
1222        if let Some(endpoint_override) = test_case.endpoint_override {
1223            imds_client = imds_client
1224                .endpoint(endpoint_override)
1225                .expect("invalid URI");
1226        }
1227
1228        if let Some(mode_override) = test_case.mode_override {
1229            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1230        }
1231
1232        let imds_client = imds_client.build();
1233        match &test_case.result {
1234            Ok(uri) => {
1235                let _ = imds_client.get("/hello").await;
1237                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1238            }
1239            Err(expected) => {
1240                let err = imds_client.get("/hello").await.expect_err("it should fail");
1241                let message = format!("{}", DisplayErrorContext(&err));
1242                assert!(
1243                    message.contains(expected),
1244                    "{}\nexpected error: {expected}\nactual error: {message}",
1245                    test_case.docs
1246                );
1247            }
1248        };
1249    }
1250}