aws_config/
web_identity_token.rs1use crate::provider_config::ProviderConfig;
65use crate::sts;
66use aws_credential_types::credential_feature::AwsCredentialFeature;
67use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
68use aws_sdk_sts::{types::PolicyDescriptorType, Client as StsClient};
69use aws_smithy_async::time::SharedTimeSource;
70use aws_smithy_types::error::display::DisplayErrorContext;
71use aws_types::os_shim_internal::{Env, Fs};
72
73use std::borrow::Cow;
74use std::path::{Path, PathBuf};
75
76const ENV_VAR_TOKEN_FILE: &str = "AWS_WEB_IDENTITY_TOKEN_FILE";
77const ENV_VAR_ROLE_ARN: &str = "AWS_ROLE_ARN";
78const ENV_VAR_SESSION_NAME: &str = "AWS_ROLE_SESSION_NAME";
79
80#[derive(Debug)]
84pub struct WebIdentityTokenCredentialsProvider {
85    source: Source,
86    time_source: SharedTimeSource,
87    fs: Fs,
88    sts_client: StsClient,
89    policy: Option<String>,
90    policy_arns: Option<Vec<PolicyDescriptorType>>,
91}
92
93impl WebIdentityTokenCredentialsProvider {
94    pub fn builder() -> Builder {
96        Builder::default()
97    }
98}
99
100#[derive(Debug)]
101enum Source {
102    Env(Env),
103    Static(StaticConfiguration),
104}
105
106#[derive(Debug, Clone)]
108pub struct StaticConfiguration {
109    pub web_identity_token_file: PathBuf,
111
112    pub role_arn: String,
114
115    pub session_name: String,
117}
118
119impl ProvideCredentials for WebIdentityTokenCredentialsProvider {
120    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
121    where
122        Self: 'a,
123    {
124        future::ProvideCredentials::new(self.credentials())
125    }
126}
127
128impl WebIdentityTokenCredentialsProvider {
129    fn source(&self) -> Result<Cow<'_, StaticConfiguration>, CredentialsError> {
130        match &self.source {
131            Source::Env(env) => {
132                let token_file = env.get(ENV_VAR_TOKEN_FILE).map_err(|_| {
133                    CredentialsError::not_loaded(format!("${} was not set", ENV_VAR_TOKEN_FILE))
134                })?;
135                let role_arn = env.get(ENV_VAR_ROLE_ARN).map_err(|_| {
136                    CredentialsError::invalid_configuration(
137                        "AWS_ROLE_ARN environment variable must be set",
138                    )
139                })?;
140                let session_name = env.get(ENV_VAR_SESSION_NAME).unwrap_or_else(|_| {
141                    sts::util::default_session_name("web-identity-token", self.time_source.now())
142                });
143                Ok(Cow::Owned(StaticConfiguration {
144                    web_identity_token_file: token_file.into(),
145                    role_arn,
146                    session_name,
147                }))
148            }
149            Source::Static(conf) => Ok(Cow::Borrowed(conf)),
150        }
151    }
152    async fn credentials(&self) -> provider::Result {
153        let conf = self.source()?;
154        load_credentials(
155            &self.fs,
156            &self.sts_client,
157            self.policy.clone(),
158            self.policy_arns.clone(),
159            &conf.web_identity_token_file,
160            &conf.role_arn,
161            &conf.session_name,
162        )
163        .await
164        .map(|mut creds| {
165            creds
166                .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
167                .push(AwsCredentialFeature::CredentialsProfileStsWebIdToken);
168            creds
169        })
170    }
171}
172
173#[derive(Debug, Default)]
175pub struct Builder {
176    source: Option<Source>,
177    config: Option<ProviderConfig>,
178    policy: Option<String>,
179    policy_arns: Option<Vec<PolicyDescriptorType>>,
180}
181
182impl Builder {
183    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
196        self.config = Some(provider_config.clone());
197        self
198    }
199
200    pub fn static_configuration(mut self, config: StaticConfiguration) -> Self {
206        self.source = Some(Source::Static(config));
207        self
208    }
209
210    pub fn policy(mut self, policy: impl Into<String>) -> Self {
216        self.policy = Some(policy.into());
217        self
218    }
219
220    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
226        self.policy_arns = Some(
227            policy_arns
228                .into_iter()
229                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
230                .collect::<Vec<_>>(),
231        );
232        self
233    }
234
235    pub fn build(self) -> WebIdentityTokenCredentialsProvider {
241        let conf = self.config.unwrap_or_default();
242        let source = self.source.unwrap_or_else(|| Source::Env(conf.env()));
243        WebIdentityTokenCredentialsProvider {
244            source,
245            fs: conf.fs(),
246            sts_client: StsClient::new(&conf.client_config()),
247            time_source: conf.time_source(),
248            policy: self.policy,
249            policy_arns: self.policy_arns,
250        }
251    }
252}
253
254async fn load_credentials(
255    fs: &Fs,
256    sts_client: &StsClient,
257    policy: Option<String>,
258    policy_arns: Option<Vec<PolicyDescriptorType>>,
259    token_file: impl AsRef<Path>,
260    role_arn: &str,
261    session_name: &str,
262) -> provider::Result {
263    let token = fs
264        .read_to_end(token_file)
265        .await
266        .map_err(CredentialsError::provider_error)?;
267    let token = String::from_utf8(token).map_err(|_utf_8_error| {
268        CredentialsError::unhandled("WebIdentityToken was not valid UTF-8")
269    })?;
270
271    let resp = sts_client.assume_role_with_web_identity()
272        .role_arn(role_arn)
273        .role_session_name(session_name)
274        .set_policy(policy)
275        .set_policy_arns(policy_arns)
276        .web_identity_token(token)
277        .send()
278        .await
279        .map_err(|sdk_error| {
280            tracing::warn!(error = %DisplayErrorContext(&sdk_error), "STS returned an error assuming web identity role");
281            CredentialsError::provider_error(sdk_error)
282        })?;
283    sts::util::into_credentials(resp.credentials, resp.assumed_role_user, "WebIdentityToken")
284}
285
286#[cfg(test)]
287mod test {
288    use crate::provider_config::ProviderConfig;
289    use crate::test_case::no_traffic_client;
290    use crate::web_identity_token::{
291        Builder, ENV_VAR_ROLE_ARN, ENV_VAR_SESSION_NAME, ENV_VAR_TOKEN_FILE,
292    };
293    use aws_credential_types::provider::error::CredentialsError;
294    use aws_smithy_async::rt::sleep::TokioSleep;
295    use aws_smithy_types::error::display::DisplayErrorContext;
296    use aws_types::os_shim_internal::{Env, Fs};
297    use aws_types::region::Region;
298    use std::collections::HashMap;
299
300    #[tokio::test]
301    async fn unloaded_provider() {
302        let conf = ProviderConfig::empty()
304            .with_sleep_impl(TokioSleep::new())
305            .with_env(Env::from_slice(&[]))
306            .with_http_client(no_traffic_client())
307            .with_region(Some(Region::from_static("us-east-1")));
308
309        let provider = Builder::default().configure(&conf).build();
310        let err = provider
311            .credentials()
312            .await
313            .expect_err("should fail, provider not loaded");
314        match err {
315            CredentialsError::CredentialsNotLoaded { .. } => { }
316            _ => panic!("incorrect error variant"),
317        }
318    }
319
320    #[tokio::test]
321    async fn missing_env_var() {
322        let env = Env::from_slice(&[(ENV_VAR_TOKEN_FILE, "/token.jwt")]);
323        let region = Some(Region::new("us-east-1"));
324        let provider = Builder::default()
325            .configure(
326                &ProviderConfig::empty()
327                    .with_sleep_impl(TokioSleep::new())
328                    .with_region(region)
329                    .with_env(env)
330                    .with_http_client(no_traffic_client()),
331            )
332            .build();
333        let err = provider
334            .credentials()
335            .await
336            .expect_err("should fail, provider not loaded");
337        assert!(
338            format!("{}", DisplayErrorContext(&err)).contains("AWS_ROLE_ARN"),
339            "`{}` did not contain expected string",
340            err
341        );
342        match err {
343            CredentialsError::InvalidConfiguration { .. } => { }
344            _ => panic!("incorrect error variant"),
345        }
346    }
347
348    #[tokio::test]
349    async fn fs_missing_file() {
350        let env = Env::from_slice(&[
351            (ENV_VAR_TOKEN_FILE, "/token.jwt"),
352            (ENV_VAR_ROLE_ARN, "arn:aws:iam::123456789123:role/test-role"),
353            (ENV_VAR_SESSION_NAME, "test-session"),
354        ]);
355        let fs = Fs::from_raw_map(HashMap::new());
356        let provider = Builder::default()
357            .configure(
358                &ProviderConfig::empty()
359                    .with_sleep_impl(TokioSleep::new())
360                    .with_http_client(no_traffic_client())
361                    .with_region(Some(Region::new("us-east-1")))
362                    .with_env(env)
363                    .with_fs(fs),
364            )
365            .build();
366        let err = provider.credentials().await.expect_err("no JWT token");
367        match err {
368            CredentialsError::ProviderError { .. } => { }
369            _ => panic!("incorrect error variant"),
370        }
371    }
372}