1mod cache;
9mod dpop;
11mod token;
12
13use crate::login::cache::{load_cached_token, save_cached_token};
14use crate::login::token::{LoginToken, LoginTokenError};
15use crate::provider_config::ProviderConfig;
16use aws_credential_types::credential_feature::AwsCredentialFeature;
17use aws_credential_types::provider;
18use aws_credential_types::provider::future;
19use aws_credential_types::provider::ProvideCredentials;
20use aws_sdk_signin::config::Builder as SignInClientConfigBuilder;
21use aws_sdk_signin::operation::create_o_auth2_token::CreateOAuth2TokenError;
22use aws_sdk_signin::types::{CreateOAuth2TokenRequestBody, OAuth2ErrorCode};
23use aws_sdk_signin::Client as SignInClient;
24use aws_smithy_async::time::SharedTimeSource;
25use aws_smithy_runtime::expiring_cache::ExpiringCache;
26use aws_types::os_shim_internal::{Env, Fs};
27use aws_types::SdkConfig;
28use std::sync::Arc;
29use std::sync::Mutex;
30use std::time::Duration;
31use std::time::SystemTime;
32
33const REFRESH_BUFFER_TIME: Duration = Duration::from_secs(5 * 60 );
34const MIN_TIME_BETWEEN_REFRESH: Duration = Duration::from_secs(30);
35pub(super) const PROVIDER_NAME: &str = "Login";
36
37#[derive(Debug)]
43pub struct LoginCredentialsProvider {
44 inner: Arc<Inner>,
45 token_cache: ExpiringCache<LoginToken, LoginTokenError>,
46}
47
48#[derive(Debug)]
49struct Inner {
50 fs: Fs,
51 env: Env,
52 session_arn: String,
53 enabled_from_profile: bool,
54 sdk_config: SdkConfig,
55 time_source: SharedTimeSource,
56 last_refresh_attempt: Mutex<Option<SystemTime>>,
57}
58
59impl LoginCredentialsProvider {
60 pub fn builder(session_arn: impl Into<String>) -> Builder {
68 Builder {
69 session_arn: session_arn.into(),
70 provider_config: None,
71 enabled_from_profile: false,
72 }
73 }
74
75 async fn resolve_token(&self) -> Result<LoginToken, LoginTokenError> {
76 let token_cache = self.token_cache.clone();
77 if let Some(token) = token_cache
78 .yield_or_clear_if_expired(self.inner.time_source.now())
79 .await
80 {
81 tracing::debug!("using cached Login token");
82 return Ok(token);
83 }
84
85 let inner = self.inner.clone();
86 let token = token_cache
87 .get_or_load(|| async move {
88 tracing::debug!("expiring cache asked for an updated Login token");
89 let mut token =
90 load_cached_token(&inner.env, &inner.fs, &inner.session_arn).await?;
91
92 tracing::debug!("loaded cached Login token");
93
94 let now = inner.time_source.now();
95 let expired = token.expires_at() <= now;
96 let expires_soon = token.expires_at() - REFRESH_BUFFER_TIME <= now;
97 let last_refresh = *inner.last_refresh_attempt.lock().unwrap();
98 let min_time_passed = last_refresh
99 .map(|lr| {
100 now.duration_since(lr).expect("last_refresh is in the past")
101 >= MIN_TIME_BETWEEN_REFRESH
102 })
103 .unwrap_or(true);
104
105 let refreshable = min_time_passed;
106
107 tracing::debug!(
108 expired = ?expired,
109 expires_soon = ?expires_soon,
110 min_time_passed = ?min_time_passed,
111 refreshable = ?refreshable,
112 will_refresh = ?(expires_soon && refreshable),
113 "cached Login token refresh decision"
114 );
115
116 if expired && !refreshable {
118 tracing::debug!("cached Login token is expired and cannot be refreshed");
119 return Err(LoginTokenError::ExpiredToken);
120 }
121
122 if expires_soon && refreshable {
124 tracing::debug!("attempting to refresh Login token");
125 let refreshed_token = Self::refresh_cached_token(&inner, &token, now).await?;
126 token = refreshed_token;
127 *inner.last_refresh_attempt.lock().unwrap() = Some(now);
128 }
129
130 let expires_at = token.expires_at();
131 Ok((token, expires_at))
132 })
133 .await?;
134
135 Ok(token)
136 }
137
138 async fn refresh_cached_token(
139 inner: &Inner,
140 cached_token: &LoginToken,
141 now: SystemTime,
142 ) -> Result<LoginToken, LoginTokenError> {
143 let dpop_auth_scheme = dpop::DPoPAuthScheme::new(&cached_token.dpop_key)?;
144 let client_config = SignInClientConfigBuilder::from(&inner.sdk_config)
145 .auth_scheme_resolver(dpop::DPoPAuthSchemeOptionResolver)
146 .push_auth_scheme(dpop_auth_scheme)
147 .build();
148
149 let client = SignInClient::from_conf(client_config);
150
151 let resp = client
152 .create_o_auth2_token()
153 .token_input(
154 CreateOAuth2TokenRequestBody::builder()
155 .client_id(&cached_token.client_id)
156 .grant_type("refresh_token")
157 .refresh_token(cached_token.refresh_token.as_str())
158 .build()
159 .expect("valid CreateOAuth2TokenRequestBody"),
160 )
161 .send()
162 .await
163 .map_err(|err| {
164 let service_err = err.into_service_error();
165 let message = match &service_err {
166 CreateOAuth2TokenError::AccessDeniedException(e) => match e.error {
167 OAuth2ErrorCode::InsufficientPermissions => Some("Unable to refresh credentials due to insufficient permissions. You may be missing permission for the 'CreateOAuth2Token' action.".to_string()),
168 OAuth2ErrorCode::TokenExpired => Some("Your session has expired. Please reauthenticate.".to_string()),
169 OAuth2ErrorCode::UserCredentialsChanged => Some("Unable to refresh credentials because of a change in your password. Please reauthenticate with your new password.".to_string()),
170 _ => None,
171 }
172 _ => None,
173 };
174
175 LoginTokenError::RefreshFailed {
176 message,
177 source: service_err.into(),
178 }
179 })?;
180
181 let token_output = resp.token_output.expect("valid token response");
182 let new_token = LoginToken::from_refresh(cached_token, token_output, now);
183
184 match save_cached_token(&inner.env, &inner.fs, &inner.session_arn, &new_token).await {
185 Ok(_) => {}
186 Err(e) => tracing::warn!("failed to save refreshed Login token: {e}"),
187 }
188 Ok(new_token)
189 }
190
191 async fn credentials(&self) -> provider::Result {
192 let token = self.resolve_token().await?;
193
194 let feat = match self.inner.enabled_from_profile {
195 true => AwsCredentialFeature::CredentialsProfileLogin,
196 false => AwsCredentialFeature::CredentialsProfile,
197 };
198
199 let mut creds = token.access_token;
200 creds
201 .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
202 .push(feat);
203 Ok(creds)
204 }
205}
206
207impl ProvideCredentials for LoginCredentialsProvider {
208 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
209 where
210 Self: 'a,
211 {
212 future::ProvideCredentials::new(self.credentials())
213 }
214}
215
216#[derive(Debug)]
218pub struct Builder {
219 session_arn: String,
220 provider_config: Option<ProviderConfig>,
221 enabled_from_profile: bool,
222}
223
224impl Builder {
225 pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
227 self.provider_config = Some(provider_config.clone());
228 self
229 }
230
231 pub(crate) fn enabled_from_profile(mut self, enabled: bool) -> Self {
234 self.enabled_from_profile = enabled;
235 self
236 }
237
238 pub fn build(self) -> LoginCredentialsProvider {
240 let provider_config = self.provider_config.unwrap_or_default();
241 let fs = provider_config.fs();
242 let env = provider_config.env();
243 let inner = Arc::new(Inner {
244 fs,
245 env,
246 session_arn: self.session_arn,
247 enabled_from_profile: self.enabled_from_profile,
248 sdk_config: provider_config.client_config(),
249 time_source: provider_config.time_source(),
250 last_refresh_attempt: Mutex::new(None),
251 });
252
253 LoginCredentialsProvider {
254 inner,
255 token_cache: ExpiringCache::new(REFRESH_BUFFER_TIME),
256 }
257 }
258}
259
260#[cfg(test)]
261mod test {
262 use super::*;
268 use crate::provider_config::ProviderConfig;
269 use aws_credential_types::provider::ProvideCredentials;
270 use aws_sdk_signin::config::RuntimeComponents;
271 use aws_smithy_async::rt::sleep::TokioSleep;
272 use aws_smithy_async::time::{SharedTimeSource, StaticTimeSource};
273 use aws_smithy_runtime_api::client::{
274 http::{
275 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings,
276 SharedHttpConnector,
277 },
278 orchestrator::{HttpRequest, HttpResponse},
279 };
280 use aws_smithy_types::body::SdkBody;
281 use aws_types::os_shim_internal::{Env, Fs};
282 use aws_types::region::Region;
283 use serde::Deserialize;
284 use std::collections::HashMap;
285 use std::error::Error;
286 use std::time::{Duration, UNIX_EPOCH};
287
288 #[derive(Deserialize, Debug)]
289 #[serde(rename_all = "camelCase")]
290 struct LoginTestCase {
291 documentation: String,
292 config_contents: String,
293 cache_contents: HashMap<String, serde_json::Value>,
294 #[serde(default)]
295 mock_api_calls: Vec<MockApiCall>,
296 outcomes: Vec<Outcome>,
297 }
298
299 #[derive(Deserialize, Debug, Clone)]
300 #[serde(rename_all = "camelCase")]
301 struct MockApiCall {
302 #[serde(default)]
303 response: Option<MockResponse>,
304 #[serde(default)]
305 response_code: Option<u16>,
306 }
307
308 #[derive(Deserialize, Debug, Clone)]
309 #[serde(rename_all = "camelCase")]
310 struct MockResponse {
311 token_output: TokenOutput,
312 }
313
314 #[derive(Deserialize, Debug, Clone)]
315 #[serde(rename_all = "camelCase")]
316 struct TokenOutput {
317 access_token: AccessToken,
318 refresh_token: String,
319 expires_in: u64,
320 }
321
322 #[derive(Deserialize, Debug, Clone)]
323 #[serde(rename_all = "camelCase")]
324 struct AccessToken {
325 access_key_id: String,
326 secret_access_key: String,
327 session_token: String,
328 }
329
330 #[derive(Deserialize, Debug)]
331 #[serde(tag = "result")]
332 enum Outcome {
333 #[serde(rename = "credentials")]
334 Credentials {
335 #[serde(rename = "accessKeyId")]
336 access_key_id: String,
337 #[serde(rename = "secretAccessKey")]
338 secret_access_key: String,
339 #[serde(rename = "sessionToken")]
340 session_token: String,
341 #[serde(rename = "accountId")]
342 account_id: String,
343 #[serde(default, rename = "expiresAt")]
344 #[allow(dead_code)]
345 expires_at: Option<String>,
346 },
347 #[serde(rename = "error")]
348 Error,
349 #[serde(rename = "cacheContents")]
350 CacheContents(HashMap<String, serde_json::Value>),
351 }
352
353 impl LoginTestCase {
354 async fn check(&self) {
355 let session_arn = "arn:aws:sts::012345678910:assumed-role/Admin/admin";
356
357 let now = UNIX_EPOCH + Duration::from_secs(1763510400);
359 let time_source = SharedTimeSource::new(StaticTimeSource::new(now));
360
361 let mut fs_map = HashMap::new();
363 fs_map.insert(
364 "/home/user/.aws/config".to_string(),
365 self.config_contents.as_bytes().to_vec(),
366 );
367 for (filename, contents) in &self.cache_contents {
368 let path = format!("/home/user/.aws/login/cache/{}", filename);
369 let mut contents = contents.clone();
371 if !contents.as_object().unwrap().contains_key("tokenType") {
372 contents.as_object_mut().unwrap().insert(
373 "tokenType".to_string(),
374 serde_json::Value::String("aws_sigv4".to_string()),
375 );
376 }
377 let json = serde_json::to_string(&contents).expect("valid json");
378 fs_map.insert(path, json.into_bytes());
379 }
380 let fs = Fs::from_map(fs_map);
381
382 let env = Env::from_slice(&[("HOME", "/home/user")]);
383
384 let http_client = if self.mock_api_calls.is_empty() {
386 crate::test_case::no_traffic_client()
387 } else {
388 aws_smithy_runtime_api::client::http::SharedHttpClient::new(TestHttpClient::new(
389 &self.mock_api_calls,
390 ))
391 };
392
393 let provider_config = ProviderConfig::empty()
394 .with_env(env.clone())
395 .with_fs(fs.clone())
396 .with_http_client(http_client)
397 .with_region(Some(Region::from_static("us-east-2")))
398 .with_sleep_impl(TokioSleep::new())
399 .with_time_source(time_source);
400
401 let provider = LoginCredentialsProvider::builder(session_arn)
402 .configure(&provider_config)
403 .build();
404
405 let result = dbg!(provider.provide_credentials().await);
407
408 for outcome in &self.outcomes {
409 match outcome {
410 Outcome::Credentials {
411 access_key_id,
412 secret_access_key,
413 session_token,
414 account_id,
415 expires_at: _,
416 } => {
417 let creds = result.as_ref().expect("credentials should succeed");
418 assert_eq!(access_key_id, creds.access_key_id());
419 assert_eq!(secret_access_key, creds.secret_access_key());
420 assert_eq!(session_token, creds.session_token().unwrap());
421 assert_eq!(account_id, creds.account_id().unwrap().as_str());
422 }
423 Outcome::Error => {
424 result.as_ref().expect_err("should fail");
425 }
426 Outcome::CacheContents(expected_cache) => {
427 for (filename, expected) in expected_cache {
429 let path = format!("/home/user/.aws/login/cache/{}", filename);
430 let actual = fs.read_to_end(&path).await.expect("cache file exists");
431 let actual: serde_json::Value =
432 serde_json::from_slice(&actual).expect("valid json");
433 assert_eq!(
435 expected.get("accessToken"),
436 actual.get("accessToken"),
437 "accessToken mismatch for {}",
438 filename
439 );
440 assert_eq!(
441 expected.get("refreshToken"),
442 actual.get("refreshToken"),
443 "refreshToken mismatch for {}",
444 filename
445 );
446 }
447 }
448 }
449 }
450 }
451 }
452
453 #[derive(Debug, Clone)]
454 struct TestHttpClient {
455 inner: SharedHttpConnector,
456 }
457
458 impl TestHttpClient {
459 fn new(mock_calls: &[MockApiCall]) -> Self {
460 Self {
461 inner: SharedHttpConnector::new(TestHttpConnector {
462 mock_calls: mock_calls.to_vec(),
463 }),
464 }
465 }
466 }
467
468 impl HttpClient for TestHttpClient {
469 fn http_connector(
470 &self,
471 _settings: &HttpConnectorSettings,
472 _components: &RuntimeComponents,
473 ) -> SharedHttpConnector {
474 self.inner.clone()
475 }
476 }
477
478 #[derive(Debug, Clone)]
479 struct TestHttpConnector {
480 mock_calls: Vec<MockApiCall>,
481 }
482
483 impl HttpConnector for TestHttpConnector {
484 fn call(&self, _request: HttpRequest) -> HttpConnectorFuture {
485 if let Some(mock) = self.mock_calls.first() {
486 if let Some(code) = mock.response_code {
487 return HttpConnectorFuture::ready(Ok(HttpResponse::new(
488 code.try_into().unwrap(),
489 SdkBody::from("{\"error\":\"refresh_failed\"}"),
490 )));
491 }
492 if let Some(resp) = &mock.response {
493 let body = format!(
494 r#"{{
495 "accessToken": {{
496 "accessKeyId": "{}",
497 "secretAccessKey": "{}",
498 "sessionToken": "{}"
499 }},
500 "expiresIn": {},
501 "refreshToken": "{}"
502 }}"#,
503 resp.token_output.access_token.access_key_id,
504 resp.token_output.access_token.secret_access_key,
505 resp.token_output.access_token.session_token,
506 resp.token_output.expires_in,
507 resp.token_output.refresh_token
508 );
509 return HttpConnectorFuture::ready(Ok(HttpResponse::new(
510 200.try_into().unwrap(),
511 SdkBody::from(body),
512 )));
513 }
514 }
515 HttpConnectorFuture::ready(Ok(HttpResponse::new(
516 500.try_into().unwrap(),
517 SdkBody::from("{\"error\":\"no_mock\"}"),
518 )))
519 }
520 }
521
522 #[tokio::test]
523 async fn run_login_tests() -> Result<(), Box<dyn Error>> {
524 let test_cases = std::fs::read_to_string("test-data/login-provider-test-cases.json")?;
525 let test_cases: Vec<LoginTestCase> = serde_json::from_str(&test_cases)?;
526
527 for (idx, test) in test_cases.iter().enumerate() {
528 println!("Running test {}: {}", idx, test.documentation);
529 test.check().await;
530 }
531 Ok(())
532 }
533}