1use aws_credential_types::credential_feature::AwsCredentialFeature;
9use aws_credential_types::provider::{
10 self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
11};
12use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
13use aws_sdk_sts::operation::assume_role::AssumeRoleError;
14use aws_sdk_sts::types::PolicyDescriptorType;
15use aws_sdk_sts::Client as StsClient;
16use aws_smithy_runtime::client::identity::IdentityCache;
17use aws_smithy_runtime_api::client::result::SdkError;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use aws_types::region::Region;
20use aws_types::SdkConfig;
21use std::time::Duration;
22use tracing::Instrument;
23
24#[derive(Debug)]
71pub struct AssumeRoleProvider {
72 inner: Inner,
73}
74
75#[derive(Debug)]
76struct Inner {
77 fluent_builder: AssumeRoleFluentBuilder,
78}
79
80impl AssumeRoleProvider {
81 pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
89 AssumeRoleProviderBuilder::new(role.into())
90 }
91}
92
93#[derive(Debug)]
97pub struct AssumeRoleProviderBuilder {
98 role_arn: String,
99 external_id: Option<String>,
100 session_name: Option<String>,
101 session_length: Option<Duration>,
102 policy: Option<String>,
103 policy_arns: Option<Vec<PolicyDescriptorType>>,
104 region_override: Option<Region>,
105 sdk_config: Option<SdkConfig>,
106}
107
108impl AssumeRoleProviderBuilder {
109 pub fn new(role: impl Into<String>) -> Self {
117 Self {
118 role_arn: role.into(),
119 external_id: None,
120 session_name: None,
121 session_length: None,
122 policy: None,
123 policy_arns: None,
124 sdk_config: None,
125 region_override: None,
126 }
127 }
128
129 pub fn external_id(mut self, id: impl Into<String>) -> Self {
135 self.external_id = Some(id.into());
136 self
137 }
138
139 pub fn session_name(mut self, name: impl Into<String>) -> Self {
146 self.session_name = Some(name.into());
147 self
148 }
149
150 pub fn policy(mut self, policy: impl Into<String>) -> Self {
156 self.policy = Some(policy.into());
157 self
158 }
159
160 pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
166 self.policy_arns = Some(
167 policy_arns
168 .into_iter()
169 .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
170 .collect::<Vec<_>>(),
171 );
172 self
173 }
174
175 pub fn session_length(mut self, length: Duration) -> Self {
188 self.session_length = Some(length);
189 self
190 }
191
192 pub fn region(mut self, region: Region) -> Self {
197 self.region_override = Some(region);
198 self
199 }
200
201 pub fn configure(mut self, conf: &SdkConfig) -> Self {
219 self.sdk_config = Some(conf.clone());
220 self
221 }
222
223 pub async fn build(self) -> AssumeRoleProvider {
228 let mut conf = match self.sdk_config {
229 Some(conf) => conf,
230 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
231 };
232 conf = conf
234 .into_builder()
235 .identity_cache(IdentityCache::no_cache())
236 .build();
237
238 if let Some(region) = self.region_override {
240 conf = conf.into_builder().region(region).build()
241 }
242
243 let config = aws_sdk_sts::config::Builder::from(&conf);
244
245 let time_source = conf.time_source().expect("A time source must be provided.");
246
247 let session_name = self.session_name.unwrap_or_else(|| {
248 super::util::default_session_name("assume-role-provider", time_source.now())
249 });
250
251 let sts_client = StsClient::from_conf(config.build());
252 let fluent_builder = sts_client
253 .assume_role()
254 .set_role_arn(Some(self.role_arn))
255 .set_external_id(self.external_id)
256 .set_role_session_name(Some(session_name))
257 .set_policy(self.policy)
258 .set_policy_arns(self.policy_arns)
259 .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
260
261 AssumeRoleProvider {
262 inner: Inner { fluent_builder },
263 }
264 }
265
266 pub async fn build_from_provider(
268 mut self,
269 provider: impl ProvideCredentials + 'static,
270 ) -> AssumeRoleProvider {
271 let conf = match self.sdk_config {
272 Some(conf) => conf,
273 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
274 };
275 let conf = conf
276 .into_builder()
277 .credentials_provider(SharedCredentialsProvider::new(provider))
278 .build();
279 self.sdk_config = Some(conf);
280 self.build().await
281 }
282}
283
284impl Inner {
285 async fn credentials(&self) -> provider::Result {
286 tracing::debug!("retrieving assumed credentials");
287
288 let assumed = self.fluent_builder.clone().send().in_current_span().await;
289 let assumed = match assumed {
290 Ok(assumed) => {
291 tracing::debug!(
292 access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
293 "obtained assumed credentials"
294 );
295 super::util::into_credentials(
296 assumed.credentials,
297 assumed.assumed_role_user,
298 "AssumeRoleProvider",
299 )
300 }
301 Err(SdkError::ServiceError(ref context))
302 if matches!(
303 context.err(),
304 AssumeRoleError::RegionDisabledException(_)
305 | AssumeRoleError::MalformedPolicyDocumentException(_)
306 ) =>
307 {
308 Err(CredentialsError::invalid_configuration(
309 assumed.err().unwrap(),
310 ))
311 }
312 Err(SdkError::ServiceError(ref context)) => {
313 tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
314 Err(CredentialsError::provider_error(assumed.err().unwrap()))
315 }
316 Err(err) => Err(CredentialsError::provider_error(err)),
317 };
318
319 assumed.map(|mut creds| {
320 creds
321 .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
322 .push(AwsCredentialFeature::CredentialsStsAssumeRole);
323 creds
324 })
325 }
326}
327
328impl ProvideCredentials for AssumeRoleProvider {
329 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
330 where
331 Self: 'a,
332 {
333 future::ProvideCredentials::new(
334 self.inner
335 .credentials()
336 .instrument(tracing::debug_span!("assume_role")),
337 )
338 }
339}
340
341#[cfg(test)]
342mod test {
343 use crate::sts::AssumeRoleProvider;
344 use aws_credential_types::credential_feature::AwsCredentialFeature;
345 use aws_credential_types::credential_fn::provide_credentials_fn;
346 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
347 use aws_credential_types::Credentials;
348 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
349 use aws_smithy_async::test_util::instant_time_and_sleep;
350 use aws_smithy_async::time::StaticTimeSource;
351 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
352 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
353 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
354 use aws_smithy_types::body::SdkBody;
355 use aws_types::os_shim_internal::Env;
356 use aws_types::region::Region;
357 use aws_types::SdkConfig;
358 use http::header::AUTHORIZATION;
359 use std::time::{Duration, UNIX_EPOCH};
360
361 #[tokio::test]
362 async fn configures_session_length() {
363 let (http_client, request) = capture_request(None);
364 let sdk_config = SdkConfig::builder()
365 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
366 .time_source(StaticTimeSource::new(
367 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
368 ))
369 .http_client(http_client)
370 .region(Region::from_static("this-will-be-overridden"))
371 .behavior_version(crate::BehaviorVersion::latest())
372 .build();
373 let provider = AssumeRoleProvider::builder("myrole")
374 .configure(&sdk_config)
375 .region(Region::new("us-east-1"))
376 .session_length(Duration::from_secs(1234567))
377 .build_from_provider(provide_credentials_fn(|| async {
378 Ok(Credentials::for_tests())
379 }))
380 .await;
381 let _ = dbg!(provider.provide_credentials().await);
382 let req = request.expect_request();
383 let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
384 assert!(str_body.contains("1234567"), "{}", str_body);
385 assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
386 }
387
388 #[tokio::test]
389 async fn loads_region_from_sdk_config() {
390 let (http_client, request) = capture_request(None);
391 let sdk_config = SdkConfig::builder()
392 .behavior_version(crate::BehaviorVersion::latest())
393 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
394 .time_source(StaticTimeSource::new(
395 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
396 ))
397 .http_client(http_client)
398 .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
399 || async {
400 panic!("don't call me — will be overridden");
401 },
402 )))
403 .region(Region::from_static("us-west-2"))
404 .build();
405 let provider = AssumeRoleProvider::builder("myrole")
406 .configure(&sdk_config)
407 .session_length(Duration::from_secs(1234567))
408 .build_from_provider(provide_credentials_fn(|| async {
409 Ok(Credentials::for_tests())
410 }))
411 .await;
412 let _ = dbg!(provider.provide_credentials().await);
413 let req = request.expect_request();
414 assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
415 }
416
417 #[tokio::test]
419 async fn build_method_from_sdk_config() {
420 let _guard = capture_test_logs();
421 let (http_client, request) = capture_request(Some(
422 http::Response::builder()
423 .status(404)
424 .body(SdkBody::from(""))
425 .unwrap(),
426 ));
427 let conf = crate::defaults(BehaviorVersion::latest())
428 .env(Env::from_slice(&[
429 ("AWS_ACCESS_KEY_ID", "123-key"),
430 ("AWS_SECRET_ACCESS_KEY", "456"),
431 ("AWS_REGION", "us-west-17"),
432 ]))
433 .use_dual_stack(true)
434 .use_fips(true)
435 .time_source(StaticTimeSource::from_secs(1234567890))
436 .http_client(http_client)
437 .load()
438 .await;
439 let provider = AssumeRoleProvider::builder("role")
440 .configure(&conf)
441 .build()
442 .await;
443 let _ = dbg!(provider.provide_credentials().await);
444 let req = request.expect_request();
445 let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
446 let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
447 assert!(
448 auth_header.contains(expect),
449 "Expected header to contain {expect} but it was {auth_header}"
450 );
451 assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
453 }
454
455 fn create_test_http_client() -> StaticReplayClient {
456 StaticReplayClient::new(vec![
457 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
458 http::Response::builder().status(200).body(SdkBody::from(
459 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:31:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
460 )).unwrap()),
461 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
462 http::Response::builder().status(200).body(SdkBody::from(
463 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>TESTSECRET</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:33:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
464 )).unwrap()),
465 ])
466 }
467
468 #[tokio::test]
469 async fn provider_does_not_cache_credentials_by_default() {
470 let http_client = create_test_http_client();
471
472 let (testing_time_source, sleep) = instant_time_and_sleep(
473 UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
475
476 let sdk_config = SdkConfig::builder()
477 .sleep_impl(SharedAsyncSleep::new(sleep))
478 .time_source(testing_time_source.clone())
479 .http_client(http_client)
480 .behavior_version(crate::BehaviorVersion::latest())
481 .build();
482 let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
483 Credentials::new(
484 "test",
485 "test",
486 None,
487 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
488 "test",
489 ),
490 Credentials::new(
491 "test",
492 "test",
493 None,
494 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
495 "test",
496 ),
497 ]));
498 let credentials_list_cloned = credentials_list.clone();
499 let provider = AssumeRoleProvider::builder("myrole")
500 .configure(&sdk_config)
501 .region(Region::new("us-east-1"))
502 .build_from_provider(provide_credentials_fn(move || {
503 let list = credentials_list.clone();
504 async move {
505 let next = list.lock().unwrap().remove(0);
506 Ok(next)
507 }
508 }))
509 .await;
510
511 let creds_first = provider
512 .provide_credentials()
513 .await
514 .expect("should return valid credentials");
515
516 testing_time_source.advance(Duration::from_secs(120));
520
521 let creds_second = provider
522 .provide_credentials()
523 .await
524 .expect("should return the second credentials");
525 assert_ne!(creds_first, creds_second);
526 assert!(credentials_list_cloned.lock().unwrap().is_empty());
527 }
528
529 #[tokio::test]
530 async fn credentials_feature() {
531 let http_client = create_test_http_client();
532
533 let (testing_time_source, sleep) = instant_time_and_sleep(
534 UNIX_EPOCH + Duration::from_secs(1234567890), );
536
537 let sdk_config = SdkConfig::builder()
538 .sleep_impl(SharedAsyncSleep::new(sleep))
539 .time_source(testing_time_source.clone())
540 .http_client(http_client)
541 .behavior_version(crate::BehaviorVersion::latest())
542 .build();
543 let credentials = Credentials::new(
544 "test",
545 "test",
546 None,
547 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
548 "test",
549 );
550 let provider = AssumeRoleProvider::builder("myrole")
551 .configure(&sdk_config)
552 .region(Region::new("us-east-1"))
553 .build_from_provider(credentials)
554 .await;
555
556 let creds = provider
557 .provide_credentials()
558 .await
559 .expect("should return valid credentials");
560
561 assert_eq!(
562 &vec![AwsCredentialFeature::CredentialsStsAssumeRole],
563 creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
564 )
565 }
566}