aws_config/
login.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Credentials from an AWS Console session vended by AWS Sign-In.
7
8mod cache;
9/// Utils related to [RFC 9449: OAuth 2.0 Demonstrating Proof of Possession (DPoP)](https://datatracker.ietf.org/doc/html/rfc9449)
10mod dpop;
11mod token;
12
13use crate::login::cache::{load_cached_token, save_cached_token};
14use crate::login::token::{LoginToken, LoginTokenError};
15use crate::provider_config::ProviderConfig;
16use aws_credential_types::credential_feature::AwsCredentialFeature;
17use aws_credential_types::provider;
18use aws_credential_types::provider::future;
19use aws_credential_types::provider::ProvideCredentials;
20use aws_sdk_signin::config::Builder as SignInClientConfigBuilder;
21use aws_sdk_signin::operation::create_o_auth2_token::CreateOAuth2TokenError;
22use aws_sdk_signin::types::{CreateOAuth2TokenRequestBody, OAuth2ErrorCode};
23use aws_sdk_signin::Client as SignInClient;
24use aws_smithy_async::time::SharedTimeSource;
25use aws_smithy_runtime::expiring_cache::ExpiringCache;
26use aws_types::os_shim_internal::{Env, Fs};
27use aws_types::SdkConfig;
28use std::sync::Arc;
29use std::sync::Mutex;
30use std::time::Duration;
31use std::time::SystemTime;
32
33const REFRESH_BUFFER_TIME: Duration = Duration::from_secs(5 * 60 /* 5 minutes */);
34const MIN_TIME_BETWEEN_REFRESH: Duration = Duration::from_secs(30);
35pub(super) const PROVIDER_NAME: &str = "Login";
36
37/// AWS credentials provider vended by AWS Sign-In. This provider allows users to acquire and refresh
38/// AWS credentials that correspond to an AWS Console session.
39///
40/// See the [SDK developer guide](https://docs.aws.amazon.com/sdkref/latest/guide/access-login.html)
41/// for more information on getting started with console sessions and the AWS CLI.
42#[derive(Debug)]
43pub struct LoginCredentialsProvider {
44    inner: Arc<Inner>,
45    token_cache: ExpiringCache<LoginToken, LoginTokenError>,
46}
47
48#[derive(Debug)]
49struct Inner {
50    fs: Fs,
51    env: Env,
52    session_arn: String,
53    enabled_from_profile: bool,
54    sdk_config: SdkConfig,
55    time_source: SharedTimeSource,
56    last_refresh_attempt: Mutex<Option<SystemTime>>,
57}
58
59impl LoginCredentialsProvider {
60    /// Create a new [`Builder`] for the given login session ARN.
61    ///
62    /// The `session_arn` argument should take the form an Amazon Resource Name (ARN) like
63    ///
64    /// ```text
65    /// arn:aws:iam::0123456789012:user/Admin
66    /// ```
67    pub fn builder(session_arn: impl Into<String>) -> Builder {
68        Builder {
69            session_arn: session_arn.into(),
70            provider_config: None,
71            enabled_from_profile: false,
72        }
73    }
74
75    async fn resolve_token(&self) -> Result<LoginToken, LoginTokenError> {
76        let token_cache = self.token_cache.clone();
77        if let Some(token) = token_cache
78            .yield_or_clear_if_expired(self.inner.time_source.now())
79            .await
80        {
81            tracing::debug!("using cached Login token");
82            return Ok(token);
83        }
84
85        let inner = self.inner.clone();
86        let token = token_cache
87            .get_or_load(|| async move {
88                tracing::debug!("expiring cache asked for an updated Login token");
89                let mut token =
90                    load_cached_token(&inner.env, &inner.fs, &inner.session_arn).await?;
91
92                tracing::debug!("loaded cached Login token");
93
94                let now = inner.time_source.now();
95                let expired = token.expires_at() <= now;
96                let expires_soon = token.expires_at() - REFRESH_BUFFER_TIME <= now;
97                let last_refresh = *inner.last_refresh_attempt.lock().unwrap();
98                let min_time_passed = last_refresh
99                    .map(|lr| {
100                        now.duration_since(lr).expect("last_refresh is in the past")
101                            >= MIN_TIME_BETWEEN_REFRESH
102                    })
103                    .unwrap_or(true);
104
105                let refreshable = min_time_passed;
106
107                tracing::debug!(
108                    expired = ?expired,
109                    expires_soon = ?expires_soon,
110                    min_time_passed = ?min_time_passed,
111                    refreshable = ?refreshable,
112                    will_refresh = ?(expires_soon && refreshable),
113                    "cached Login token refresh decision"
114                );
115
116                // Fail fast if the token has expired and we can't refresh it
117                if expired && !refreshable {
118                    tracing::debug!("cached Login token is expired and cannot be refreshed");
119                    return Err(LoginTokenError::ExpiredToken);
120                }
121
122                // Refresh the token if it is going to expire soon
123                if expires_soon && refreshable {
124                    tracing::debug!("attempting to refresh Login token");
125                    let refreshed_token = Self::refresh_cached_token(&inner, &token, now).await?;
126                    token = refreshed_token;
127                    *inner.last_refresh_attempt.lock().unwrap() = Some(now);
128                }
129
130                let expires_at = token.expires_at();
131                Ok((token, expires_at))
132            })
133            .await?;
134
135        Ok(token)
136    }
137
138    async fn refresh_cached_token(
139        inner: &Inner,
140        cached_token: &LoginToken,
141        now: SystemTime,
142    ) -> Result<LoginToken, LoginTokenError> {
143        let dpop_auth_scheme = dpop::DPoPAuthScheme::new(&cached_token.dpop_key)?;
144        let client_config = SignInClientConfigBuilder::from(&inner.sdk_config)
145            .auth_scheme_resolver(dpop::DPoPAuthSchemeOptionResolver)
146            .push_auth_scheme(dpop_auth_scheme)
147            .build();
148
149        let client = SignInClient::from_conf(client_config);
150
151        let resp = client
152            .create_o_auth2_token()
153            .token_input(
154                CreateOAuth2TokenRequestBody::builder()
155                    .client_id(&cached_token.client_id)
156                    .grant_type("refresh_token")
157                    .refresh_token(cached_token.refresh_token.as_str())
158                    .build()
159                    .expect("valid CreateOAuth2TokenRequestBody"),
160            )
161            .send()
162            .await
163            .map_err(|err| {
164                let service_err = err.into_service_error();
165                let message = match &service_err {
166                    CreateOAuth2TokenError::AccessDeniedException(e) => match e.error {
167                        OAuth2ErrorCode::InsufficientPermissions => Some("Unable to refresh credentials due to insufficient permissions. You may be missing permission for the 'CreateOAuth2Token' action.".to_string()),
168                        OAuth2ErrorCode::TokenExpired => Some("Your session has expired. Please reauthenticate.".to_string()),
169                        OAuth2ErrorCode::UserCredentialsChanged => Some("Unable to refresh credentials because of a change in your password. Please reauthenticate with your new password.".to_string()),
170                        _ => None,
171                    }
172                    _ => None,
173                };
174
175                LoginTokenError::RefreshFailed {
176                    message,
177                    source: service_err.into(),
178                }
179            })?;
180
181        let token_output = resp.token_output.expect("valid token response");
182        let new_token = LoginToken::from_refresh(cached_token, token_output, now);
183
184        match save_cached_token(&inner.env, &inner.fs, &inner.session_arn, &new_token).await {
185            Ok(_) => {}
186            Err(e) => tracing::warn!("failed to save refreshed Login token: {e}"),
187        }
188        Ok(new_token)
189    }
190
191    async fn credentials(&self) -> provider::Result {
192        let token = self.resolve_token().await?;
193
194        let feat = match self.inner.enabled_from_profile {
195            true => AwsCredentialFeature::CredentialsProfileLogin,
196            false => AwsCredentialFeature::CredentialsProfile,
197        };
198
199        let mut creds = token.access_token;
200        creds
201            .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
202            .push(feat);
203        Ok(creds)
204    }
205}
206
207impl ProvideCredentials for LoginCredentialsProvider {
208    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
209    where
210        Self: 'a,
211    {
212        future::ProvideCredentials::new(self.credentials())
213    }
214}
215
216/// Builder for [`LoginCredentialsProvider`]
217#[derive(Debug)]
218pub struct Builder {
219    session_arn: String,
220    provider_config: Option<ProviderConfig>,
221    enabled_from_profile: bool,
222}
223
224impl Builder {
225    /// Override the configuration used for this provider
226    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
227        self.provider_config = Some(provider_config.clone());
228        self
229    }
230
231    /// Set whether this provider was enabled via a profile.
232    /// Defaults to `false` (configured explicitly in user code).
233    pub(crate) fn enabled_from_profile(mut self, enabled: bool) -> Self {
234        self.enabled_from_profile = enabled;
235        self
236    }
237
238    /// Construct a [`LoginCredentialsProvider`] from the builder
239    pub fn build(self) -> LoginCredentialsProvider {
240        let provider_config = self.provider_config.unwrap_or_default();
241        let fs = provider_config.fs();
242        let env = provider_config.env();
243        let inner = Arc::new(Inner {
244            fs,
245            env,
246            session_arn: self.session_arn,
247            enabled_from_profile: self.enabled_from_profile,
248            sdk_config: provider_config.client_config(),
249            time_source: provider_config.time_source(),
250            last_refresh_attempt: Mutex::new(None),
251        });
252
253        LoginCredentialsProvider {
254            inner,
255            token_cache: ExpiringCache::new(REFRESH_BUFFER_TIME),
256        }
257    }
258}
259
260#[cfg(test)]
261mod test {
262    //! Test suite for LoginCredentialsProvider
263    //!
264    //! This test module reads test cases from `test-data/login-provider-test-cases.json`
265    //! and validates the behavior of the LoginCredentialsProvider against various scenarios
266    //! from the SEP.
267    use super::*;
268    use crate::provider_config::ProviderConfig;
269    use aws_credential_types::provider::ProvideCredentials;
270    use aws_sdk_signin::config::RuntimeComponents;
271    use aws_smithy_async::rt::sleep::TokioSleep;
272    use aws_smithy_async::time::{SharedTimeSource, StaticTimeSource};
273    use aws_smithy_runtime_api::client::{
274        http::{
275            HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings,
276            SharedHttpConnector,
277        },
278        orchestrator::{HttpRequest, HttpResponse},
279    };
280    use aws_smithy_types::body::SdkBody;
281    use aws_types::os_shim_internal::{Env, Fs};
282    use aws_types::region::Region;
283    use serde::Deserialize;
284    use std::collections::HashMap;
285    use std::error::Error;
286    use std::time::{Duration, UNIX_EPOCH};
287
288    #[derive(Deserialize, Debug)]
289    #[serde(rename_all = "camelCase")]
290    struct LoginTestCase {
291        documentation: String,
292        config_contents: String,
293        cache_contents: HashMap<String, serde_json::Value>,
294        #[serde(default)]
295        mock_api_calls: Vec<MockApiCall>,
296        outcomes: Vec<Outcome>,
297    }
298
299    #[derive(Deserialize, Debug, Clone)]
300    #[serde(rename_all = "camelCase")]
301    struct MockApiCall {
302        #[serde(default)]
303        response: Option<MockResponse>,
304        #[serde(default)]
305        response_code: Option<u16>,
306    }
307
308    #[derive(Deserialize, Debug, Clone)]
309    #[serde(rename_all = "camelCase")]
310    struct MockResponse {
311        token_output: TokenOutput,
312    }
313
314    #[derive(Deserialize, Debug, Clone)]
315    #[serde(rename_all = "camelCase")]
316    struct TokenOutput {
317        access_token: AccessToken,
318        refresh_token: String,
319        expires_in: u64,
320    }
321
322    #[derive(Deserialize, Debug, Clone)]
323    #[serde(rename_all = "camelCase")]
324    struct AccessToken {
325        access_key_id: String,
326        secret_access_key: String,
327        session_token: String,
328    }
329
330    #[derive(Deserialize, Debug)]
331    #[serde(tag = "result")]
332    enum Outcome {
333        #[serde(rename = "credentials")]
334        Credentials {
335            #[serde(rename = "accessKeyId")]
336            access_key_id: String,
337            #[serde(rename = "secretAccessKey")]
338            secret_access_key: String,
339            #[serde(rename = "sessionToken")]
340            session_token: String,
341            #[serde(rename = "accountId")]
342            account_id: String,
343            #[serde(default, rename = "expiresAt")]
344            #[allow(dead_code)]
345            expires_at: Option<String>,
346        },
347        #[serde(rename = "error")]
348        Error,
349        #[serde(rename = "cacheContents")]
350        CacheContents(HashMap<String, serde_json::Value>),
351    }
352
353    impl LoginTestCase {
354        async fn check(&self) {
355            let session_arn = "arn:aws:sts::012345678910:assumed-role/Admin/admin";
356
357            // Fixed time for testing: 2025-11-19T00:00:00Z
358            let now = UNIX_EPOCH + Duration::from_secs(1763510400);
359            let time_source = SharedTimeSource::new(StaticTimeSource::new(now));
360
361            // Setup filesystem with cache and config contents
362            let mut fs_map = HashMap::new();
363            fs_map.insert(
364                "/home/user/.aws/config".to_string(),
365                self.config_contents.as_bytes().to_vec(),
366            );
367            for (filename, contents) in &self.cache_contents {
368                let path = format!("/home/user/.aws/login/cache/{}", filename);
369                // Add tokenType if missing (required by cache parser)
370                let mut contents = contents.clone();
371                if !contents.as_object().unwrap().contains_key("tokenType") {
372                    contents.as_object_mut().unwrap().insert(
373                        "tokenType".to_string(),
374                        serde_json::Value::String("aws_sigv4".to_string()),
375                    );
376                }
377                let json = serde_json::to_string(&contents).expect("valid json");
378                fs_map.insert(path, json.into_bytes());
379            }
380            let fs = Fs::from_map(fs_map);
381
382            let env = Env::from_slice(&[("HOME", "/home/user")]);
383
384            // Setup mock HTTP client
385            let http_client = if self.mock_api_calls.is_empty() {
386                crate::test_case::no_traffic_client()
387            } else {
388                aws_smithy_runtime_api::client::http::SharedHttpClient::new(TestHttpClient::new(
389                    &self.mock_api_calls,
390                ))
391            };
392
393            let provider_config = ProviderConfig::empty()
394                .with_env(env.clone())
395                .with_fs(fs.clone())
396                .with_http_client(http_client)
397                .with_region(Some(Region::from_static("us-east-2")))
398                .with_sleep_impl(TokioSleep::new())
399                .with_time_source(time_source);
400
401            let provider = LoginCredentialsProvider::builder(session_arn)
402                .configure(&provider_config)
403                .build();
404
405            // Call provider once and validate result against all outcomes
406            let result = dbg!(provider.provide_credentials().await);
407
408            for outcome in &self.outcomes {
409                match outcome {
410                    Outcome::Credentials {
411                        access_key_id,
412                        secret_access_key,
413                        session_token,
414                        account_id,
415                        expires_at: _,
416                    } => {
417                        let creds = result.as_ref().expect("credentials should succeed");
418                        assert_eq!(access_key_id, creds.access_key_id());
419                        assert_eq!(secret_access_key, creds.secret_access_key());
420                        assert_eq!(session_token, creds.session_token().unwrap());
421                        assert_eq!(account_id, creds.account_id().unwrap().as_str());
422                    }
423                    Outcome::Error => {
424                        result.as_ref().expect_err("should fail");
425                    }
426                    Outcome::CacheContents(expected_cache) => {
427                        // Verify cache was updated after provider call
428                        for (filename, expected) in expected_cache {
429                            let path = format!("/home/user/.aws/login/cache/{}", filename);
430                            let actual = fs.read_to_end(&path).await.expect("cache file exists");
431                            let actual: serde_json::Value =
432                                serde_json::from_slice(&actual).expect("valid json");
433                            // Compare only the fields that matter (ignore formatting differences)
434                            assert_eq!(
435                                expected.get("accessToken"),
436                                actual.get("accessToken"),
437                                "accessToken mismatch for {}",
438                                filename
439                            );
440                            assert_eq!(
441                                expected.get("refreshToken"),
442                                actual.get("refreshToken"),
443                                "refreshToken mismatch for {}",
444                                filename
445                            );
446                        }
447                    }
448                }
449            }
450        }
451    }
452
453    #[derive(Debug, Clone)]
454    struct TestHttpClient {
455        inner: SharedHttpConnector,
456    }
457
458    impl TestHttpClient {
459        fn new(mock_calls: &[MockApiCall]) -> Self {
460            Self {
461                inner: SharedHttpConnector::new(TestHttpConnector {
462                    mock_calls: mock_calls.to_vec(),
463                }),
464            }
465        }
466    }
467
468    impl HttpClient for TestHttpClient {
469        fn http_connector(
470            &self,
471            _settings: &HttpConnectorSettings,
472            _components: &RuntimeComponents,
473        ) -> SharedHttpConnector {
474            self.inner.clone()
475        }
476    }
477
478    #[derive(Debug, Clone)]
479    struct TestHttpConnector {
480        mock_calls: Vec<MockApiCall>,
481    }
482
483    impl HttpConnector for TestHttpConnector {
484        fn call(&self, _request: HttpRequest) -> HttpConnectorFuture {
485            if let Some(mock) = self.mock_calls.first() {
486                if let Some(code) = mock.response_code {
487                    return HttpConnectorFuture::ready(Ok(HttpResponse::new(
488                        code.try_into().unwrap(),
489                        SdkBody::from("{\"error\":\"refresh_failed\"}"),
490                    )));
491                }
492                if let Some(resp) = &mock.response {
493                    let body = format!(
494                        r#"{{
495                            "accessToken": {{
496                                "accessKeyId": "{}",
497                                "secretAccessKey": "{}",
498                                "sessionToken": "{}"
499                            }},
500                            "expiresIn": {},
501                            "refreshToken": "{}"
502                        }}"#,
503                        resp.token_output.access_token.access_key_id,
504                        resp.token_output.access_token.secret_access_key,
505                        resp.token_output.access_token.session_token,
506                        resp.token_output.expires_in,
507                        resp.token_output.refresh_token
508                    );
509                    return HttpConnectorFuture::ready(Ok(HttpResponse::new(
510                        200.try_into().unwrap(),
511                        SdkBody::from(body),
512                    )));
513                }
514            }
515            HttpConnectorFuture::ready(Ok(HttpResponse::new(
516                500.try_into().unwrap(),
517                SdkBody::from("{\"error\":\"no_mock\"}"),
518            )))
519        }
520    }
521
522    #[tokio::test]
523    async fn run_login_tests() -> Result<(), Box<dyn Error>> {
524        let test_cases = std::fs::read_to_string("test-data/login-provider-test-cases.json")?;
525        let test_cases: Vec<LoginTestCase> = serde_json::from_str(&test_cases)?;
526
527        for (idx, test) in test_cases.iter().enumerate() {
528            println!("Running test {}: {}", idx, test.documentation);
529            test.check().await;
530        }
531        Ok(())
532    }
533}