aws_config/sts/
assume_role.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Assume credentials for a role through the AWS Security Token Service (STS).
7
8use aws_credential_types::credential_feature::AwsCredentialFeature;
9use aws_credential_types::provider::{
10    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
11};
12use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
13use aws_sdk_sts::operation::assume_role::AssumeRoleError;
14use aws_sdk_sts::types::{PolicyDescriptorType, Tag};
15use aws_sdk_sts::Client as StsClient;
16use aws_smithy_runtime::client::identity::IdentityCache;
17use aws_smithy_runtime_api::client::result::SdkError;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use aws_types::region::Region;
20use aws_types::SdkConfig;
21use std::time::Duration;
22use tracing::Instrument;
23
24/// Credentials provider that uses credentials provided by another provider to assume a role
25/// through the AWS Security Token Service (STS).
26///
27/// When asked to provide credentials, this provider will first invoke the inner credentials
28/// provider to get AWS credentials for STS. Then, it will call STS to get assumed credentials for
29/// the desired role.
30///
31/// # Examples
32/// Create an AssumeRoleProvider explicitly set to us-east-2 that utilizes the default credentials chain.
33/// ```no_run
34/// use aws_config::sts::AssumeRoleProvider;
35/// use aws_types::region::Region;
36/// # async fn docs() {
37/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
38///   .region(Region::from_static("us-east-2"))
39///   .session_name("testAR")
40///   .build().await;
41/// }
42/// ```
43///
44/// Create an AssumeRoleProvider from an explicitly configured base configuration.
45/// ```no_run
46/// use aws_config::sts::AssumeRoleProvider;
47/// use aws_types::region::Region;
48/// # async fn docs() {
49/// let conf = aws_config::from_env().use_fips(true).load().await;
50/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
51///   .configure(&conf)
52///   .session_name("testAR")
53///   .build().await;
54/// }
55/// ```
56///
57/// Create an AssumeroleProvider that sources credentials from a provider credential provider:
58/// ```no_run
59/// use aws_config::sts::AssumeRoleProvider;
60/// use aws_types::region::Region;
61/// use aws_config::environment::EnvironmentVariableCredentialsProvider;
62/// # async fn docs() {
63/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
64///   .session_name("test-assume-role-session")
65///   // only consider environment variables, explicitly.
66///   .build_from_provider(EnvironmentVariableCredentialsProvider::new()).await;
67/// }
68/// ```
69///
70#[derive(Debug)]
71pub struct AssumeRoleProvider {
72    inner: Inner,
73}
74
75#[derive(Debug)]
76struct Inner {
77    fluent_builder: AssumeRoleFluentBuilder,
78}
79
80impl AssumeRoleProvider {
81    /// Build a new role-assuming provider for the given role.
82    ///
83    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
84    ///
85    /// ```text
86    /// arn:aws:iam::123456789012:role/example
87    /// ```
88    pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
89        AssumeRoleProviderBuilder::new(role.into())
90    }
91}
92
93/// A builder for [`AssumeRoleProvider`].
94///
95/// Construct one through [`AssumeRoleProvider::builder`].
96#[derive(Debug)]
97pub struct AssumeRoleProviderBuilder {
98    role_arn: String,
99    external_id: Option<String>,
100    session_name: Option<String>,
101    session_length: Option<Duration>,
102    policy: Option<String>,
103    policy_arns: Option<Vec<PolicyDescriptorType>>,
104    region_override: Option<Region>,
105    sdk_config: Option<SdkConfig>,
106    tags: Option<Vec<Tag>>,
107}
108
109impl AssumeRoleProviderBuilder {
110    /// Start a new assume role builder for the given role.
111    ///
112    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
113    ///
114    /// ```text
115    /// arn:aws:iam::123456789012:role/example
116    /// ```
117    pub fn new(role: impl Into<String>) -> Self {
118        Self {
119            role_arn: role.into(),
120            external_id: None,
121            session_name: None,
122            session_length: None,
123            policy: None,
124            policy_arns: None,
125            sdk_config: None,
126            region_override: None,
127            tags: None,
128        }
129    }
130
131    /// Set a unique identifier that might be required when you assume a role in another account.
132    ///
133    /// If the administrator of the account to which the role belongs provided you with an external
134    /// ID, then provide that value in this parameter. The value can be any string, such as a
135    /// passphrase or account number.
136    pub fn external_id(mut self, id: impl Into<String>) -> Self {
137        self.external_id = Some(id.into());
138        self
139    }
140
141    /// Set an identifier for the assumed role session.
142    ///
143    /// Use the role session name to uniquely identify a session when the same role is assumed by
144    /// different principals or for different reasons. In cross-account scenarios, the role session
145    /// name is visible to, and can be logged by the account that owns the role. The role session
146    /// name is also used in the ARN of the assumed role principal.
147    pub fn session_name(mut self, name: impl Into<String>) -> Self {
148        self.session_name = Some(name.into());
149        self
150    }
151
152    /// Set an IAM policy in JSON format that you want to use as an inline session policy.
153    ///
154    /// This parameter is optional
155    /// For more information, see
156    /// [policy](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
157    pub fn policy(mut self, policy: impl Into<String>) -> Self {
158        self.policy = Some(policy.into());
159        self
160    }
161
162    /// Set the Amazon Resource Names (ARNs) of the IAM managed policies that you want to use as managed session policies.
163    ///
164    /// This parameter is optional.
165    /// For more information, see
166    /// [policy_arns](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
167    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
168        self.policy_arns = Some(
169            policy_arns
170                .into_iter()
171                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
172                .collect::<Vec<_>>(),
173        );
174        self
175    }
176
177    /// Set the expiration time of the role session.
178    ///
179    /// When unset, this value defaults to 1 hour.
180    ///
181    /// The value specified can range from 900 seconds (15 minutes) up to the maximum session duration
182    /// set for the role. The maximum session duration setting can have a value from 1 hour to 12 hours.
183    /// If you specify a value higher than this setting or the administrator setting (whichever is lower),
184    /// **you will be unable to assume the role**. For example, if you specify a session duration of 12 hours,
185    /// but your administrator set the maximum session duration to 6 hours, you cannot assume the role.
186    ///
187    /// For more information, see
188    /// [duration_seconds](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::duration_seconds)
189    pub fn session_length(mut self, length: Duration) -> Self {
190        self.session_length = Some(length);
191        self
192    }
193
194    /// Set the region to assume the role in.
195    ///
196    /// This dictates which STS endpoint the AssumeRole action is invoked on. This will override
197    /// a region set from `.configure(...)`
198    pub fn region(mut self, region: Region) -> Self {
199        self.region_override = Some(region);
200        self
201    }
202
203    /// Set the session tags
204    ///
205    /// A list of session tags that you want to pass. Each session tag consists of a key name and an associated value.
206    /// For more information, see `[Tag]`.
207    pub fn tags<K, V>(mut self, tags: impl IntoIterator<Item = (K, V)>) -> Self
208    where
209        K: Into<String>,
210        V: Into<String>,
211    {
212        self.tags = Some(
213            tags.into_iter()
214                // Unwrap won't fail as both key and value are specified.
215                // Currently Tag does not have an infallible build method.
216                .map(|(k, v)| {
217                    Tag::builder()
218                        .key(k)
219                        .value(v)
220                        .build()
221                        .expect("this is unreachable: both k and v are set")
222                })
223                .collect::<Vec<_>>(),
224        );
225        self
226    }
227
228    /// Sets the configuration used for this provider
229    ///
230    /// This enables overriding the connection used to communicate with STS in addition to other internal
231    /// fields like the time source and sleep implementation used for caching.
232    ///
233    /// If this field is not provided, configuration from [`crate::load_from_env()`] is used.
234    ///
235    /// # Examples
236    /// ```rust
237    /// # async fn docs() {
238    /// use aws_types::region::Region;
239    /// use aws_config::sts::AssumeRoleProvider;
240    /// let config = aws_config::from_env().region(Region::from_static("us-west-2")).load().await;
241    /// let assume_role_provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/example")
242    ///   .configure(&config)
243    ///   .build();
244    /// }
245    pub fn configure(mut self, conf: &SdkConfig) -> Self {
246        self.sdk_config = Some(conf.clone());
247        self
248    }
249
250    /// Build a credentials provider for this role.
251    ///
252    /// Base credentials will be used from the [`SdkConfig`] set via [`Self::configure`] or loaded
253    /// from [`aws_config::from_env`](crate::from_env) if `configure` was never called.
254    pub async fn build(self) -> AssumeRoleProvider {
255        let mut conf = match self.sdk_config {
256            Some(conf) => conf,
257            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
258        };
259        // ignore a identity cache set from SdkConfig
260        conf = conf
261            .into_builder()
262            .identity_cache(IdentityCache::no_cache())
263            .build();
264
265        // set a region override if one exists
266        if let Some(region) = self.region_override {
267            conf = conf.into_builder().region(region).build()
268        }
269
270        let config = aws_sdk_sts::config::Builder::from(&conf);
271
272        let time_source = conf.time_source().expect("A time source must be provided.");
273
274        let session_name = self.session_name.unwrap_or_else(|| {
275            super::util::default_session_name("assume-role-provider", time_source.now())
276        });
277
278        let sts_client = StsClient::from_conf(config.build());
279        let fluent_builder = sts_client
280            .assume_role()
281            .set_role_arn(Some(self.role_arn))
282            .set_external_id(self.external_id)
283            .set_role_session_name(Some(session_name))
284            .set_policy(self.policy)
285            .set_policy_arns(self.policy_arns)
286            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32))
287            .set_tags(self.tags);
288
289        AssumeRoleProvider {
290            inner: Inner { fluent_builder },
291        }
292    }
293
294    /// Build a credentials provider for this role authorized by the given `provider`.
295    pub async fn build_from_provider(
296        mut self,
297        provider: impl ProvideCredentials + 'static,
298    ) -> AssumeRoleProvider {
299        let conf = match self.sdk_config {
300            Some(conf) => conf,
301            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
302        };
303        let conf = conf
304            .into_builder()
305            .credentials_provider(SharedCredentialsProvider::new(provider))
306            .build();
307        self.sdk_config = Some(conf);
308        self.build().await
309    }
310}
311
312impl Inner {
313    async fn credentials(&self) -> provider::Result {
314        tracing::debug!("retrieving assumed credentials");
315
316        let assumed = self.fluent_builder.clone().send().in_current_span().await;
317        let assumed = match assumed {
318            Ok(assumed) => {
319                tracing::debug!(
320                    access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
321                    "obtained assumed credentials"
322                );
323                super::util::into_credentials(
324                    assumed.credentials,
325                    assumed.assumed_role_user,
326                    "AssumeRoleProvider",
327                )
328            }
329            Err(SdkError::ServiceError(ref context))
330                if matches!(
331                    context.err(),
332                    AssumeRoleError::RegionDisabledException(_)
333                        | AssumeRoleError::MalformedPolicyDocumentException(_)
334                ) =>
335            {
336                Err(CredentialsError::invalid_configuration(
337                    assumed.err().unwrap(),
338                ))
339            }
340            Err(SdkError::ServiceError(ref context)) => {
341                tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
342                Err(CredentialsError::provider_error(assumed.err().unwrap()))
343            }
344            Err(err) => Err(CredentialsError::provider_error(err)),
345        };
346
347        assumed.map(|mut creds| {
348            creds
349                .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
350                .push(AwsCredentialFeature::CredentialsStsAssumeRole);
351            creds
352        })
353    }
354}
355
356impl ProvideCredentials for AssumeRoleProvider {
357    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
358    where
359        Self: 'a,
360    {
361        future::ProvideCredentials::new(
362            self.inner
363                .credentials()
364                .instrument(tracing::debug_span!("assume_role")),
365        )
366    }
367}
368
369#[cfg(test)]
370mod test {
371    use crate::sts::AssumeRoleProvider;
372    use aws_credential_types::credential_feature::AwsCredentialFeature;
373    use aws_credential_types::credential_fn::provide_credentials_fn;
374    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
375    use aws_credential_types::Credentials;
376    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
377    use aws_smithy_async::test_util::instant_time_and_sleep;
378    use aws_smithy_async::time::StaticTimeSource;
379    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
380    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
381    use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
382    use aws_smithy_types::body::SdkBody;
383    use aws_types::os_shim_internal::Env;
384    use aws_types::region::Region;
385    use aws_types::SdkConfig;
386    use http::header::AUTHORIZATION;
387    use std::time::{Duration, UNIX_EPOCH};
388
389    #[tokio::test]
390    async fn configures_session_length() {
391        let (http_client, request) = capture_request(None);
392        let sdk_config = SdkConfig::builder()
393            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
394            .time_source(StaticTimeSource::new(
395                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
396            ))
397            .http_client(http_client)
398            .region(Region::from_static("this-will-be-overridden"))
399            .behavior_version(crate::BehaviorVersion::latest())
400            .build();
401        let provider = AssumeRoleProvider::builder("myrole")
402            .configure(&sdk_config)
403            .region(Region::new("us-east-1"))
404            .session_length(Duration::from_secs(1234567))
405            .build_from_provider(provide_credentials_fn(|| async {
406                Ok(Credentials::for_tests())
407            }))
408            .await;
409        let _ = dbg!(provider.provide_credentials().await);
410        let req = request.expect_request();
411        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
412        assert!(str_body.contains("1234567"), "{}", str_body);
413        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
414    }
415
416    #[tokio::test]
417    async fn loads_region_from_sdk_config() {
418        let (http_client, request) = capture_request(None);
419        let sdk_config = SdkConfig::builder()
420            .behavior_version(crate::BehaviorVersion::latest())
421            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
422            .time_source(StaticTimeSource::new(
423                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
424            ))
425            .http_client(http_client)
426            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
427                || async {
428                    panic!("don't call me — will be overridden");
429                },
430            )))
431            .region(Region::from_static("us-west-2"))
432            .build();
433        let provider = AssumeRoleProvider::builder("myrole")
434            .configure(&sdk_config)
435            .session_length(Duration::from_secs(1234567))
436            .build_from_provider(provide_credentials_fn(|| async {
437                Ok(Credentials::for_tests())
438            }))
439            .await;
440        let _ = dbg!(provider.provide_credentials().await);
441        let req = request.expect_request();
442        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
443    }
444
445    /// Test that `build()` where no provider is passed still works
446    #[tokio::test]
447    async fn build_method_from_sdk_config() {
448        let _guard = capture_test_logs();
449        let (http_client, request) = capture_request(Some(
450            http::Response::builder()
451                .status(404)
452                .body(SdkBody::from(""))
453                .unwrap(),
454        ));
455        let conf = crate::defaults(BehaviorVersion::latest())
456            .env(Env::from_slice(&[
457                ("AWS_ACCESS_KEY_ID", "123-key"),
458                ("AWS_SECRET_ACCESS_KEY", "456"),
459                ("AWS_REGION", "us-west-17"),
460            ]))
461            .use_dual_stack(true)
462            .use_fips(true)
463            .time_source(StaticTimeSource::from_secs(1234567890))
464            .http_client(http_client)
465            .load()
466            .await;
467        let provider = AssumeRoleProvider::builder("role")
468            .configure(&conf)
469            .build()
470            .await;
471        let _ = dbg!(provider.provide_credentials().await);
472        let req = request.expect_request();
473        let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
474        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
475        assert!(
476            auth_header.contains(expect),
477            "Expected header to contain {expect} but it was {auth_header}"
478        );
479        // ensure that FIPS & DualStack are also respected
480        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
481    }
482
483    fn create_test_http_client() -> StaticReplayClient {
484        StaticReplayClient::new(vec![
485            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
486            http::Response::builder().status(200).body(SdkBody::from(
487                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
488            )).unwrap()),
489            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
490            http::Response::builder().status(200).body(SdkBody::from(
491                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>TESTSECRET</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:33:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
492            )).unwrap()),
493        ])
494    }
495
496    #[tokio::test]
497    async fn provider_does_not_cache_credentials_by_default() {
498        let http_client = create_test_http_client();
499
500        let (testing_time_source, sleep) = instant_time_and_sleep(
501            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
502        );
503
504        let sdk_config = SdkConfig::builder()
505            .sleep_impl(SharedAsyncSleep::new(sleep))
506            .time_source(testing_time_source.clone())
507            .http_client(http_client)
508            .behavior_version(crate::BehaviorVersion::latest())
509            .build();
510        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
511            Credentials::new(
512                "test",
513                "test",
514                None,
515                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
516                "test",
517            ),
518            Credentials::new(
519                "test",
520                "test",
521                None,
522                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
523                "test",
524            ),
525        ]));
526        let credentials_list_cloned = credentials_list.clone();
527        let provider = AssumeRoleProvider::builder("myrole")
528            .configure(&sdk_config)
529            .region(Region::new("us-east-1"))
530            .build_from_provider(provide_credentials_fn(move || {
531                let list = credentials_list.clone();
532                async move {
533                    let next = list.lock().unwrap().remove(0);
534                    Ok(next)
535                }
536            }))
537            .await;
538
539        let creds_first = provider
540            .provide_credentials()
541            .await
542            .expect("should return valid credentials");
543
544        // After time has been advanced by 120 seconds, the first credentials _could_ still be valid
545        // if `LazyCredentialsCache` were used, but the provider uses `NoCredentialsCache` by default
546        // so the first credentials will not be used.
547        testing_time_source.advance(Duration::from_secs(120));
548
549        let creds_second = provider
550            .provide_credentials()
551            .await
552            .expect("should return the second credentials");
553        assert_ne!(creds_first, creds_second);
554        assert!(credentials_list_cloned.lock().unwrap().is_empty());
555    }
556
557    #[tokio::test]
558    async fn credentials_feature() {
559        let http_client = create_test_http_client();
560
561        let (testing_time_source, sleep) = instant_time_and_sleep(
562            UNIX_EPOCH + Duration::from_secs(1234567890), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
563        );
564
565        let sdk_config = SdkConfig::builder()
566            .sleep_impl(SharedAsyncSleep::new(sleep))
567            .time_source(testing_time_source.clone())
568            .http_client(http_client)
569            .behavior_version(crate::BehaviorVersion::latest())
570            .build();
571        let credentials = Credentials::new(
572            "test",
573            "test",
574            None,
575            Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
576            "test",
577        );
578        let provider = AssumeRoleProvider::builder("myrole")
579            .configure(&sdk_config)
580            .region(Region::new("us-east-1"))
581            .build_from_provider(credentials)
582            .await;
583
584        let creds = provider
585            .provide_credentials()
586            .await
587            .expect("should return valid credentials");
588
589        assert_eq!(
590            &vec![AwsCredentialFeature::CredentialsStsAssumeRole],
591            creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
592        )
593    }
594}