1use aws_credential_types::credential_feature::AwsCredentialFeature;
9use aws_credential_types::provider::{
10    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
11};
12use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
13use aws_sdk_sts::operation::assume_role::AssumeRoleError;
14use aws_sdk_sts::types::{PolicyDescriptorType, Tag};
15use aws_sdk_sts::Client as StsClient;
16use aws_smithy_runtime::client::identity::IdentityCache;
17use aws_smithy_runtime_api::client::result::SdkError;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use aws_types::region::Region;
20use aws_types::SdkConfig;
21use std::time::Duration;
22use tracing::Instrument;
23
24#[derive(Debug)]
71pub struct AssumeRoleProvider {
72    inner: Inner,
73}
74
75#[derive(Debug)]
76struct Inner {
77    fluent_builder: AssumeRoleFluentBuilder,
78}
79
80impl AssumeRoleProvider {
81    pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
89        AssumeRoleProviderBuilder::new(role.into())
90    }
91}
92
93#[derive(Debug)]
97pub struct AssumeRoleProviderBuilder {
98    role_arn: String,
99    external_id: Option<String>,
100    session_name: Option<String>,
101    session_length: Option<Duration>,
102    policy: Option<String>,
103    policy_arns: Option<Vec<PolicyDescriptorType>>,
104    region_override: Option<Region>,
105    sdk_config: Option<SdkConfig>,
106    tags: Option<Vec<Tag>>,
107}
108
109impl AssumeRoleProviderBuilder {
110    pub fn new(role: impl Into<String>) -> Self {
118        Self {
119            role_arn: role.into(),
120            external_id: None,
121            session_name: None,
122            session_length: None,
123            policy: None,
124            policy_arns: None,
125            sdk_config: None,
126            region_override: None,
127            tags: None,
128        }
129    }
130
131    pub fn external_id(mut self, id: impl Into<String>) -> Self {
137        self.external_id = Some(id.into());
138        self
139    }
140
141    pub fn session_name(mut self, name: impl Into<String>) -> Self {
148        self.session_name = Some(name.into());
149        self
150    }
151
152    pub fn policy(mut self, policy: impl Into<String>) -> Self {
158        self.policy = Some(policy.into());
159        self
160    }
161
162    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
168        self.policy_arns = Some(
169            policy_arns
170                .into_iter()
171                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
172                .collect::<Vec<_>>(),
173        );
174        self
175    }
176
177    pub fn session_length(mut self, length: Duration) -> Self {
190        self.session_length = Some(length);
191        self
192    }
193
194    pub fn region(mut self, region: Region) -> Self {
199        self.region_override = Some(region);
200        self
201    }
202
203    pub fn tags<K, V>(mut self, tags: impl IntoIterator<Item = (K, V)>) -> Self
208    where
209        K: Into<String>,
210        V: Into<String>,
211    {
212        self.tags = Some(
213            tags.into_iter()
214                .map(|(k, v)| {
217                    Tag::builder()
218                        .key(k)
219                        .value(v)
220                        .build()
221                        .expect("this is unreachable: both k and v are set")
222                })
223                .collect::<Vec<_>>(),
224        );
225        self
226    }
227
228    pub fn configure(mut self, conf: &SdkConfig) -> Self {
246        self.sdk_config = Some(conf.clone());
247        self
248    }
249
250    pub async fn build(self) -> AssumeRoleProvider {
255        let mut conf = match self.sdk_config {
256            Some(conf) => conf,
257            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
258        };
259        conf = conf
261            .into_builder()
262            .identity_cache(IdentityCache::no_cache())
263            .build();
264
265        if let Some(region) = self.region_override {
267            conf = conf.into_builder().region(region).build()
268        }
269
270        let config = aws_sdk_sts::config::Builder::from(&conf);
271
272        let time_source = conf.time_source().expect("A time source must be provided.");
273
274        let session_name = self.session_name.unwrap_or_else(|| {
275            super::util::default_session_name("assume-role-provider", time_source.now())
276        });
277
278        let sts_client = StsClient::from_conf(config.build());
279        let fluent_builder = sts_client
280            .assume_role()
281            .set_role_arn(Some(self.role_arn))
282            .set_external_id(self.external_id)
283            .set_role_session_name(Some(session_name))
284            .set_policy(self.policy)
285            .set_policy_arns(self.policy_arns)
286            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32))
287            .set_tags(self.tags);
288
289        AssumeRoleProvider {
290            inner: Inner { fluent_builder },
291        }
292    }
293
294    pub async fn build_from_provider(
296        mut self,
297        provider: impl ProvideCredentials + 'static,
298    ) -> AssumeRoleProvider {
299        let conf = match self.sdk_config {
300            Some(conf) => conf,
301            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
302        };
303        let conf = conf
304            .into_builder()
305            .credentials_provider(SharedCredentialsProvider::new(provider))
306            .build();
307        self.sdk_config = Some(conf);
308        self.build().await
309    }
310}
311
312impl Inner {
313    async fn credentials(&self) -> provider::Result {
314        tracing::debug!("retrieving assumed credentials");
315
316        let assumed = self.fluent_builder.clone().send().in_current_span().await;
317        let assumed = match assumed {
318            Ok(assumed) => {
319                tracing::debug!(
320                    access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
321                    "obtained assumed credentials"
322                );
323                super::util::into_credentials(
324                    assumed.credentials,
325                    assumed.assumed_role_user,
326                    "AssumeRoleProvider",
327                )
328            }
329            Err(SdkError::ServiceError(ref context))
330                if matches!(
331                    context.err(),
332                    AssumeRoleError::RegionDisabledException(_)
333                        | AssumeRoleError::MalformedPolicyDocumentException(_)
334                ) =>
335            {
336                Err(CredentialsError::invalid_configuration(
337                    assumed.err().unwrap(),
338                ))
339            }
340            Err(SdkError::ServiceError(ref context)) => {
341                tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
342                Err(CredentialsError::provider_error(assumed.err().unwrap()))
343            }
344            Err(err) => Err(CredentialsError::provider_error(err)),
345        };
346
347        assumed.map(|mut creds| {
348            creds
349                .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
350                .push(AwsCredentialFeature::CredentialsStsAssumeRole);
351            creds
352        })
353    }
354}
355
356impl ProvideCredentials for AssumeRoleProvider {
357    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
358    where
359        Self: 'a,
360    {
361        future::ProvideCredentials::new(
362            self.inner
363                .credentials()
364                .instrument(tracing::debug_span!("assume_role")),
365        )
366    }
367}
368
369#[cfg(test)]
370mod test {
371    use crate::sts::AssumeRoleProvider;
372    use aws_credential_types::credential_feature::AwsCredentialFeature;
373    use aws_credential_types::credential_fn::provide_credentials_fn;
374    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
375    use aws_credential_types::Credentials;
376    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
377    use aws_smithy_async::test_util::instant_time_and_sleep;
378    use aws_smithy_async::time::StaticTimeSource;
379    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
380    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
381    use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
382    use aws_smithy_types::body::SdkBody;
383    use aws_types::os_shim_internal::Env;
384    use aws_types::region::Region;
385    use aws_types::SdkConfig;
386    use http::header::AUTHORIZATION;
387    use std::time::{Duration, UNIX_EPOCH};
388
389    #[tokio::test]
390    async fn configures_session_length() {
391        let (http_client, request) = capture_request(None);
392        let sdk_config = SdkConfig::builder()
393            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
394            .time_source(StaticTimeSource::new(
395                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
396            ))
397            .http_client(http_client)
398            .region(Region::from_static("this-will-be-overridden"))
399            .behavior_version(crate::BehaviorVersion::latest())
400            .build();
401        let provider = AssumeRoleProvider::builder("myrole")
402            .configure(&sdk_config)
403            .region(Region::new("us-east-1"))
404            .session_length(Duration::from_secs(1234567))
405            .build_from_provider(provide_credentials_fn(|| async {
406                Ok(Credentials::for_tests())
407            }))
408            .await;
409        let _ = dbg!(provider.provide_credentials().await);
410        let req = request.expect_request();
411        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
412        assert!(str_body.contains("1234567"), "{}", str_body);
413        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
414    }
415
416    #[tokio::test]
417    async fn loads_region_from_sdk_config() {
418        let (http_client, request) = capture_request(None);
419        let sdk_config = SdkConfig::builder()
420            .behavior_version(crate::BehaviorVersion::latest())
421            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
422            .time_source(StaticTimeSource::new(
423                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
424            ))
425            .http_client(http_client)
426            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
427                || async {
428                    panic!("don't call me — will be overridden");
429                },
430            )))
431            .region(Region::from_static("us-west-2"))
432            .build();
433        let provider = AssumeRoleProvider::builder("myrole")
434            .configure(&sdk_config)
435            .session_length(Duration::from_secs(1234567))
436            .build_from_provider(provide_credentials_fn(|| async {
437                Ok(Credentials::for_tests())
438            }))
439            .await;
440        let _ = dbg!(provider.provide_credentials().await);
441        let req = request.expect_request();
442        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
443    }
444
445    #[tokio::test]
447    async fn build_method_from_sdk_config() {
448        let _guard = capture_test_logs();
449        let (http_client, request) = capture_request(Some(
450            http::Response::builder()
451                .status(404)
452                .body(SdkBody::from(""))
453                .unwrap(),
454        ));
455        let conf = crate::defaults(BehaviorVersion::latest())
456            .env(Env::from_slice(&[
457                ("AWS_ACCESS_KEY_ID", "123-key"),
458                ("AWS_SECRET_ACCESS_KEY", "456"),
459                ("AWS_REGION", "us-west-17"),
460            ]))
461            .use_dual_stack(true)
462            .use_fips(true)
463            .time_source(StaticTimeSource::from_secs(1234567890))
464            .http_client(http_client)
465            .load()
466            .await;
467        let provider = AssumeRoleProvider::builder("role")
468            .configure(&conf)
469            .build()
470            .await;
471        let _ = dbg!(provider.provide_credentials().await);
472        let req = request.expect_request();
473        let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
474        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
475        assert!(
476            auth_header.contains(expect),
477            "Expected header to contain {expect} but it was {auth_header}"
478        );
479        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
481    }
482
483    fn create_test_http_client() -> StaticReplayClient {
484        StaticReplayClient::new(vec![
485            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
486            http::Response::builder().status(200).body(SdkBody::from(
487                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
488            )).unwrap()),
489            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
490            http::Response::builder().status(200).body(SdkBody::from(
491                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>TESTSECRET</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:33:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
492            )).unwrap()),
493        ])
494    }
495
496    #[tokio::test]
497    async fn provider_does_not_cache_credentials_by_default() {
498        let http_client = create_test_http_client();
499
500        let (testing_time_source, sleep) = instant_time_and_sleep(
501            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
503
504        let sdk_config = SdkConfig::builder()
505            .sleep_impl(SharedAsyncSleep::new(sleep))
506            .time_source(testing_time_source.clone())
507            .http_client(http_client)
508            .behavior_version(crate::BehaviorVersion::latest())
509            .build();
510        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
511            Credentials::new(
512                "test",
513                "test",
514                None,
515                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
516                "test",
517            ),
518            Credentials::new(
519                "test",
520                "test",
521                None,
522                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
523                "test",
524            ),
525        ]));
526        let credentials_list_cloned = credentials_list.clone();
527        let provider = AssumeRoleProvider::builder("myrole")
528            .configure(&sdk_config)
529            .region(Region::new("us-east-1"))
530            .build_from_provider(provide_credentials_fn(move || {
531                let list = credentials_list.clone();
532                async move {
533                    let next = list.lock().unwrap().remove(0);
534                    Ok(next)
535                }
536            }))
537            .await;
538
539        let creds_first = provider
540            .provide_credentials()
541            .await
542            .expect("should return valid credentials");
543
544        testing_time_source.advance(Duration::from_secs(120));
548
549        let creds_second = provider
550            .provide_credentials()
551            .await
552            .expect("should return the second credentials");
553        assert_ne!(creds_first, creds_second);
554        assert!(credentials_list_cloned.lock().unwrap().is_empty());
555    }
556
557    #[tokio::test]
558    async fn credentials_feature() {
559        let http_client = create_test_http_client();
560
561        let (testing_time_source, sleep) = instant_time_and_sleep(
562            UNIX_EPOCH + Duration::from_secs(1234567890), );
564
565        let sdk_config = SdkConfig::builder()
566            .sleep_impl(SharedAsyncSleep::new(sleep))
567            .time_source(testing_time_source.clone())
568            .http_client(http_client)
569            .behavior_version(crate::BehaviorVersion::latest())
570            .build();
571        let credentials = Credentials::new(
572            "test",
573            "test",
574            None,
575            Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
576            "test",
577        );
578        let provider = AssumeRoleProvider::builder("myrole")
579            .configure(&sdk_config)
580            .region(Region::new("us-east-1"))
581            .build_from_provider(credentials)
582            .await;
583
584        let creds = provider
585            .provide_credentials()
586            .await
587            .expect("should return valid credentials");
588
589        assert_eq!(
590            &vec![AwsCredentialFeature::CredentialsStsAssumeRole],
591            creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
592        )
593    }
594}