1use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::metrics::MetricsRuntimePlugin;
16use aws_smithy_runtime::client::orchestrator::operation::Operation;
17use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
18use aws_smithy_runtime_api::box_error::BoxError;
19use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
20use aws_smithy_runtime_api::client::endpoint::{
21 EndpointFuture, EndpointResolverParams, ResolveEndpoint,
22};
23use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
24use aws_smithy_runtime_api::client::orchestrator::{
25 HttpRequest, Metadata, OrchestratorError, SensitiveOutput,
26};
27use aws_smithy_runtime_api::client::result::ConnectorError;
28use aws_smithy_runtime_api::client::result::SdkError;
29use aws_smithy_runtime_api::client::retries::classifiers::{
30 ClassifyRetry, RetryAction, SharedRetryClassifier,
31};
32use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
33use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
34use aws_smithy_types::body::SdkBody;
35use aws_smithy_types::config_bag::{FrozenLayer, Layer};
36use aws_smithy_types::endpoint::Endpoint;
37use aws_smithy_types::retry::RetryConfig;
38use aws_smithy_types::timeout::TimeoutConfig;
39use aws_types::os_shim_internal::Env;
40use http::Uri;
41use std::borrow::Cow;
42use std::error::Error as _;
43use std::fmt;
44use std::str::FromStr;
45use std::sync::Arc;
46use std::time::Duration;
47
48pub mod error;
49mod token;
50
51const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
53const DEFAULT_ATTEMPTS: u32 = 4;
54const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
56const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
57const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
58
59fn user_agent() -> AwsUserAgent {
60 AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
61}
62
63#[derive(Clone, Debug)]
132pub struct Client {
133 operation: Operation<String, SensitiveString, InnerImdsError>,
134}
135
136impl Client {
137 pub fn builder() -> Builder {
139 Builder::default()
140 }
141
142 pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
163 self.operation
164 .invoke(path.into())
165 .await
166 .map_err(|err| match err {
167 SdkError::ConstructionFailure(_) if err.source().is_some() => {
168 match err.into_source().map(|e| e.downcast::<ImdsError>()) {
169 Ok(Ok(token_failure)) => *token_failure,
170 Ok(Err(err)) => ImdsError::unexpected(err),
171 Err(err) => ImdsError::unexpected(err),
172 }
173 }
174 SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
175 SdkError::ServiceError(context) => match context.err() {
176 InnerImdsError::InvalidUtf8 => {
177 ImdsError::unexpected("IMDS returned invalid UTF-8")
178 }
179 InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
180 },
181 err @ SdkError::DispatchFailure(_) => match err.into_source() {
185 Ok(source) => match source.downcast::<ConnectorError>() {
186 Ok(source) => match source.into_source().downcast::<ImdsError>() {
187 Ok(source) => *source,
188 Err(err) => ImdsError::unexpected(err),
189 },
190 Err(err) => ImdsError::unexpected(err),
191 },
192 Err(err) => ImdsError::unexpected(err),
193 },
194 SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
195 _ => ImdsError::unexpected(err),
196 })
197 }
198}
199
200#[derive(Clone)]
202pub struct SensitiveString(String);
203
204impl fmt::Debug for SensitiveString {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 f.debug_tuple("SensitiveString")
207 .field(&"** redacted **")
208 .finish()
209 }
210}
211
212impl AsRef<str> for SensitiveString {
213 fn as_ref(&self) -> &str {
214 &self.0
215 }
216}
217
218impl From<String> for SensitiveString {
219 fn from(value: String) -> Self {
220 Self(value)
221 }
222}
223
224impl From<SensitiveString> for String {
225 fn from(value: SensitiveString) -> Self {
226 value.0
227 }
228}
229
230#[derive(Debug)]
234struct ImdsCommonRuntimePlugin {
235 config: FrozenLayer,
236 components: RuntimeComponentsBuilder,
237}
238
239impl ImdsCommonRuntimePlugin {
240 fn new(
241 config: &ProviderConfig,
242 endpoint_resolver: ImdsEndpointResolver,
243 retry_config: RetryConfig,
244 retry_classifier: SharedRetryClassifier,
245 timeout_config: TimeoutConfig,
246 ) -> Self {
247 let mut layer = Layer::new("ImdsCommonRuntimePlugin");
248 layer.store_put(AuthSchemeOptionResolverParams::new(()));
249 layer.store_put(EndpointResolverParams::new(()));
250 layer.store_put(SensitiveOutput);
251 layer.store_put(retry_config);
252 layer.store_put(timeout_config);
253 layer.store_put(user_agent());
254
255 Self {
256 config: layer.freeze(),
257 components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
258 .with_http_client(config.http_client())
259 .with_endpoint_resolver(Some(endpoint_resolver))
260 .with_interceptor(UserAgentInterceptor::new())
261 .with_retry_classifier(retry_classifier)
262 .with_retry_strategy(Some(StandardRetryStrategy::new()))
263 .with_time_source(Some(config.time_source()))
264 .with_sleep_impl(config.sleep_impl()),
265 }
266 }
267}
268
269impl RuntimePlugin for ImdsCommonRuntimePlugin {
270 fn config(&self) -> Option<FrozenLayer> {
271 Some(self.config.clone())
272 }
273
274 fn runtime_components(
275 &self,
276 _current_components: &RuntimeComponentsBuilder,
277 ) -> Cow<'_, RuntimeComponentsBuilder> {
278 Cow::Borrowed(&self.components)
279 }
280}
281
282#[derive(Debug, Clone)]
288#[non_exhaustive]
289pub enum EndpointMode {
290 IpV4,
294 IpV6,
296}
297
298impl FromStr for EndpointMode {
299 type Err = InvalidEndpointMode;
300
301 fn from_str(value: &str) -> Result<Self, Self::Err> {
302 match value {
303 _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
304 _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
305 other => Err(InvalidEndpointMode::new(other.to_owned())),
306 }
307 }
308}
309
310impl EndpointMode {
311 fn endpoint(&self) -> Uri {
313 match self {
314 EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
315 EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
316 }
317 }
318}
319
320#[derive(Default, Debug, Clone)]
322pub struct Builder {
323 max_attempts: Option<u32>,
324 endpoint: Option<EndpointSource>,
325 mode_override: Option<EndpointMode>,
326 token_ttl: Option<Duration>,
327 connect_timeout: Option<Duration>,
328 read_timeout: Option<Duration>,
329 operation_timeout: Option<Duration>,
330 operation_attempt_timeout: Option<Duration>,
331 config: Option<ProviderConfig>,
332 retry_classifier: Option<SharedRetryClassifier>,
333}
334
335impl Builder {
336 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
340 self.max_attempts = Some(max_attempts);
341 self
342 }
343
344 pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
358 self.config = Some(provider_config.clone());
359 self
360 }
361
362 pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
368 let uri: Uri = endpoint.as_ref().parse()?;
369 self.endpoint = Some(EndpointSource::Explicit(uri));
370 Ok(self)
371 }
372
373 pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
378 self.mode_override = Some(mode);
379 self
380 }
381
382 pub fn token_ttl(mut self, ttl: Duration) -> Self {
388 self.token_ttl = Some(ttl);
389 self
390 }
391
392 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
396 self.connect_timeout = Some(timeout);
397 self
398 }
399
400 pub fn read_timeout(mut self, timeout: Duration) -> Self {
404 self.read_timeout = Some(timeout);
405 self
406 }
407
408 pub fn operation_timeout(mut self, timeout: Duration) -> Self {
412 self.operation_timeout = Some(timeout);
413 self
414 }
415
416 pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
420 self.operation_attempt_timeout = Some(timeout);
421 self
422 }
423
424 pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
430 self.retry_classifier = Some(retry_classifier);
431 self
432 }
433
434 pub fn build(self) -> Client {
443 let config = self.config.unwrap_or_default();
444 let timeout_config = TimeoutConfig::builder()
445 .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
446 .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
447 .operation_attempt_timeout(
448 self.operation_attempt_timeout
449 .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
450 )
451 .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
452 .build();
453 let endpoint_source = self
454 .endpoint
455 .unwrap_or_else(|| EndpointSource::Env(config.clone()));
456 let endpoint_resolver = ImdsEndpointResolver {
457 endpoint_source: Arc::new(endpoint_source),
458 mode_override: self.mode_override,
459 };
460 let retry_config = RetryConfig::standard()
461 .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
462 let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
463 ImdsResponseRetryClassifier::default(),
464 ));
465 let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
466 &config,
467 endpoint_resolver,
468 retry_config,
469 retry_classifier,
470 timeout_config,
471 ));
472 let operation = Operation::builder()
473 .service_name("imds")
474 .operation_name("get")
475 .runtime_plugin(common_plugin.clone())
476 .runtime_plugin(TokenRuntimePlugin::new(
477 common_plugin,
478 self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
479 ))
480 .runtime_plugin(
481 MetricsRuntimePlugin::builder()
482 .with_scope("aws_config::imds_credentials")
483 .with_time_source(config.time_source())
484 .with_metadata(Metadata::new("get_credentials", "imds"))
485 .build()
486 .expect("All required fields have been set"),
487 )
488 .with_connection_poisoning()
489 .serializer(|path| {
490 Ok(HttpRequest::try_from(
491 http::Request::builder()
492 .uri(path)
493 .body(SdkBody::empty())
494 .expect("valid request"),
495 )
496 .unwrap())
497 })
498 .deserializer(|response| {
499 if response.status().is_success() {
500 std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
501 .map(|data| SensitiveString::from(data.to_string()))
502 .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
503 } else {
504 Err(OrchestratorError::operation(InnerImdsError::BadStatus))
505 }
506 });
507 let operation = if let Some(bv) = config.behavior_version() {
508 operation.behavior_version(bv).build()
509 } else {
510 operation.build()
511 };
512 Client { operation }
513 }
514}
515
516mod env {
517 pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
518 pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
519}
520
521mod profile_keys {
522 pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
523 pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
524}
525
526#[derive(Debug, Clone)]
528enum EndpointSource {
529 Explicit(Uri),
530 Env(ProviderConfig),
531}
532
533impl EndpointSource {
534 async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
535 match self {
536 EndpointSource::Explicit(uri) => {
537 if mode_override.is_some() {
538 tracing::warn!(endpoint = ?uri, mode = ?mode_override,
539 "Endpoint mode override was set in combination with an explicit endpoint. \
540 The mode override will be ignored.")
541 }
542 Ok(uri.clone())
543 }
544 EndpointSource::Env(conf) => {
545 let env = conf.env();
546 let profile = conf.profile().await;
548 let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
549 Some(Cow::Owned(uri))
550 } else {
551 profile
552 .and_then(|profile| profile.get(profile_keys::ENDPOINT))
553 .map(Cow::Borrowed)
554 };
555 if let Some(uri) = uri_override {
556 return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
557 }
558
559 let mode = if let Some(mode) = mode_override {
561 mode
562 } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
563 mode.parse::<EndpointMode>()
564 .map_err(BuildError::invalid_endpoint_mode)?
565 } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
566 {
567 mode.parse::<EndpointMode>()
568 .map_err(BuildError::invalid_endpoint_mode)?
569 } else {
570 EndpointMode::IpV4
571 };
572
573 Ok(mode.endpoint())
574 }
575 }
576 }
577}
578
579#[derive(Clone, Debug)]
580struct ImdsEndpointResolver {
581 endpoint_source: Arc<EndpointSource>,
582 mode_override: Option<EndpointMode>,
583}
584
585impl ResolveEndpoint for ImdsEndpointResolver {
586 fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
587 EndpointFuture::new(async move {
588 self.endpoint_source
589 .endpoint(self.mode_override.clone())
590 .await
591 .map(|uri| Endpoint::builder().url(uri.to_string()).build())
592 .map_err(|err| err.into())
593 })
594 }
595}
596
597#[derive(Clone, Debug, Default)]
608#[non_exhaustive]
609pub struct ImdsResponseRetryClassifier {
610 retry_connect_timeouts: bool,
611}
612
613impl ImdsResponseRetryClassifier {
614 pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
616 self.retry_connect_timeouts = retry_connect_timeouts;
617 self
618 }
619}
620
621impl ClassifyRetry for ImdsResponseRetryClassifier {
622 fn name(&self) -> &'static str {
623 "ImdsResponseRetryClassifier"
624 }
625
626 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
627 if let Some(response) = ctx.response() {
628 let status = response.status();
629 match status {
630 _ if status.is_server_error() => RetryAction::server_error(),
631 _ if status.as_u16() == 401 => RetryAction::server_error(),
633 _ => RetryAction::NoActionIndicated,
635 }
636 } else if self.retry_connect_timeouts {
637 RetryAction::server_error()
638 } else {
639 RetryAction::NoActionIndicated
644 }
645 }
646}
647
648#[cfg(test)]
649pub(crate) mod test {
650 use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
651 use crate::provider_config::ProviderConfig;
652 use aws_smithy_async::rt::sleep::TokioSleep;
653 use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
654 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
655 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
656 use aws_smithy_runtime_api::client::interceptors::context::{
657 Input, InterceptorContext, Output,
658 };
659 use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
660 use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
661 use aws_smithy_runtime_api::client::result::ConnectorError;
662 use aws_smithy_runtime_api::client::retries::classifiers::{
663 ClassifyRetry, RetryAction, SharedRetryClassifier,
664 };
665 use aws_smithy_types::body::SdkBody;
666 use aws_smithy_types::error::display::DisplayErrorContext;
667 use aws_types::os_shim_internal::{Env, Fs};
668 use http::header::USER_AGENT;
669 use http::Uri;
670 use serde::Deserialize;
671 use std::collections::HashMap;
672 use std::error::Error;
673 use std::io;
674 use std::time::{Duration, UNIX_EPOCH};
675 use tracing_test::traced_test;
676
677 macro_rules! assert_full_error_contains {
678 ($err:expr, $contains:expr) => {
679 let err = $err;
680 let message = format!(
681 "{}",
682 aws_smithy_types::error::display::DisplayErrorContext(&err)
683 );
684 assert!(
685 message.contains($contains),
686 "Error message '{message}' didn't contain text '{}'",
687 $contains
688 );
689 };
690 }
691
692 const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
693 const TOKEN_B: &str = "alternatetoken==";
694
695 pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
697 http::Request::builder()
698 .uri(format!("{}/latest/api/token", base))
699 .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
700 .method("PUT")
701 .body(SdkBody::empty())
702 .unwrap()
703 .try_into()
704 .unwrap()
705 }
706
707 pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
709 HttpResponse::try_from(
710 http::Response::builder()
711 .status(200)
712 .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
713 .body(SdkBody::from(token))
714 .unwrap(),
715 )
716 .unwrap()
717 }
718
719 pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
721 http::Request::builder()
722 .uri(Uri::from_static(path))
723 .method("GET")
724 .header("x-aws-ec2-metadata-token", token)
725 .body(SdkBody::empty())
726 .unwrap()
727 .try_into()
728 .unwrap()
729 }
730
731 pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
733 HttpResponse::try_from(
734 http::Response::builder()
735 .status(200)
736 .body(SdkBody::from(body))
737 .unwrap(),
738 )
739 .unwrap()
740 }
741
742 pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
744 tokio::time::pause();
745 super::Client::builder()
746 .configure(
747 &ProviderConfig::no_configuration()
748 .with_sleep_impl(InstantSleep::unlogged())
749 .with_http_client(http_client.clone()),
750 )
751 .build()
752 }
753
754 fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
755 let http_client = StaticReplayClient::new(events);
756 let client = make_imds_client(&http_client);
757 (client, http_client)
758 }
759
760 #[tokio::test]
761 async fn client_caches_token() {
762 let (client, http_client) = mock_imds_client(vec![
763 ReplayEvent::new(
764 token_request("http://169.254.169.254", 21600),
765 token_response(21600, TOKEN_A),
766 ),
767 ReplayEvent::new(
768 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
769 imds_response(r#"test-imds-output"#),
770 ),
771 ReplayEvent::new(
772 imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
773 imds_response("output2"),
774 ),
775 ]);
776 let metadata = client.get("/latest/metadata").await.expect("failed");
778 assert_eq!("test-imds-output", metadata.as_ref());
779 let metadata = client.get("/latest/metadata2").await.expect("failed");
781 assert_eq!("output2", metadata.as_ref());
782 http_client.assert_requests_match(&[]);
783 }
784
785 #[tokio::test]
786 async fn token_can_expire() {
787 let (_, http_client) = mock_imds_client(vec![
788 ReplayEvent::new(
789 token_request("http://[fd00:ec2::254]", 600),
790 token_response(600, TOKEN_A),
791 ),
792 ReplayEvent::new(
793 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
794 imds_response(r#"test-imds-output1"#),
795 ),
796 ReplayEvent::new(
797 token_request("http://[fd00:ec2::254]", 600),
798 token_response(600, TOKEN_B),
799 ),
800 ReplayEvent::new(
801 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
802 imds_response(r#"test-imds-output2"#),
803 ),
804 ]);
805 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
806 let client = super::Client::builder()
807 .configure(
808 &ProviderConfig::no_configuration()
809 .with_http_client(http_client.clone())
810 .with_time_source(time_source.clone())
811 .with_sleep_impl(sleep),
812 )
813 .endpoint_mode(EndpointMode::IpV6)
814 .token_ttl(Duration::from_secs(600))
815 .build();
816
817 let resp1 = client.get("/latest/metadata").await.expect("success");
818 time_source.advance(Duration::from_secs(600));
820 let resp2 = client.get("/latest/metadata").await.expect("success");
821 http_client.assert_requests_match(&[]);
822 assert_eq!("test-imds-output1", resp1.as_ref());
823 assert_eq!("test-imds-output2", resp2.as_ref());
824 }
825
826 #[tokio::test]
828 async fn token_refresh_buffer() {
829 let _logs = capture_test_logs();
830 let (_, http_client) = mock_imds_client(vec![
831 ReplayEvent::new(
832 token_request("http://[fd00:ec2::254]", 600),
833 token_response(600, TOKEN_A),
834 ),
835 ReplayEvent::new(
837 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
838 imds_response(r#"test-imds-output1"#),
839 ),
840 ReplayEvent::new(
842 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
843 imds_response(r#"test-imds-output2"#),
844 ),
845 ReplayEvent::new(
847 token_request("http://[fd00:ec2::254]", 600),
848 token_response(600, TOKEN_B),
849 ),
850 ReplayEvent::new(
851 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
852 imds_response(r#"test-imds-output3"#),
853 ),
854 ]);
855 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
856 let client = super::Client::builder()
857 .configure(
858 &ProviderConfig::no_configuration()
859 .with_sleep_impl(sleep)
860 .with_http_client(http_client.clone())
861 .with_time_source(time_source.clone()),
862 )
863 .endpoint_mode(EndpointMode::IpV6)
864 .token_ttl(Duration::from_secs(600))
865 .build();
866
867 tracing::info!("resp1 -----------------------------------------------------------");
868 let resp1 = client.get("/latest/metadata").await.expect("success");
869 time_source.advance(Duration::from_secs(400));
871 tracing::info!("resp2 -----------------------------------------------------------");
872 let resp2 = client.get("/latest/metadata").await.expect("success");
873 time_source.advance(Duration::from_secs(150));
874 tracing::info!("resp3 -----------------------------------------------------------");
875 let resp3 = client.get("/latest/metadata").await.expect("success");
876 http_client.assert_requests_match(&[]);
877 assert_eq!("test-imds-output1", resp1.as_ref());
878 assert_eq!("test-imds-output2", resp2.as_ref());
879 assert_eq!("test-imds-output3", resp3.as_ref());
880 }
881
882 #[tokio::test]
884 #[traced_test]
885 async fn retry_500() {
886 let (client, http_client) = mock_imds_client(vec![
887 ReplayEvent::new(
888 token_request("http://169.254.169.254", 21600),
889 token_response(21600, TOKEN_A),
890 ),
891 ReplayEvent::new(
892 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
893 http::Response::builder()
894 .status(500)
895 .body(SdkBody::empty())
896 .unwrap(),
897 ),
898 ReplayEvent::new(
899 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
900 imds_response("ok"),
901 ),
902 ]);
903 assert_eq!(
904 "ok",
905 client
906 .get("/latest/metadata")
907 .await
908 .expect("success")
909 .as_ref()
910 );
911 http_client.assert_requests_match(&[]);
912
913 for request in http_client.actual_requests() {
915 assert!(request.headers().get(USER_AGENT).is_some());
916 }
917 }
918
919 #[tokio::test]
921 #[traced_test]
922 async fn retry_token_failure() {
923 let (client, http_client) = mock_imds_client(vec![
924 ReplayEvent::new(
925 token_request("http://169.254.169.254", 21600),
926 http::Response::builder()
927 .status(500)
928 .body(SdkBody::empty())
929 .unwrap(),
930 ),
931 ReplayEvent::new(
932 token_request("http://169.254.169.254", 21600),
933 token_response(21600, TOKEN_A),
934 ),
935 ReplayEvent::new(
936 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
937 imds_response("ok"),
938 ),
939 ]);
940 assert_eq!(
941 "ok",
942 client
943 .get("/latest/metadata")
944 .await
945 .expect("success")
946 .as_ref()
947 );
948 http_client.assert_requests_match(&[]);
949 }
950
951 #[tokio::test]
953 #[traced_test]
954 async fn retry_metadata_401() {
955 let (client, http_client) = mock_imds_client(vec![
956 ReplayEvent::new(
957 token_request("http://169.254.169.254", 21600),
958 token_response(0, TOKEN_A),
959 ),
960 ReplayEvent::new(
961 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
962 http::Response::builder()
963 .status(401)
964 .body(SdkBody::empty())
965 .unwrap(),
966 ),
967 ReplayEvent::new(
968 token_request("http://169.254.169.254", 21600),
969 token_response(21600, TOKEN_B),
970 ),
971 ReplayEvent::new(
972 imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
973 imds_response("ok"),
974 ),
975 ]);
976 assert_eq!(
977 "ok",
978 client
979 .get("/latest/metadata")
980 .await
981 .expect("success")
982 .as_ref()
983 );
984 http_client.assert_requests_match(&[]);
985 }
986
987 #[tokio::test]
989 #[traced_test]
990 async fn no_403_retry() {
991 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
992 token_request("http://169.254.169.254", 21600),
993 http::Response::builder()
994 .status(403)
995 .body(SdkBody::empty())
996 .unwrap(),
997 )]);
998 let err = client.get("/latest/metadata").await.expect_err("no token");
999 assert_full_error_contains!(err, "forbidden");
1000 http_client.assert_requests_match(&[]);
1001 }
1002
1003 #[test]
1005 fn successful_response_properly_classified() {
1006 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1007 ctx.set_output_or_error(Ok(Output::doesnt_matter()));
1008 ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
1009 let classifier = ImdsResponseRetryClassifier::default();
1010 assert_eq!(
1011 RetryAction::NoActionIndicated,
1012 classifier.classify_retry(&ctx)
1013 );
1014
1015 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1017 ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1018 io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1019 ))));
1020 assert_eq!(
1021 RetryAction::NoActionIndicated,
1022 classifier.classify_retry(&ctx)
1023 );
1024 }
1025
1026 #[tokio::test]
1028 async fn user_provided_retry_classifier() {
1029 #[derive(Clone, Debug)]
1030 struct UserProvidedRetryClassifier;
1031
1032 impl ClassifyRetry for UserProvidedRetryClassifier {
1033 fn name(&self) -> &'static str {
1034 "UserProvidedRetryClassifier"
1035 }
1036
1037 fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1039 RetryAction::RetryForbidden
1040 }
1041 }
1042
1043 let events = vec![
1044 ReplayEvent::new(
1045 token_request("http://169.254.169.254", 21600),
1046 token_response(0, TOKEN_A),
1047 ),
1048 ReplayEvent::new(
1049 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1050 http::Response::builder()
1051 .status(401)
1052 .body(SdkBody::empty())
1053 .unwrap(),
1054 ),
1055 ReplayEvent::new(
1056 token_request("http://169.254.169.254", 21600),
1057 token_response(21600, TOKEN_B),
1058 ),
1059 ReplayEvent::new(
1060 imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1061 imds_response("ok"),
1062 ),
1063 ];
1064 let http_client = StaticReplayClient::new(events);
1065
1066 let imds_client = super::Client::builder()
1067 .configure(
1068 &ProviderConfig::no_configuration()
1069 .with_sleep_impl(InstantSleep::unlogged())
1070 .with_http_client(http_client.clone()),
1071 )
1072 .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1073 .build();
1074
1075 let res = imds_client
1076 .get("/latest/metadata")
1077 .await
1078 .expect_err("Client should error");
1079
1080 assert_full_error_contains!(res, "401");
1083 }
1084
1085 #[tokio::test]
1087 async fn invalid_token() {
1088 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1089 token_request("http://169.254.169.254", 21600),
1090 token_response(21600, "invalid\nheader\nvalue\0"),
1091 )]);
1092 let err = client.get("/latest/metadata").await.expect_err("no token");
1093 assert_full_error_contains!(err, "invalid token");
1094 http_client.assert_requests_match(&[]);
1095 }
1096
1097 #[tokio::test]
1098 async fn non_utf8_response() {
1099 let (client, http_client) = mock_imds_client(vec![
1100 ReplayEvent::new(
1101 token_request("http://169.254.169.254", 21600),
1102 token_response(21600, TOKEN_A).map(SdkBody::from),
1103 ),
1104 ReplayEvent::new(
1105 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1106 http::Response::builder()
1107 .status(200)
1108 .body(SdkBody::from(vec![0xA0, 0xA1]))
1109 .unwrap(),
1110 ),
1111 ]);
1112 let err = client.get("/latest/metadata").await.expect_err("no token");
1113 assert_full_error_contains!(err, "invalid UTF-8");
1114 http_client.assert_requests_match(&[]);
1115 }
1116
1117 #[cfg_attr(windows, ignore)]
1119 #[tokio::test]
1121 #[cfg(feature = "default-https-client")]
1122 async fn one_second_connect_timeout() {
1123 use crate::imds::client::ImdsError;
1124 use std::time::SystemTime;
1125 let client = Client::builder()
1126 .endpoint("http://240.0.0.0")
1128 .expect("valid uri")
1129 .build();
1130 let now = SystemTime::now();
1131 let resp = client
1132 .get("/latest/metadata")
1133 .await
1134 .expect_err("240.0.0.0 will never resolve");
1135 match resp {
1136 err @ ImdsError::FailedToLoadToken(_)
1137 if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
1139 "wrong error, expected construction failure with TimedOutError inside: {}",
1140 DisplayErrorContext(&other)
1141 ),
1142 }
1143 let time_elapsed = now.elapsed().unwrap();
1144 assert!(
1145 time_elapsed > Duration::from_secs(1),
1146 "time_elapsed should be greater than 1s but was {:?}",
1147 time_elapsed
1148 );
1149 assert!(
1150 time_elapsed < Duration::from_secs(2),
1151 "time_elapsed should be less than 2s but was {:?}",
1152 time_elapsed
1153 );
1154 }
1155
1156 async fn retry_connect_timeouts_for_bv(
1157 behavior_version: aws_smithy_runtime_api::client::behavior_version::BehaviorVersion,
1158 min_elapsed: Duration,
1159 max_elapsed: Duration,
1160 ) {
1161 use std::time::SystemTime;
1162 let http_client = StaticReplayClient::new(vec![]);
1163 let imds_client = super::Client::builder()
1164 .retry_classifier(SharedRetryClassifier::new(
1165 ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1166 ))
1167 .configure(
1168 &ProviderConfig::no_configuration()
1169 .with_http_client(http_client.clone())
1170 .with_behavior_version(Some(behavior_version)),
1171 )
1172 .operation_timeout(Duration::from_secs(1))
1173 .endpoint("http://240.0.0.0")
1174 .expect("valid uri")
1175 .build();
1176
1177 let now = SystemTime::now();
1178 let _res = imds_client
1179 .get("/latest/metadata")
1180 .await
1181 .expect_err("240.0.0.0 will never resolve");
1182 let time_elapsed = now.elapsed().unwrap();
1183
1184 assert!(
1185 time_elapsed > min_elapsed,
1186 "time_elapsed should be greater than {min_elapsed:?} but was {time_elapsed:?}",
1187 );
1188 assert!(
1189 time_elapsed < max_elapsed,
1190 "time_elapsed should be less than {max_elapsed:?} but was {time_elapsed:?}",
1191 );
1192 }
1193
1194 #[tokio::test]
1196 async fn retry_connect_timeouts() {
1197 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
1198 #[allow(deprecated)]
1200 retry_connect_timeouts_for_bv(
1201 BehaviorVersion::v2024_03_28(),
1202 Duration::from_secs(1),
1203 Duration::from_secs(2),
1204 )
1205 .await;
1206
1207 retry_connect_timeouts_for_bv(
1209 BehaviorVersion::latest(),
1210 Duration::from_millis(500),
1211 Duration::from_secs(2),
1212 )
1213 .await;
1214 }
1215
1216 #[derive(Debug, Deserialize)]
1217 struct ImdsConfigTest {
1218 env: HashMap<String, String>,
1219 fs: HashMap<String, String>,
1220 endpoint_override: Option<String>,
1221 mode_override: Option<String>,
1222 result: Result<String, String>,
1223 docs: String,
1224 }
1225
1226 #[tokio::test]
1227 async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1228 let _logs = capture_test_logs();
1229
1230 let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1231 #[derive(Deserialize)]
1232 struct TestCases {
1233 tests: Vec<ImdsConfigTest>,
1234 }
1235
1236 let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1237 let test_cases = test_cases.tests;
1238 for test in test_cases {
1239 check(test).await;
1240 }
1241 Ok(())
1242 }
1243
1244 async fn check(test_case: ImdsConfigTest) {
1245 let (http_client, watcher) = capture_request(None);
1246 let provider_config = ProviderConfig::no_configuration()
1247 .with_sleep_impl(TokioSleep::new())
1248 .with_env(Env::from(test_case.env))
1249 .with_fs(Fs::from_map(test_case.fs))
1250 .with_http_client(http_client);
1251 let mut imds_client = Client::builder().configure(&provider_config);
1252 if let Some(endpoint_override) = test_case.endpoint_override {
1253 imds_client = imds_client
1254 .endpoint(endpoint_override)
1255 .expect("invalid URI");
1256 }
1257
1258 if let Some(mode_override) = test_case.mode_override {
1259 imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1260 }
1261
1262 let imds_client = imds_client.build();
1263 match &test_case.result {
1264 Ok(uri) => {
1265 let _ = imds_client.get("/hello").await;
1267 assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1268 }
1269 Err(expected) => {
1270 let err = imds_client.get("/hello").await.expect_err("it should fail");
1271 let message = format!("{}", DisplayErrorContext(&err));
1272 assert!(
1273 message.contains(expected),
1274 "{}\nexpected error: {expected}\nactual error: {message}",
1275 test_case.docs
1276 );
1277 }
1278 };
1279 }
1280}