aws_smithy_http_client/client/tls/
rustls_provider.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5use crate::client::tls::Provider;
6use rustls::crypto::CryptoProvider;
7
8/// Choice of underlying cryptography library (this only applies to rustls)
9#[derive(Debug, Eq, PartialEq, Clone)]
10#[non_exhaustive]
11pub enum CryptoMode {
12    /// Crypto based on [ring](https://github.com/briansmith/ring)
13    #[cfg(feature = "rustls-ring")]
14    Ring,
15    /// Crypto based on [aws-lc](https://github.com/aws/aws-lc-rs)
16    #[cfg(feature = "rustls-aws-lc")]
17    AwsLc,
18    /// FIPS compliant variant of [aws-lc](https://github.com/aws/aws-lc-rs)
19    #[cfg(feature = "rustls-aws-lc-fips")]
20    AwsLcFips,
21}
22
23impl CryptoMode {
24    fn provider(self) -> CryptoProvider {
25        match self {
26            #[cfg(feature = "rustls-aws-lc")]
27            CryptoMode::AwsLc => rustls::crypto::aws_lc_rs::default_provider(),
28
29            #[cfg(feature = "rustls-ring")]
30            CryptoMode::Ring => rustls::crypto::ring::default_provider(),
31
32            #[cfg(feature = "rustls-aws-lc-fips")]
33            CryptoMode::AwsLcFips => {
34                let provider = rustls::crypto::default_fips_provider();
35                assert!(
36                    provider.fips(),
37                    "FIPS was requested but the provider did not support FIPS"
38                );
39                provider
40            }
41        }
42    }
43}
44
45impl Provider {
46    /// Create a TLS provider based on [rustls](https://github.com/rustls/rustls)
47    /// and the given [`CryptoMode`]
48    pub fn rustls(mode: CryptoMode) -> Provider {
49        Provider::Rustls(mode)
50    }
51}
52
53pub(crate) mod build_connector {
54    use crate::client::tls::rustls_provider::CryptoMode;
55    use crate::tls::TlsContext;
56    use client::connect::HttpConnector;
57    use hyper_util::client::legacy as client;
58    use rustls::crypto::CryptoProvider;
59    use rustls_native_certs::CertificateResult;
60    use rustls_pki_types::pem::PemObject;
61    use rustls_pki_types::CertificateDer;
62    use std::sync::Arc;
63    use std::sync::LazyLock;
64
65    /// Cached native certificates
66    ///
67    /// Creating a `with_native_roots()` hyper_rustls client re-loads system certs
68    /// each invocation (which can take 300ms on OSx). Cache the loaded certs
69    /// to avoid repeatedly incurring that cost.
70    pub(crate) static NATIVE_ROOTS: LazyLock<Vec<CertificateDer<'static>>> = LazyLock::new(|| {
71        let CertificateResult { certs, errors, .. } = rustls_native_certs::load_native_certs();
72        if !errors.is_empty() {
73            tracing::warn!("native root CA certificate loading errors: {errors:?}")
74        }
75
76        if certs.is_empty() {
77            tracing::warn!("no native root CA certificates found!");
78        }
79
80        // NOTE: unlike hyper-rustls::with_native_roots we don't validate here, we'll do that later
81        // for now we have a collection of certs that may or may not be valid.
82        certs
83    });
84
85    pub(crate) fn restrict_ciphers(base: CryptoProvider) -> CryptoProvider {
86        let suites = &[
87            rustls::CipherSuite::TLS13_AES_256_GCM_SHA384,
88            rustls::CipherSuite::TLS13_AES_128_GCM_SHA256,
89            // TLS1.2 suites
90            rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
91            rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
92            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
93            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
94            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
95        ];
96        let supported_suites = suites
97            .iter()
98            .flat_map(|suite| {
99                base.cipher_suites
100                    .iter()
101                    .find(|s| &s.suite() == suite)
102                    .cloned()
103            })
104            .collect::<Vec<_>>();
105        CryptoProvider {
106            cipher_suites: supported_suites,
107            ..base
108        }
109    }
110
111    impl TlsContext {
112        pub(crate) fn rustls_root_certs(&self) -> rustls::RootCertStore {
113            let mut roots = rustls::RootCertStore::empty();
114            if self.trust_store.enable_native_roots {
115                let (valid, _invalid) = roots.add_parsable_certificates(NATIVE_ROOTS.clone());
116                debug_assert!(valid > 0, "TrustStore configured to enable native roots but no valid root certificates parsed!");
117            }
118
119            for pem_cert in &self.trust_store.custom_certs {
120                let ders = CertificateDer::pem_slice_iter(&pem_cert.0)
121                    .collect::<Result<Vec<_>, _>>()
122                    .expect("valid PEM certificate");
123                for cert in ders {
124                    roots.add(cert).expect("cert parsable")
125                }
126            }
127
128            roots
129        }
130    }
131
132    /// Create a rustls ClientConfig with smithy-rs defaults
133    ///
134    /// This centralizes the rustls ClientConfig creation logic to ensure
135    /// consistency between the main HTTPS connector and tunnel handlers.
136    pub(crate) fn create_rustls_client_config(
137        crypto_mode: CryptoMode,
138        tls_context: &TlsContext,
139    ) -> rustls::ClientConfig {
140        let root_certs = tls_context.rustls_root_certs();
141        rustls::ClientConfig::builder_with_provider(Arc::new(restrict_ciphers(crypto_mode.provider())))
142            .with_safe_default_protocol_versions()
143            .expect("Error with the TLS configuration. Please file a bug report under https://github.com/smithy-lang/smithy-rs/issues.")
144            .with_root_certificates(root_certs)
145            .with_no_client_auth()
146    }
147
148    pub(crate) fn wrap_connector<R>(
149        mut conn: HttpConnector<R>,
150        crypto_mode: CryptoMode,
151        tls_context: &TlsContext,
152        proxy_config: crate::client::proxy::ProxyConfig,
153    ) -> super::connect::RustTlsConnector<R> {
154        let client_config = create_rustls_client_config(crypto_mode, tls_context);
155        conn.enforce_http(false);
156        let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
157            .with_tls_config(client_config.clone())
158            .https_or_http()
159            .enable_http1()
160            .enable_http2()
161            .wrap_connector(conn);
162
163        super::connect::RustTlsConnector::new(https_connector, client_config, proxy_config)
164    }
165}
166
167pub(crate) mod connect {
168    use crate::client::connect::{Conn, Connecting};
169    use crate::client::proxy::ProxyConfig;
170    use aws_smithy_runtime_api::box_error::BoxError;
171    use http_1x::uri::Scheme;
172    use http_1x::Uri;
173    use hyper::rt::{Read, ReadBufCursor, Write};
174    use hyper_rustls::MaybeHttpsStream;
175    use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
176    use hyper_util::client::proxy::matcher::Matcher;
177    use hyper_util::rt::TokioIo;
178    use pin_project_lite::pin_project;
179    use std::error::Error;
180    use std::sync::Arc;
181    use std::{
182        io::{self, IoSlice},
183        pin::Pin,
184        task::{Context, Poll},
185    };
186    use tokio::io::{AsyncRead, AsyncWrite};
187    use tokio::net::TcpStream;
188    use tokio_rustls::client::TlsStream;
189    use tower::Service;
190
191    #[derive(Debug, Clone)]
192    pub(crate) struct RustTlsConnector<R> {
193        https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
194        tls_config: Arc<rustls::ClientConfig>,
195        proxy_matcher: Option<Arc<Matcher>>, // Pre-computed for performance
196    }
197
198    impl<R> RustTlsConnector<R> {
199        pub(super) fn new(
200            https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
201            tls_config: rustls::ClientConfig,
202            proxy_config: ProxyConfig,
203        ) -> Self {
204            // Pre-compute the proxy matcher once during construction
205            let proxy_matcher = if proxy_config.is_disabled() {
206                None
207            } else {
208                Some(Arc::new(proxy_config.into_hyper_util_matcher()))
209            };
210
211            Self {
212                https,
213                tls_config: Arc::new(tls_config),
214                proxy_matcher,
215            }
216        }
217    }
218
219    impl<R> Service<Uri> for RustTlsConnector<R>
220    where
221        R: Clone + Send + Sync + 'static,
222        R: Service<hyper_util::client::legacy::connect::dns::Name>,
223        R::Response: Iterator<Item = std::net::SocketAddr>,
224        R::Future: Send,
225        R::Error: Into<Box<dyn Error + Send + Sync>>,
226    {
227        type Response = Conn;
228        type Error = BoxError;
229        type Future = Connecting;
230
231        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
232            self.https.poll_ready(cx).map_err(Into::into)
233        }
234
235        fn call(&mut self, dst: Uri) -> Self::Future {
236            // Check if this request should be proxied using pre-computed matcher
237            let proxy_intercept = if let Some(ref matcher) = self.proxy_matcher {
238                matcher.intercept(&dst)
239            } else {
240                None
241            };
242
243            if let Some(intercept) = proxy_intercept {
244                if dst.scheme() == Some(&Scheme::HTTPS) {
245                    // HTTPS through HTTP proxy: Use CONNECT tunneling + manual TLS
246                    self.handle_https_through_proxy(dst, intercept)
247                } else {
248                    // HTTP through proxy: Direct connection to proxy
249                    self.handle_http_through_proxy(dst, intercept)
250                }
251            } else {
252                // Direct connection: Use the existing HTTPS connector
253                self.handle_direct_connection(dst)
254            }
255        }
256    }
257
258    impl<R> RustTlsConnector<R>
259    where
260        R: Clone + Send + Sync + 'static,
261        R: Service<hyper_util::client::legacy::connect::dns::Name>,
262        R::Response: Iterator<Item = std::net::SocketAddr>,
263        R::Future: Send,
264        R::Error: Into<Box<dyn Error + Send + Sync>>,
265    {
266        fn handle_direct_connection(&mut self, dst: Uri) -> Connecting {
267            let fut = self.https.call(dst);
268            Box::pin(async move {
269                let conn = fut.await?;
270                Ok(Conn {
271                    inner: Box::new(conn),
272                    is_proxy: false,
273                })
274            })
275        }
276
277        fn handle_http_through_proxy(
278            &mut self,
279            _dst: Uri,
280            intercept: hyper_util::client::proxy::matcher::Intercept,
281        ) -> Connecting {
282            // For HTTP through proxy, connect to the proxy and let it handle the request
283            let proxy_uri = intercept.uri().clone();
284            let fut = self.https.call(proxy_uri);
285            Box::pin(async move {
286                let conn = fut.await?;
287                Ok(Conn {
288                    inner: Box::new(conn),
289                    is_proxy: true,
290                })
291            })
292        }
293
294        fn handle_https_through_proxy(
295            &mut self,
296            dst: Uri,
297            intercept: hyper_util::client::proxy::matcher::Intercept,
298        ) -> Connecting {
299            use rustls_pki_types::ServerName;
300            // For HTTPS through HTTP proxy, we need to:
301            // 1. Establish CONNECT tunnel using the HTTPS connector
302            // 2. Perform manual TLS handshake over the tunneled stream
303
304            let tunnel = hyper_util::client::legacy::connect::proxy::Tunnel::new(
305                intercept.uri().clone(),
306                self.https.clone(),
307            );
308
309            // Configure tunnel with authentication if present
310            let mut tunnel = if let Some(auth) = intercept.basic_auth() {
311                tunnel.with_auth(auth.clone())
312            } else {
313                tunnel
314            };
315
316            let tls_config = self.tls_config.clone();
317            let dst_clone = dst.clone();
318
319            Box::pin(async move {
320                // Establish CONNECT tunnel
321                tracing::trace!("tunneling HTTPS over proxy");
322                let tunneled = tunnel
323                    .call(dst_clone.clone())
324                    .await
325                    .map_err(|e| BoxError::from(format!("CONNECT tunnel failed: {}", e)))?;
326
327                // Stage 2: Manual TLS handshake over tunneled stream
328                let host = dst_clone
329                    .host()
330                    .ok_or("missing host in URI for TLS handshake")?;
331
332                let server_name = ServerName::try_from(host.to_owned()).map_err(|e| {
333                    BoxError::from(format!("invalid server name for TLS handshake: {}", e))
334                })?;
335
336                let tls_connector = tokio_rustls::TlsConnector::from(tls_config)
337                    .connect(server_name, TokioIo::new(tunneled))
338                    .await?;
339
340                Ok(Conn {
341                    inner: Box::new(RustTlsConn {
342                        inner: TokioIo::new(tls_connector),
343                    }),
344                    is_proxy: true,
345                })
346            })
347        }
348    }
349
350    pin_project! {
351        pub(crate) struct RustTlsConn<T> {
352            #[pin] pub(super) inner: TokioIo<TlsStream<T>>
353        }
354    }
355
356    impl Connection for RustTlsConn<TokioIo<TokioIo<TcpStream>>> {
357        fn connected(&self) -> Connected {
358            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
359                self.inner
360                    .inner()
361                    .get_ref()
362                    .0
363                    .inner()
364                    .connected()
365                    .negotiated_h2()
366            } else {
367                self.inner.inner().get_ref().0.inner().connected()
368            }
369        }
370    }
371
372    impl Connection for RustTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
373        fn connected(&self) -> Connected {
374            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
375                self.inner
376                    .inner()
377                    .get_ref()
378                    .0
379                    .inner()
380                    .connected()
381                    .negotiated_h2()
382            } else {
383                self.inner.inner().get_ref().0.inner().connected()
384            }
385        }
386    }
387    impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustTlsConn<T> {
388        fn poll_read(
389            self: Pin<&mut Self>,
390            cx: &mut Context<'_>,
391            buf: ReadBufCursor<'_>,
392        ) -> Poll<tokio::io::Result<()>> {
393            let this = self.project();
394            Read::poll_read(this.inner, cx, buf)
395        }
396    }
397
398    impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustTlsConn<T> {
399        fn poll_write(
400            self: Pin<&mut Self>,
401            cx: &mut Context<'_>,
402            buf: &[u8],
403        ) -> Poll<Result<usize, tokio::io::Error>> {
404            let this = self.project();
405            Write::poll_write(this.inner, cx, buf)
406        }
407
408        fn poll_write_vectored(
409            self: Pin<&mut Self>,
410            cx: &mut Context<'_>,
411            bufs: &[IoSlice<'_>],
412        ) -> Poll<Result<usize, io::Error>> {
413            let this = self.project();
414            Write::poll_write_vectored(this.inner, cx, bufs)
415        }
416
417        fn is_write_vectored(&self) -> bool {
418            self.inner.is_write_vectored()
419        }
420
421        fn poll_flush(
422            self: Pin<&mut Self>,
423            cx: &mut Context<'_>,
424        ) -> Poll<Result<(), tokio::io::Error>> {
425            let this = self.project();
426            Write::poll_flush(this.inner, cx)
427        }
428
429        fn poll_shutdown(
430            self: Pin<&mut Self>,
431            cx: &mut Context<'_>,
432        ) -> Poll<Result<(), tokio::io::Error>> {
433            let this = self.project();
434            Write::poll_shutdown(this.inner, cx)
435        }
436    }
437}