aws_config/sts/
assume_role.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Assume credentials for a role through the AWS Security Token Service (STS).
7
8use aws_credential_types::credential_feature::AwsCredentialFeature;
9use aws_credential_types::provider::{
10    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
11};
12use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
13use aws_sdk_sts::operation::assume_role::AssumeRoleError;
14use aws_sdk_sts::types::PolicyDescriptorType;
15use aws_sdk_sts::Client as StsClient;
16use aws_smithy_runtime::client::identity::IdentityCache;
17use aws_smithy_runtime_api::client::result::SdkError;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use aws_types::region::Region;
20use aws_types::SdkConfig;
21use std::time::Duration;
22use tracing::Instrument;
23
24/// Credentials provider that uses credentials provided by another provider to assume a role
25/// through the AWS Security Token Service (STS).
26///
27/// When asked to provide credentials, this provider will first invoke the inner credentials
28/// provider to get AWS credentials for STS. Then, it will call STS to get assumed credentials for
29/// the desired role.
30///
31/// # Examples
32/// Create an AssumeRoleProvider explicitly set to us-east-2 that utilizes the default credentials chain.
33/// ```no_run
34/// use aws_config::sts::AssumeRoleProvider;
35/// use aws_types::region::Region;
36/// # async fn docs() {
37/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
38///   .region(Region::from_static("us-east-2"))
39///   .session_name("testAR")
40///   .build().await;
41/// }
42/// ```
43///
44/// Create an AssumeRoleProvider from an explicitly configured base configuration.
45/// ```no_run
46/// use aws_config::sts::AssumeRoleProvider;
47/// use aws_types::region::Region;
48/// # async fn docs() {
49/// let conf = aws_config::from_env().use_fips(true).load().await;
50/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
51///   .configure(&conf)
52///   .session_name("testAR")
53///   .build().await;
54/// }
55/// ```
56///
57/// Create an AssumeroleProvider that sources credentials from a provider credential provider:
58/// ```no_run
59/// use aws_config::sts::AssumeRoleProvider;
60/// use aws_types::region::Region;
61/// use aws_config::environment::EnvironmentVariableCredentialsProvider;
62/// # async fn docs() {
63/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
64///   .session_name("test-assume-role-session")
65///   // only consider environment variables, explicitly.
66///   .build_from_provider(EnvironmentVariableCredentialsProvider::new()).await;
67/// }
68/// ```
69///
70#[derive(Debug)]
71pub struct AssumeRoleProvider {
72    inner: Inner,
73}
74
75#[derive(Debug)]
76struct Inner {
77    fluent_builder: AssumeRoleFluentBuilder,
78}
79
80impl AssumeRoleProvider {
81    /// Build a new role-assuming provider for the given role.
82    ///
83    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
84    ///
85    /// ```text
86    /// arn:aws:iam::123456789012:role/example
87    /// ```
88    pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
89        AssumeRoleProviderBuilder::new(role.into())
90    }
91}
92
93/// A builder for [`AssumeRoleProvider`].
94///
95/// Construct one through [`AssumeRoleProvider::builder`].
96#[derive(Debug)]
97pub struct AssumeRoleProviderBuilder {
98    role_arn: String,
99    external_id: Option<String>,
100    session_name: Option<String>,
101    session_length: Option<Duration>,
102    policy: Option<String>,
103    policy_arns: Option<Vec<PolicyDescriptorType>>,
104    region_override: Option<Region>,
105    sdk_config: Option<SdkConfig>,
106}
107
108impl AssumeRoleProviderBuilder {
109    /// Start a new assume role builder for the given role.
110    ///
111    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
112    ///
113    /// ```text
114    /// arn:aws:iam::123456789012:role/example
115    /// ```
116    pub fn new(role: impl Into<String>) -> Self {
117        Self {
118            role_arn: role.into(),
119            external_id: None,
120            session_name: None,
121            session_length: None,
122            policy: None,
123            policy_arns: None,
124            sdk_config: None,
125            region_override: None,
126        }
127    }
128
129    /// Set a unique identifier that might be required when you assume a role in another account.
130    ///
131    /// If the administrator of the account to which the role belongs provided you with an external
132    /// ID, then provide that value in this parameter. The value can be any string, such as a
133    /// passphrase or account number.
134    pub fn external_id(mut self, id: impl Into<String>) -> Self {
135        self.external_id = Some(id.into());
136        self
137    }
138
139    /// Set an identifier for the assumed role session.
140    ///
141    /// Use the role session name to uniquely identify a session when the same role is assumed by
142    /// different principals or for different reasons. In cross-account scenarios, the role session
143    /// name is visible to, and can be logged by the account that owns the role. The role session
144    /// name is also used in the ARN of the assumed role principal.
145    pub fn session_name(mut self, name: impl Into<String>) -> Self {
146        self.session_name = Some(name.into());
147        self
148    }
149
150    /// Set an IAM policy in JSON format that you want to use as an inline session policy.
151    ///
152    /// This parameter is optional
153    /// For more information, see
154    /// [policy](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
155    pub fn policy(mut self, policy: impl Into<String>) -> Self {
156        self.policy = Some(policy.into());
157        self
158    }
159
160    /// Set the Amazon Resource Names (ARNs) of the IAM managed policies that you want to use as managed session policies.
161    ///
162    /// This parameter is optional.
163    /// For more information, see
164    /// [policy_arns](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
165    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
166        self.policy_arns = Some(
167            policy_arns
168                .into_iter()
169                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
170                .collect::<Vec<_>>(),
171        );
172        self
173    }
174
175    /// Set the expiration time of the role session.
176    ///
177    /// When unset, this value defaults to 1 hour.
178    ///
179    /// The value specified can range from 900 seconds (15 minutes) up to the maximum session duration
180    /// set for the role. The maximum session duration setting can have a value from 1 hour to 12 hours.
181    /// If you specify a value higher than this setting or the administrator setting (whichever is lower),
182    /// **you will be unable to assume the role**. For example, if you specify a session duration of 12 hours,
183    /// but your administrator set the maximum session duration to 6 hours, you cannot assume the role.
184    ///
185    /// For more information, see
186    /// [duration_seconds](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::duration_seconds)
187    pub fn session_length(mut self, length: Duration) -> Self {
188        self.session_length = Some(length);
189        self
190    }
191
192    /// Set the region to assume the role in.
193    ///
194    /// This dictates which STS endpoint the AssumeRole action is invoked on. This will override
195    /// a region set from `.configure(...)`
196    pub fn region(mut self, region: Region) -> Self {
197        self.region_override = Some(region);
198        self
199    }
200
201    /// Sets the configuration used for this provider
202    ///
203    /// This enables overriding the connection used to communicate with STS in addition to other internal
204    /// fields like the time source and sleep implementation used for caching.
205    ///
206    /// If this field is not provided, configuration from [`aws_config::load_from_env().await`] is used.
207    ///
208    /// # Examples
209    /// ```rust
210    /// # async fn docs() {
211    /// use aws_types::region::Region;
212    /// use aws_config::sts::AssumeRoleProvider;
213    /// let config = aws_config::from_env().region(Region::from_static("us-west-2")).load().await;
214    /// let assume_role_provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/example")
215    ///   .configure(&config)
216    ///   .build();
217    /// }
218    pub fn configure(mut self, conf: &SdkConfig) -> Self {
219        self.sdk_config = Some(conf.clone());
220        self
221    }
222
223    /// Build a credentials provider for this role.
224    ///
225    /// Base credentials will be used from the [`SdkConfig`] set via [`Self::configure`] or loaded
226    /// from [`aws_config::from_env`](crate::from_env) if `configure` was never called.
227    pub async fn build(self) -> AssumeRoleProvider {
228        let mut conf = match self.sdk_config {
229            Some(conf) => conf,
230            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
231        };
232        // ignore a identity cache set from SdkConfig
233        conf = conf
234            .into_builder()
235            .identity_cache(IdentityCache::no_cache())
236            .build();
237
238        // set a region override if one exists
239        if let Some(region) = self.region_override {
240            conf = conf.into_builder().region(region).build()
241        }
242
243        let config = aws_sdk_sts::config::Builder::from(&conf);
244
245        let time_source = conf.time_source().expect("A time source must be provided.");
246
247        let session_name = self.session_name.unwrap_or_else(|| {
248            super::util::default_session_name("assume-role-provider", time_source.now())
249        });
250
251        let sts_client = StsClient::from_conf(config.build());
252        let fluent_builder = sts_client
253            .assume_role()
254            .set_role_arn(Some(self.role_arn))
255            .set_external_id(self.external_id)
256            .set_role_session_name(Some(session_name))
257            .set_policy(self.policy)
258            .set_policy_arns(self.policy_arns)
259            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
260
261        AssumeRoleProvider {
262            inner: Inner { fluent_builder },
263        }
264    }
265
266    /// Build a credentials provider for this role authorized by the given `provider`.
267    pub async fn build_from_provider(
268        mut self,
269        provider: impl ProvideCredentials + 'static,
270    ) -> AssumeRoleProvider {
271        let conf = match self.sdk_config {
272            Some(conf) => conf,
273            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
274        };
275        let conf = conf
276            .into_builder()
277            .credentials_provider(SharedCredentialsProvider::new(provider))
278            .build();
279        self.sdk_config = Some(conf);
280        self.build().await
281    }
282}
283
284impl Inner {
285    async fn credentials(&self) -> provider::Result {
286        tracing::debug!("retrieving assumed credentials");
287
288        let assumed = self.fluent_builder.clone().send().in_current_span().await;
289        let assumed = match assumed {
290            Ok(assumed) => {
291                tracing::debug!(
292                    access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
293                    "obtained assumed credentials"
294                );
295                super::util::into_credentials(
296                    assumed.credentials,
297                    assumed.assumed_role_user,
298                    "AssumeRoleProvider",
299                )
300            }
301            Err(SdkError::ServiceError(ref context))
302                if matches!(
303                    context.err(),
304                    AssumeRoleError::RegionDisabledException(_)
305                        | AssumeRoleError::MalformedPolicyDocumentException(_)
306                ) =>
307            {
308                Err(CredentialsError::invalid_configuration(
309                    assumed.err().unwrap(),
310                ))
311            }
312            Err(SdkError::ServiceError(ref context)) => {
313                tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
314                Err(CredentialsError::provider_error(assumed.err().unwrap()))
315            }
316            Err(err) => Err(CredentialsError::provider_error(err)),
317        };
318
319        assumed.map(|mut creds| {
320            creds
321                .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
322                .push(AwsCredentialFeature::CredentialsStsAssumeRole);
323            creds
324        })
325    }
326}
327
328impl ProvideCredentials for AssumeRoleProvider {
329    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
330    where
331        Self: 'a,
332    {
333        future::ProvideCredentials::new(
334            self.inner
335                .credentials()
336                .instrument(tracing::debug_span!("assume_role")),
337        )
338    }
339}
340
341#[cfg(test)]
342mod test {
343    use crate::sts::AssumeRoleProvider;
344    use aws_credential_types::credential_feature::AwsCredentialFeature;
345    use aws_credential_types::credential_fn::provide_credentials_fn;
346    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
347    use aws_credential_types::Credentials;
348    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
349    use aws_smithy_async::test_util::instant_time_and_sleep;
350    use aws_smithy_async::time::StaticTimeSource;
351    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
352    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
353    use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
354    use aws_smithy_types::body::SdkBody;
355    use aws_types::os_shim_internal::Env;
356    use aws_types::region::Region;
357    use aws_types::SdkConfig;
358    use http::header::AUTHORIZATION;
359    use std::time::{Duration, UNIX_EPOCH};
360
361    #[tokio::test]
362    async fn configures_session_length() {
363        let (http_client, request) = capture_request(None);
364        let sdk_config = SdkConfig::builder()
365            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
366            .time_source(StaticTimeSource::new(
367                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
368            ))
369            .http_client(http_client)
370            .region(Region::from_static("this-will-be-overridden"))
371            .behavior_version(crate::BehaviorVersion::latest())
372            .build();
373        let provider = AssumeRoleProvider::builder("myrole")
374            .configure(&sdk_config)
375            .region(Region::new("us-east-1"))
376            .session_length(Duration::from_secs(1234567))
377            .build_from_provider(provide_credentials_fn(|| async {
378                Ok(Credentials::for_tests())
379            }))
380            .await;
381        let _ = dbg!(provider.provide_credentials().await);
382        let req = request.expect_request();
383        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
384        assert!(str_body.contains("1234567"), "{}", str_body);
385        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
386    }
387
388    #[tokio::test]
389    async fn loads_region_from_sdk_config() {
390        let (http_client, request) = capture_request(None);
391        let sdk_config = SdkConfig::builder()
392            .behavior_version(crate::BehaviorVersion::latest())
393            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
394            .time_source(StaticTimeSource::new(
395                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
396            ))
397            .http_client(http_client)
398            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
399                || async {
400                    panic!("don't call me — will be overridden");
401                },
402            )))
403            .region(Region::from_static("us-west-2"))
404            .build();
405        let provider = AssumeRoleProvider::builder("myrole")
406            .configure(&sdk_config)
407            .session_length(Duration::from_secs(1234567))
408            .build_from_provider(provide_credentials_fn(|| async {
409                Ok(Credentials::for_tests())
410            }))
411            .await;
412        let _ = dbg!(provider.provide_credentials().await);
413        let req = request.expect_request();
414        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
415    }
416
417    /// Test that `build()` where no provider is passed still works
418    #[tokio::test]
419    async fn build_method_from_sdk_config() {
420        let _guard = capture_test_logs();
421        let (http_client, request) = capture_request(Some(
422            http::Response::builder()
423                .status(404)
424                .body(SdkBody::from(""))
425                .unwrap(),
426        ));
427        let conf = crate::defaults(BehaviorVersion::latest())
428            .env(Env::from_slice(&[
429                ("AWS_ACCESS_KEY_ID", "123-key"),
430                ("AWS_SECRET_ACCESS_KEY", "456"),
431                ("AWS_REGION", "us-west-17"),
432            ]))
433            .use_dual_stack(true)
434            .use_fips(true)
435            .time_source(StaticTimeSource::from_secs(1234567890))
436            .http_client(http_client)
437            .load()
438            .await;
439        let provider = AssumeRoleProvider::builder("role")
440            .configure(&conf)
441            .build()
442            .await;
443        let _ = dbg!(provider.provide_credentials().await);
444        let req = request.expect_request();
445        let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
446        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
447        assert!(
448            auth_header.contains(expect),
449            "Expected header to contain {expect} but it was {auth_header}"
450        );
451        // ensure that FIPS & DualStack are also respected
452        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
453    }
454
455    fn create_test_http_client() -> StaticReplayClient {
456        StaticReplayClient::new(vec![
457            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
458            http::Response::builder().status(200).body(SdkBody::from(
459                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
460            )).unwrap()),
461            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
462            http::Response::builder().status(200).body(SdkBody::from(
463                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>TESTSECRET</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:33:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
464            )).unwrap()),
465        ])
466    }
467
468    #[tokio::test]
469    async fn provider_does_not_cache_credentials_by_default() {
470        let http_client = create_test_http_client();
471
472        let (testing_time_source, sleep) = instant_time_and_sleep(
473            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
474        );
475
476        let sdk_config = SdkConfig::builder()
477            .sleep_impl(SharedAsyncSleep::new(sleep))
478            .time_source(testing_time_source.clone())
479            .http_client(http_client)
480            .behavior_version(crate::BehaviorVersion::latest())
481            .build();
482        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
483            Credentials::new(
484                "test",
485                "test",
486                None,
487                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
488                "test",
489            ),
490            Credentials::new(
491                "test",
492                "test",
493                None,
494                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
495                "test",
496            ),
497        ]));
498        let credentials_list_cloned = credentials_list.clone();
499        let provider = AssumeRoleProvider::builder("myrole")
500            .configure(&sdk_config)
501            .region(Region::new("us-east-1"))
502            .build_from_provider(provide_credentials_fn(move || {
503                let list = credentials_list.clone();
504                async move {
505                    let next = list.lock().unwrap().remove(0);
506                    Ok(next)
507                }
508            }))
509            .await;
510
511        let creds_first = provider
512            .provide_credentials()
513            .await
514            .expect("should return valid credentials");
515
516        // After time has been advanced by 120 seconds, the first credentials _could_ still be valid
517        // if `LazyCredentialsCache` were used, but the provider uses `NoCredentialsCache` by default
518        // so the first credentials will not be used.
519        testing_time_source.advance(Duration::from_secs(120));
520
521        let creds_second = provider
522            .provide_credentials()
523            .await
524            .expect("should return the second credentials");
525        assert_ne!(creds_first, creds_second);
526        assert!(credentials_list_cloned.lock().unwrap().is_empty());
527    }
528
529    #[tokio::test]
530    async fn credentials_feature() {
531        let http_client = create_test_http_client();
532
533        let (testing_time_source, sleep) = instant_time_and_sleep(
534            UNIX_EPOCH + Duration::from_secs(1234567890), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
535        );
536
537        let sdk_config = SdkConfig::builder()
538            .sleep_impl(SharedAsyncSleep::new(sleep))
539            .time_source(testing_time_source.clone())
540            .http_client(http_client)
541            .behavior_version(crate::BehaviorVersion::latest())
542            .build();
543        let credentials = Credentials::new(
544            "test",
545            "test",
546            None,
547            Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
548            "test",
549        );
550        let provider = AssumeRoleProvider::builder("myrole")
551            .configure(&sdk_config)
552            .region(Region::new("us-east-1"))
553            .build_from_provider(credentials)
554            .await;
555
556        let creds = provider
557            .provide_credentials()
558            .await
559            .expect("should return valid credentials");
560
561        assert_eq!(
562            &vec![AwsCredentialFeature::CredentialsStsAssumeRole],
563            creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
564        )
565    }
566}