1use aws_credential_types::credential_feature::AwsCredentialFeature;
9use aws_credential_types::provider::{
10 self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
11};
12use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
13use aws_sdk_sts::operation::assume_role::AssumeRoleError;
14use aws_sdk_sts::types::{PolicyDescriptorType, Tag};
15use aws_sdk_sts::Client as StsClient;
16use aws_smithy_runtime::client::identity::IdentityCache;
17use aws_smithy_runtime_api::client::result::SdkError;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use aws_types::region::Region;
20use aws_types::SdkConfig;
21use std::time::Duration;
22use tracing::Instrument;
23
24#[derive(Debug)]
71pub struct AssumeRoleProvider {
72 inner: Inner,
73}
74
75#[derive(Debug)]
76struct Inner {
77 fluent_builder: AssumeRoleFluentBuilder,
78}
79
80impl AssumeRoleProvider {
81 pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
89 AssumeRoleProviderBuilder::new(role.into())
90 }
91}
92
93#[derive(Debug)]
97pub struct AssumeRoleProviderBuilder {
98 role_arn: String,
99 external_id: Option<String>,
100 session_name: Option<String>,
101 session_length: Option<Duration>,
102 policy: Option<String>,
103 policy_arns: Option<Vec<PolicyDescriptorType>>,
104 region_override: Option<Region>,
105 sdk_config: Option<SdkConfig>,
106 tags: Option<Vec<Tag>>,
107}
108
109impl AssumeRoleProviderBuilder {
110 pub fn new(role: impl Into<String>) -> Self {
118 Self {
119 role_arn: role.into(),
120 external_id: None,
121 session_name: None,
122 session_length: None,
123 policy: None,
124 policy_arns: None,
125 sdk_config: None,
126 region_override: None,
127 tags: None,
128 }
129 }
130
131 pub fn external_id(mut self, id: impl Into<String>) -> Self {
137 self.external_id = Some(id.into());
138 self
139 }
140
141 pub fn session_name(mut self, name: impl Into<String>) -> Self {
148 self.session_name = Some(name.into());
149 self
150 }
151
152 pub fn policy(mut self, policy: impl Into<String>) -> Self {
158 self.policy = Some(policy.into());
159 self
160 }
161
162 pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
168 self.policy_arns = Some(
169 policy_arns
170 .into_iter()
171 .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
172 .collect::<Vec<_>>(),
173 );
174 self
175 }
176
177 pub fn session_length(mut self, length: Duration) -> Self {
190 self.session_length = Some(length);
191 self
192 }
193
194 pub fn region(mut self, region: Region) -> Self {
199 self.region_override = Some(region);
200 self
201 }
202
203 pub fn tags<K, V>(mut self, tags: impl IntoIterator<Item = (K, V)>) -> Self
208 where
209 K: Into<String>,
210 V: Into<String>,
211 {
212 self.tags = Some(
213 tags.into_iter()
214 .map(|(k, v)| {
217 Tag::builder()
218 .key(k)
219 .value(v)
220 .build()
221 .expect("this is unreachable: both k and v are set")
222 })
223 .collect::<Vec<_>>(),
224 );
225 self
226 }
227
228 pub fn configure(mut self, conf: &SdkConfig) -> Self {
246 self.sdk_config = Some(conf.clone());
247 self
248 }
249
250 pub async fn build(self) -> AssumeRoleProvider {
255 let mut conf = match self.sdk_config {
256 Some(conf) => conf,
257 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
258 };
259 conf = conf
261 .into_builder()
262 .identity_cache(IdentityCache::no_cache())
263 .build();
264
265 if let Some(region) = self.region_override {
267 conf = conf.into_builder().region(region).build()
268 }
269
270 let config = aws_sdk_sts::config::Builder::from(&conf);
271
272 let time_source = conf.time_source().expect("A time source must be provided.");
273
274 let session_name = self.session_name.unwrap_or_else(|| {
275 super::util::default_session_name("assume-role-provider", time_source.now())
276 });
277
278 let sts_client = StsClient::from_conf(config.build());
279 let fluent_builder = sts_client
280 .assume_role()
281 .set_role_arn(Some(self.role_arn))
282 .set_external_id(self.external_id)
283 .set_role_session_name(Some(session_name))
284 .set_policy(self.policy)
285 .set_policy_arns(self.policy_arns)
286 .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32))
287 .set_tags(self.tags);
288
289 AssumeRoleProvider {
290 inner: Inner { fluent_builder },
291 }
292 }
293
294 pub async fn build_from_provider(
296 mut self,
297 provider: impl ProvideCredentials + 'static,
298 ) -> AssumeRoleProvider {
299 let conf = match self.sdk_config {
300 Some(conf) => conf,
301 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
302 };
303 let conf = conf
304 .into_builder()
305 .credentials_provider(SharedCredentialsProvider::new(provider))
306 .build();
307 self.sdk_config = Some(conf);
308 self.build().await
309 }
310}
311
312impl Inner {
313 async fn credentials(&self) -> provider::Result {
314 tracing::debug!("retrieving assumed credentials");
315
316 let assumed = self.fluent_builder.clone().send().in_current_span().await;
317 let assumed = match assumed {
318 Ok(assumed) => {
319 tracing::debug!(
320 access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
321 "obtained assumed credentials"
322 );
323 super::util::into_credentials(
324 assumed.credentials,
325 assumed.assumed_role_user,
326 "AssumeRoleProvider",
327 )
328 }
329 Err(SdkError::ServiceError(ref context))
330 if matches!(
331 context.err(),
332 AssumeRoleError::RegionDisabledException(_)
333 | AssumeRoleError::MalformedPolicyDocumentException(_)
334 ) =>
335 {
336 Err(CredentialsError::invalid_configuration(
337 assumed.err().unwrap(),
338 ))
339 }
340 Err(SdkError::ServiceError(ref context)) => {
341 tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
342 Err(CredentialsError::provider_error(assumed.err().unwrap()))
343 }
344 Err(err) => Err(CredentialsError::provider_error(err)),
345 };
346
347 assumed.map(|mut creds| {
348 creds
349 .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
350 .push(AwsCredentialFeature::CredentialsStsAssumeRole);
351 creds
352 })
353 }
354}
355
356impl ProvideCredentials for AssumeRoleProvider {
357 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
358 where
359 Self: 'a,
360 {
361 future::ProvideCredentials::new(
362 self.inner
363 .credentials()
364 .instrument(tracing::debug_span!("assume_role")),
365 )
366 }
367}
368
369#[cfg(test)]
370mod test {
371 use crate::sts::AssumeRoleProvider;
372 use aws_credential_types::credential_feature::AwsCredentialFeature;
373 use aws_credential_types::credential_fn::provide_credentials_fn;
374 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
375 use aws_credential_types::Credentials;
376 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
377 use aws_smithy_async::test_util::instant_time_and_sleep;
378 use aws_smithy_async::time::StaticTimeSource;
379 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
380 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
381 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
382 use aws_smithy_types::body::SdkBody;
383 use aws_types::os_shim_internal::Env;
384 use aws_types::region::Region;
385 use aws_types::SdkConfig;
386 use http::header::AUTHORIZATION;
387 use std::time::{Duration, UNIX_EPOCH};
388
389 #[tokio::test]
390 async fn configures_session_length() {
391 let (http_client, request) = capture_request(None);
392 let sdk_config = SdkConfig::builder()
393 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
394 .time_source(StaticTimeSource::new(
395 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
396 ))
397 .http_client(http_client)
398 .region(Region::from_static("this-will-be-overridden"))
399 .behavior_version(crate::BehaviorVersion::latest())
400 .build();
401 let provider = AssumeRoleProvider::builder("myrole")
402 .configure(&sdk_config)
403 .region(Region::new("us-east-1"))
404 .session_length(Duration::from_secs(1234567))
405 .build_from_provider(provide_credentials_fn(|| async {
406 Ok(Credentials::for_tests())
407 }))
408 .await;
409 let _ = dbg!(provider.provide_credentials().await);
410 let req = request.expect_request();
411 let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
412 assert!(str_body.contains("1234567"), "{}", str_body);
413 assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
414 }
415
416 #[tokio::test]
417 async fn loads_region_from_sdk_config() {
418 let (http_client, request) = capture_request(None);
419 let sdk_config = SdkConfig::builder()
420 .behavior_version(crate::BehaviorVersion::latest())
421 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
422 .time_source(StaticTimeSource::new(
423 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
424 ))
425 .http_client(http_client)
426 .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
427 || async {
428 panic!("don't call me — will be overridden");
429 },
430 )))
431 .region(Region::from_static("us-west-2"))
432 .build();
433 let provider = AssumeRoleProvider::builder("myrole")
434 .configure(&sdk_config)
435 .session_length(Duration::from_secs(1234567))
436 .build_from_provider(provide_credentials_fn(|| async {
437 Ok(Credentials::for_tests())
438 }))
439 .await;
440 let _ = dbg!(provider.provide_credentials().await);
441 let req = request.expect_request();
442 assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
443 }
444
445 #[tokio::test]
447 async fn build_method_from_sdk_config() {
448 let _guard = capture_test_logs();
449 let (http_client, request) = capture_request(Some(
450 http::Response::builder()
451 .status(404)
452 .body(SdkBody::from(""))
453 .unwrap(),
454 ));
455 let conf = crate::defaults(BehaviorVersion::latest())
456 .env(Env::from_slice(&[
457 ("AWS_ACCESS_KEY_ID", "123-key"),
458 ("AWS_SECRET_ACCESS_KEY", "456"),
459 ("AWS_REGION", "us-west-17"),
460 ]))
461 .use_dual_stack(true)
462 .use_fips(true)
463 .time_source(StaticTimeSource::from_secs(1234567890))
464 .http_client(http_client)
465 .load()
466 .await;
467 let provider = AssumeRoleProvider::builder("role")
468 .configure(&conf)
469 .build()
470 .await;
471 let _ = dbg!(provider.provide_credentials().await);
472 let req = request.expect_request();
473 let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
474 let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
475 assert!(
476 auth_header.contains(expect),
477 "Expected header to contain {expect} but it was {auth_header}"
478 );
479 assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
481 }
482
483 fn create_test_http_client() -> StaticReplayClient {
484 StaticReplayClient::new(vec![
485 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
486 http::Response::builder().status(200).body(SdkBody::from(
487 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:31:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
488 )).unwrap()),
489 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
490 http::Response::builder().status(200).body(SdkBody::from(
491 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>TESTSECRET</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:33:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
492 )).unwrap()),
493 ])
494 }
495
496 #[tokio::test]
497 async fn provider_does_not_cache_credentials_by_default() {
498 let http_client = create_test_http_client();
499
500 let (testing_time_source, sleep) = instant_time_and_sleep(
501 UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
503
504 let sdk_config = SdkConfig::builder()
505 .sleep_impl(SharedAsyncSleep::new(sleep))
506 .time_source(testing_time_source.clone())
507 .http_client(http_client)
508 .behavior_version(crate::BehaviorVersion::latest())
509 .build();
510 let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
511 Credentials::new(
512 "test",
513 "test",
514 None,
515 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
516 "test",
517 ),
518 Credentials::new(
519 "test",
520 "test",
521 None,
522 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
523 "test",
524 ),
525 ]));
526 let credentials_list_cloned = credentials_list.clone();
527 let provider = AssumeRoleProvider::builder("myrole")
528 .configure(&sdk_config)
529 .region(Region::new("us-east-1"))
530 .build_from_provider(provide_credentials_fn(move || {
531 let list = credentials_list.clone();
532 async move {
533 let next = list.lock().unwrap().remove(0);
534 Ok(next)
535 }
536 }))
537 .await;
538
539 let creds_first = provider
540 .provide_credentials()
541 .await
542 .expect("should return valid credentials");
543
544 testing_time_source.advance(Duration::from_secs(120));
548
549 let creds_second = provider
550 .provide_credentials()
551 .await
552 .expect("should return the second credentials");
553 assert_ne!(creds_first, creds_second);
554 assert!(credentials_list_cloned.lock().unwrap().is_empty());
555 }
556
557 #[tokio::test]
558 async fn credentials_feature() {
559 let http_client = create_test_http_client();
560
561 let (testing_time_source, sleep) = instant_time_and_sleep(
562 UNIX_EPOCH + Duration::from_secs(1234567890), );
564
565 let sdk_config = SdkConfig::builder()
566 .sleep_impl(SharedAsyncSleep::new(sleep))
567 .time_source(testing_time_source.clone())
568 .http_client(http_client)
569 .behavior_version(crate::BehaviorVersion::latest())
570 .build();
571 let credentials = Credentials::new(
572 "test",
573 "test",
574 None,
575 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
576 "test",
577 );
578 let provider = AssumeRoleProvider::builder("myrole")
579 .configure(&sdk_config)
580 .region(Region::new("us-east-1"))
581 .build_from_provider(credentials)
582 .await;
583
584 let creds = provider
585 .provide_credentials()
586 .await
587 .expect("should return valid credentials");
588
589 assert_eq!(
590 &vec![AwsCredentialFeature::CredentialsStsAssumeRole],
591 creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
592 )
593 }
594}