1use aws_credential_types::provider::{
9 self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
10};
11use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
12use aws_sdk_sts::operation::assume_role::AssumeRoleError;
13use aws_sdk_sts::types::PolicyDescriptorType;
14use aws_sdk_sts::Client as StsClient;
15use aws_smithy_runtime::client::identity::IdentityCache;
16use aws_smithy_runtime_api::client::result::SdkError;
17use aws_smithy_types::error::display::DisplayErrorContext;
18use aws_types::region::Region;
19use aws_types::SdkConfig;
20use std::time::Duration;
21use tracing::Instrument;
22
23#[derive(Debug)]
70pub struct AssumeRoleProvider {
71 inner: Inner,
72}
73
74#[derive(Debug)]
75struct Inner {
76 fluent_builder: AssumeRoleFluentBuilder,
77}
78
79impl AssumeRoleProvider {
80 pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
88 AssumeRoleProviderBuilder::new(role.into())
89 }
90}
91
92#[derive(Debug)]
96pub struct AssumeRoleProviderBuilder {
97 role_arn: String,
98 external_id: Option<String>,
99 session_name: Option<String>,
100 session_length: Option<Duration>,
101 policy: Option<String>,
102 policy_arns: Option<Vec<PolicyDescriptorType>>,
103 region_override: Option<Region>,
104 sdk_config: Option<SdkConfig>,
105}
106
107impl AssumeRoleProviderBuilder {
108 pub fn new(role: impl Into<String>) -> Self {
116 Self {
117 role_arn: role.into(),
118 external_id: None,
119 session_name: None,
120 session_length: None,
121 policy: None,
122 policy_arns: None,
123 sdk_config: None,
124 region_override: None,
125 }
126 }
127
128 pub fn external_id(mut self, id: impl Into<String>) -> Self {
134 self.external_id = Some(id.into());
135 self
136 }
137
138 pub fn session_name(mut self, name: impl Into<String>) -> Self {
145 self.session_name = Some(name.into());
146 self
147 }
148
149 pub fn policy(mut self, policy: impl Into<String>) -> Self {
155 self.policy = Some(policy.into());
156 self
157 }
158
159 pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
165 self.policy_arns = Some(
166 policy_arns
167 .into_iter()
168 .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
169 .collect::<Vec<_>>(),
170 );
171 self
172 }
173
174 pub fn session_length(mut self, length: Duration) -> Self {
187 self.session_length = Some(length);
188 self
189 }
190
191 pub fn region(mut self, region: Region) -> Self {
196 self.region_override = Some(region);
197 self
198 }
199
200 pub fn configure(mut self, conf: &SdkConfig) -> Self {
218 self.sdk_config = Some(conf.clone());
219 self
220 }
221
222 pub async fn build(self) -> AssumeRoleProvider {
227 let mut conf = match self.sdk_config {
228 Some(conf) => conf,
229 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
230 };
231 conf = conf
233 .into_builder()
234 .identity_cache(IdentityCache::no_cache())
235 .build();
236
237 if let Some(region) = self.region_override {
239 conf = conf.into_builder().region(region).build()
240 }
241
242 let config = aws_sdk_sts::config::Builder::from(&conf);
243
244 let time_source = conf.time_source().expect("A time source must be provided.");
245
246 let session_name = self.session_name.unwrap_or_else(|| {
247 super::util::default_session_name("assume-role-provider", time_source.now())
248 });
249
250 let sts_client = StsClient::from_conf(config.build());
251 let fluent_builder = sts_client
252 .assume_role()
253 .set_role_arn(Some(self.role_arn))
254 .set_external_id(self.external_id)
255 .set_role_session_name(Some(session_name))
256 .set_policy(self.policy)
257 .set_policy_arns(self.policy_arns)
258 .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
259
260 AssumeRoleProvider {
261 inner: Inner { fluent_builder },
262 }
263 }
264
265 pub async fn build_from_provider(
267 mut self,
268 provider: impl ProvideCredentials + 'static,
269 ) -> AssumeRoleProvider {
270 let conf = match self.sdk_config {
271 Some(conf) => conf,
272 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
273 };
274 let conf = conf
275 .into_builder()
276 .credentials_provider(SharedCredentialsProvider::new(provider))
277 .build();
278 self.sdk_config = Some(conf);
279 self.build().await
280 }
281}
282
283impl Inner {
284 async fn credentials(&self) -> provider::Result {
285 tracing::debug!("retrieving assumed credentials");
286
287 let assumed = self.fluent_builder.clone().send().in_current_span().await;
288 match assumed {
289 Ok(assumed) => {
290 tracing::debug!(
291 access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
292 "obtained assumed credentials"
293 );
294 super::util::into_credentials(
295 assumed.credentials,
296 assumed.assumed_role_user,
297 "AssumeRoleProvider",
298 )
299 }
300 Err(SdkError::ServiceError(ref context))
301 if matches!(
302 context.err(),
303 AssumeRoleError::RegionDisabledException(_)
304 | AssumeRoleError::MalformedPolicyDocumentException(_)
305 ) =>
306 {
307 Err(CredentialsError::invalid_configuration(
308 assumed.err().unwrap(),
309 ))
310 }
311 Err(SdkError::ServiceError(ref context)) => {
312 tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
313 Err(CredentialsError::provider_error(assumed.err().unwrap()))
314 }
315 Err(err) => Err(CredentialsError::provider_error(err)),
316 }
317 }
318}
319
320impl ProvideCredentials for AssumeRoleProvider {
321 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
322 where
323 Self: 'a,
324 {
325 future::ProvideCredentials::new(
326 self.inner
327 .credentials()
328 .instrument(tracing::debug_span!("assume_role")),
329 )
330 }
331}
332
333#[cfg(test)]
334mod test {
335 use crate::sts::AssumeRoleProvider;
336 use aws_credential_types::credential_fn::provide_credentials_fn;
337 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
338 use aws_credential_types::Credentials;
339 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
340 use aws_smithy_async::test_util::instant_time_and_sleep;
341 use aws_smithy_async::time::StaticTimeSource;
342 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
343 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
344 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
345 use aws_smithy_types::body::SdkBody;
346 use aws_types::os_shim_internal::Env;
347 use aws_types::region::Region;
348 use aws_types::SdkConfig;
349 use http::header::AUTHORIZATION;
350 use std::time::{Duration, UNIX_EPOCH};
351
352 #[tokio::test]
353 async fn configures_session_length() {
354 let (http_client, request) = capture_request(None);
355 let sdk_config = SdkConfig::builder()
356 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
357 .time_source(StaticTimeSource::new(
358 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
359 ))
360 .http_client(http_client)
361 .region(Region::from_static("this-will-be-overridden"))
362 .behavior_version(crate::BehaviorVersion::latest())
363 .build();
364 let provider = AssumeRoleProvider::builder("myrole")
365 .configure(&sdk_config)
366 .region(Region::new("us-east-1"))
367 .session_length(Duration::from_secs(1234567))
368 .build_from_provider(provide_credentials_fn(|| async {
369 Ok(Credentials::for_tests())
370 }))
371 .await;
372 let _ = dbg!(provider.provide_credentials().await);
373 let req = request.expect_request();
374 let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
375 assert!(str_body.contains("1234567"), "{}", str_body);
376 assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
377 }
378
379 #[tokio::test]
380 async fn loads_region_from_sdk_config() {
381 let (http_client, request) = capture_request(None);
382 let sdk_config = SdkConfig::builder()
383 .behavior_version(crate::BehaviorVersion::latest())
384 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
385 .time_source(StaticTimeSource::new(
386 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
387 ))
388 .http_client(http_client)
389 .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
390 || async {
391 panic!("don't call me — will be overridden");
392 },
393 )))
394 .region(Region::from_static("us-west-2"))
395 .build();
396 let provider = AssumeRoleProvider::builder("myrole")
397 .configure(&sdk_config)
398 .session_length(Duration::from_secs(1234567))
399 .build_from_provider(provide_credentials_fn(|| async {
400 Ok(Credentials::for_tests())
401 }))
402 .await;
403 let _ = dbg!(provider.provide_credentials().await);
404 let req = request.expect_request();
405 assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
406 }
407
408 #[tokio::test]
410 async fn build_method_from_sdk_config() {
411 let _guard = capture_test_logs();
412 let (http_client, request) = capture_request(Some(
413 http::Response::builder()
414 .status(404)
415 .body(SdkBody::from(""))
416 .unwrap(),
417 ));
418 let conf = crate::defaults(BehaviorVersion::latest())
419 .env(Env::from_slice(&[
420 ("AWS_ACCESS_KEY_ID", "123-key"),
421 ("AWS_SECRET_ACCESS_KEY", "456"),
422 ("AWS_REGION", "us-west-17"),
423 ]))
424 .use_dual_stack(true)
425 .use_fips(true)
426 .time_source(StaticTimeSource::from_secs(1234567890))
427 .http_client(http_client)
428 .load()
429 .await;
430 let provider = AssumeRoleProvider::builder("role")
431 .configure(&conf)
432 .build()
433 .await;
434 let _ = dbg!(provider.provide_credentials().await);
435 let req = request.expect_request();
436 let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
437 let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
438 assert!(
439 auth_header.contains(expect),
440 "Expected header to contain {expect} but it was {auth_header}"
441 );
442 assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
444 }
445
446 #[tokio::test]
447 async fn provider_does_not_cache_credentials_by_default() {
448 let http_client = StaticReplayClient::new(vec![
449 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
450 http::Response::builder().status(200).body(SdkBody::from(
451 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:31:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
452 )).unwrap()),
453 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
454 http::Response::builder().status(200).body(SdkBody::from(
455 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>TESTSECRET</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:33:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
456 )).unwrap()),
457 ]);
458
459 let (testing_time_source, sleep) = instant_time_and_sleep(
460 UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
462
463 let sdk_config = SdkConfig::builder()
464 .sleep_impl(SharedAsyncSleep::new(sleep))
465 .time_source(testing_time_source.clone())
466 .http_client(http_client)
467 .behavior_version(crate::BehaviorVersion::latest())
468 .build();
469 let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
470 Credentials::new(
471 "test",
472 "test",
473 None,
474 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
475 "test",
476 ),
477 Credentials::new(
478 "test",
479 "test",
480 None,
481 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
482 "test",
483 ),
484 ]));
485 let credentials_list_cloned = credentials_list.clone();
486 let provider = AssumeRoleProvider::builder("myrole")
487 .configure(&sdk_config)
488 .region(Region::new("us-east-1"))
489 .build_from_provider(provide_credentials_fn(move || {
490 let list = credentials_list.clone();
491 async move {
492 let next = list.lock().unwrap().remove(0);
493 Ok(next)
494 }
495 }))
496 .await;
497
498 let creds_first = provider
499 .provide_credentials()
500 .await
501 .expect("should return valid credentials");
502
503 testing_time_source.advance(Duration::from_secs(120));
507
508 let creds_second = provider
509 .provide_credentials()
510 .await
511 .expect("should return the second credentials");
512 assert_ne!(creds_first, creds_second);
513 assert!(credentials_list_cloned.lock().unwrap().is_empty());
514 }
515}