aws_smithy_http_client/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6mod timeout;
7/// TLS connector(s)
8pub mod tls;
9
10use crate::cfg::cfg_tls;
11use crate::tls::TlsContext;
12use aws_smithy_async::future::timeout::TimedOutError;
13use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep, SharedAsyncSleep};
14use aws_smithy_runtime_api::box_error::BoxError;
15use aws_smithy_runtime_api::client::connection::CaptureSmithyConnection;
16use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
17use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
18use aws_smithy_runtime_api::client::http::{
19    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpClient,
20    SharedHttpConnector,
21};
22use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
23use aws_smithy_runtime_api::client::result::ConnectorError;
24use aws_smithy_runtime_api::client::runtime_components::{
25    RuntimeComponents, RuntimeComponentsBuilder,
26};
27use aws_smithy_runtime_api::shared::IntoShared;
28use aws_smithy_types::body::SdkBody;
29use aws_smithy_types::config_bag::ConfigBag;
30use aws_smithy_types::error::display::DisplayErrorContext;
31use aws_smithy_types::retry::ErrorKind;
32use client::connect::Connection;
33use h2::Reason;
34use http_1x::{Extensions, Uri};
35use hyper::rt::{Read, Write};
36use hyper_util::client::legacy as client;
37use hyper_util::client::legacy::connect::dns::GaiResolver;
38use hyper_util::client::legacy::connect::{
39    capture_connection, CaptureConnection, Connect, HttpConnector as HyperHttpConnector, HttpInfo,
40};
41use hyper_util::rt::TokioExecutor;
42use std::borrow::Cow;
43use std::collections::HashMap;
44use std::error::Error;
45use std::fmt;
46use std::sync::RwLock;
47use std::time::Duration;
48
49/// Given `HttpConnectorSettings` and an `SharedAsyncSleep`, create a `SharedHttpConnector` from defaults depending on what cargo features are activated.
50pub fn default_connector(
51    settings: &HttpConnectorSettings,
52    sleep: Option<SharedAsyncSleep>,
53) -> Option<SharedHttpConnector> {
54    #[cfg(feature = "rustls-aws-lc")]
55    {
56        tracing::trace!(settings = ?settings, sleep = ?sleep, "creating a new default connector");
57        let mut conn_builder = Connector::builder().connector_settings(settings.clone());
58
59        if let Some(sleep) = sleep {
60            conn_builder = conn_builder.sleep_impl(sleep);
61        }
62
63        let conn = conn_builder
64            .tls_provider(tls::Provider::Rustls(
65                tls::rustls_provider::CryptoMode::AwsLc,
66            ))
67            .build();
68        Some(SharedHttpConnector::new(conn))
69    }
70    #[cfg(not(feature = "rustls-aws-lc"))]
71    {
72        tracing::trace!(settings = ?settings, sleep = ?sleep, "no default connector available");
73        None
74    }
75}
76
77/// [`HttpConnector`] used to make HTTP requests.
78///
79/// This connector also implements socket connect and read timeouts.
80///
81/// This shouldn't be used directly in most cases.
82/// See the docs on [`Builder`] for examples of how to customize the HTTP client.
83#[derive(Debug)]
84pub struct Connector {
85    adapter: Box<dyn HttpConnector>,
86}
87
88impl Connector {
89    /// Builder for an HTTP connector.
90    pub fn builder() -> ConnectorBuilder {
91        ConnectorBuilder {
92            enable_tcp_nodelay: true,
93            ..Default::default()
94        }
95    }
96}
97
98impl HttpConnector for Connector {
99    fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
100        self.adapter.call(request)
101    }
102}
103
104/// Builder for [`Connector`].
105#[derive(Default, Debug)]
106pub struct ConnectorBuilder<Tls = TlsUnset> {
107    connector_settings: Option<HttpConnectorSettings>,
108    sleep_impl: Option<SharedAsyncSleep>,
109    client_builder: Option<hyper_util::client::legacy::Builder>,
110    enable_tcp_nodelay: bool,
111    interface: Option<String>,
112    #[allow(unused)]
113    tls: Tls,
114}
115
116/// Initial builder state, `TlsProvider` choice required
117#[derive(Default)]
118#[non_exhaustive]
119pub struct TlsUnset {}
120
121/// TLS implementation selected
122pub struct TlsProviderSelected {
123    #[allow(unused)]
124    provider: tls::Provider,
125    #[allow(unused)]
126    context: TlsContext,
127}
128
129impl ConnectorBuilder<TlsUnset> {
130    /// Set the TLS implementation to use for this connector
131    pub fn tls_provider(self, provider: tls::Provider) -> ConnectorBuilder<TlsProviderSelected> {
132        ConnectorBuilder {
133            connector_settings: self.connector_settings,
134            sleep_impl: self.sleep_impl,
135            client_builder: self.client_builder,
136            enable_tcp_nodelay: self.enable_tcp_nodelay,
137            interface: self.interface,
138            tls: TlsProviderSelected {
139                provider,
140                context: TlsContext::default(),
141            },
142        }
143    }
144
145    /// Build an HTTP connector sans TLS
146    #[doc(hidden)]
147    pub fn build_http(self) -> Connector {
148        let base = self.base_connector();
149        self.wrap_connector(base)
150    }
151}
152
153impl<Any> ConnectorBuilder<Any> {
154    /// Create a [`Connector`] from this builder and a given connector.
155    pub(crate) fn wrap_connector<C>(self, tcp_connector: C) -> Connector
156    where
157        C: Send + Sync + 'static,
158        C: Clone,
159        C: tower::Service<Uri>,
160        C::Response: Read + Write + Connection + Send + Sync + Unpin,
161        C: Connect,
162        C::Future: Unpin + Send + 'static,
163        C::Error: Into<BoxError>,
164    {
165        let client_builder =
166            self.client_builder
167                .unwrap_or(hyper_util::client::legacy::Builder::new(
168                    TokioExecutor::new(),
169                ));
170        let sleep_impl = self.sleep_impl.or_else(default_async_sleep);
171        let (connect_timeout, read_timeout) = self
172            .connector_settings
173            .map(|c| (c.connect_timeout(), c.read_timeout()))
174            .unwrap_or((None, None));
175
176        let connector = match connect_timeout {
177            Some(duration) => timeout::ConnectTimeout::new(
178                tcp_connector,
179                sleep_impl
180                    .clone()
181                    .expect("a sleep impl must be provided in order to have a connect timeout"),
182                duration,
183            ),
184            None => timeout::ConnectTimeout::no_timeout(tcp_connector),
185        };
186        let base = client_builder.build(connector);
187        let read_timeout = match read_timeout {
188            Some(duration) => timeout::HttpReadTimeout::new(
189                base,
190                sleep_impl.expect("a sleep impl must be provided in order to have a read timeout"),
191                duration,
192            ),
193            None => timeout::HttpReadTimeout::no_timeout(base),
194        };
195        Connector {
196            adapter: Box::new(Adapter {
197                client: read_timeout,
198            }),
199        }
200    }
201
202    /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
203    /// (which is a base TCP connector with no TLS or any wrapping)
204    fn base_connector(&self) -> HyperHttpConnector {
205        self.base_connector_with_resolver(GaiResolver::new())
206    }
207
208    /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
209    /// using the given resolver `R`
210    fn base_connector_with_resolver<R>(&self, resolver: R) -> HyperHttpConnector<R> {
211        let mut conn = HyperHttpConnector::new_with_resolver(resolver);
212        conn.set_nodelay(self.enable_tcp_nodelay);
213        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
214        if let Some(interface) = &self.interface {
215            conn.set_interface(interface);
216        }
217        conn
218    }
219
220    /// Set the async sleep implementation used for timeouts
221    ///
222    /// Calling this is only necessary for testing or to use something other than
223    /// [`default_async_sleep`].
224    pub fn sleep_impl(mut self, sleep_impl: impl AsyncSleep + 'static) -> Self {
225        self.sleep_impl = Some(sleep_impl.into_shared());
226        self
227    }
228
229    /// Set the async sleep implementation used for timeouts
230    ///
231    /// Calling this is only necessary for testing or to use something other than
232    /// [`default_async_sleep`].
233    pub fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) -> &mut Self {
234        self.sleep_impl = sleep_impl;
235        self
236    }
237
238    /// Configure the HTTP settings for the `HyperAdapter`
239    pub fn connector_settings(mut self, connector_settings: HttpConnectorSettings) -> Self {
240        self.connector_settings = Some(connector_settings);
241        self
242    }
243
244    /// Configure the HTTP settings for the `HyperAdapter`
245    pub fn set_connector_settings(
246        &mut self,
247        connector_settings: Option<HttpConnectorSettings>,
248    ) -> &mut Self {
249        self.connector_settings = connector_settings;
250        self
251    }
252
253    /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
254    pub fn enable_tcp_nodelay(mut self, nodelay: bool) -> Self {
255        self.enable_tcp_nodelay = nodelay;
256        self
257    }
258
259    /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
260    pub fn set_enable_tcp_nodelay(&mut self, nodelay: bool) -> &mut Self {
261        self.enable_tcp_nodelay = nodelay;
262        self
263    }
264
265    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
266    ///
267    /// If a socket is bound to an interface, only packets received from that particular
268    /// interface are processed by the socket. Note that this only works for some socket
269    /// types (e.g. `AF_INET` sockets).
270    ///
271    /// On Linux it can be used to specify a [VRF], but the binary needs to either have
272    /// `CAP_NET_RAW` capability set or be run as root.
273    ///
274    /// This function is only available on Android, Fuchsia, and Linux.
275    ///
276    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
277    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
278    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
279        self.interface = Some(interface.into());
280        self
281    }
282
283    /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
284    ///
285    /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
286    pub(crate) fn hyper_builder(
287        mut self,
288        hyper_builder: hyper_util::client::legacy::Builder,
289    ) -> Self {
290        self.set_hyper_builder(Some(hyper_builder));
291        self
292    }
293
294    /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
295    ///
296    /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
297    pub(crate) fn set_hyper_builder(
298        &mut self,
299        hyper_builder: Option<hyper_util::client::legacy::Builder>,
300    ) -> &mut Self {
301        self.client_builder = hyper_builder;
302        self
303    }
304}
305
306/// Adapter to use a Hyper 1.0-based Client as an `HttpConnector`
307///
308/// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`Connector::builder`].
309struct Adapter<C> {
310    client: timeout::HttpReadTimeout<
311        hyper_util::client::legacy::Client<timeout::ConnectTimeout<C>, SdkBody>,
312    >,
313}
314
315impl<C> fmt::Debug for Adapter<C> {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        f.debug_struct("Adapter")
318            .field("client", &"** hyper client **")
319            .finish()
320    }
321}
322
323/// Extract a smithy connection from a hyper CaptureConnection
324fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
325    let capture_conn = capture_conn.clone();
326    if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
327        let mut extensions = Extensions::new();
328        conn.get_extras(&mut extensions);
329        let http_info = extensions.get::<HttpInfo>();
330        let mut builder = ConnectionMetadata::builder()
331            .proxied(conn.is_proxied())
332            .poison_fn(move || match capture_conn.connection_metadata().as_ref() {
333                Some(conn) => conn.poison(),
334                None => tracing::trace!("no connection existed to poison"),
335            });
336
337        builder
338            .set_local_addr(http_info.map(|info| info.local_addr()))
339            .set_remote_addr(http_info.map(|info| info.remote_addr()));
340
341        let smithy_connection = builder.build();
342
343        Some(smithy_connection)
344    } else {
345        None
346    }
347}
348
349impl<C> HttpConnector for Adapter<C>
350where
351    C: Clone + Send + Sync + 'static,
352    C: tower::Service<Uri>,
353    C::Response: Connection + Read + Write + Unpin + 'static,
354    timeout::ConnectTimeout<C>: Connect,
355    C::Future: Unpin + Send + 'static,
356    C::Error: Into<BoxError>,
357{
358    fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
359        let mut request = match request.try_into_http1x() {
360            Ok(request) => request,
361            Err(err) => {
362                return HttpConnectorFuture::ready(Err(ConnectorError::user(err.into())));
363            }
364        };
365        let capture_connection = capture_connection(&mut request);
366        if let Some(capture_smithy_connection) =
367            request.extensions().get::<CaptureSmithyConnection>()
368        {
369            capture_smithy_connection
370                .set_connection_retriever(move || extract_smithy_connection(&capture_connection));
371        }
372        let mut client = self.client.clone();
373        use tower::Service;
374        let fut = client.call(request);
375        HttpConnectorFuture::new(async move {
376            let response = fut
377                .await
378                .map_err(downcast_error)?
379                .map(SdkBody::from_body_1_x);
380            match HttpResponse::try_from(response) {
381                Ok(response) => Ok(response),
382                Err(err) => Err(ConnectorError::other(err.into(), None)),
383            }
384        })
385    }
386}
387
388/// Downcast errors coming out of hyper into an appropriate `ConnectorError`
389fn downcast_error(err: BoxError) -> ConnectorError {
390    // is a `TimedOutError` (from aws_smithy_async::timeout) in the chain? if it is, this is a timeout
391    if find_source::<TimedOutError>(err.as_ref()).is_some() {
392        return ConnectorError::timeout(err);
393    }
394    // is the top of chain error actually already a `ConnectorError`? return that directly
395    let err = match err.downcast::<ConnectorError>() {
396        Ok(connector_error) => return *connector_error,
397        Err(box_error) => box_error,
398    };
399    // generally, the top of chain will probably be a hyper error. Go through a set of hyper specific
400    // error classifications
401    let err = match find_source::<hyper::Error>(err.as_ref()) {
402        Some(hyper_error) => return to_connector_error(hyper_error)(err),
403        None => err,
404    };
405
406    // otherwise, we have no idea!
407    ConnectorError::other(err, None)
408}
409
410/// Convert a [`hyper::Error`] into a [`ConnectorError`]
411fn to_connector_error(err: &hyper::Error) -> fn(BoxError) -> ConnectorError {
412    if err.is_timeout() || find_source::<timeout::HttpTimeoutError>(err).is_some() {
413        return ConnectorError::timeout;
414    }
415    if err.is_user() {
416        return ConnectorError::user;
417    }
418    if err.is_closed() || err.is_canceled() || find_source::<std::io::Error>(err).is_some() {
419        return ConnectorError::io;
420    }
421    // We sometimes receive this from S3: hyper::Error(IncompleteMessage)
422    if err.is_incomplete_message() {
423        return |err: BoxError| ConnectorError::other(err, Some(ErrorKind::TransientError));
424    }
425
426    if let Some(h2_err) = find_source::<h2::Error>(err) {
427        if h2_err.is_go_away()
428            || (h2_err.is_reset() && h2_err.reason() == Some(Reason::REFUSED_STREAM))
429        {
430            return ConnectorError::io;
431        }
432    }
433
434    tracing::warn!(err = %DisplayErrorContext(&err), "unrecognized error from Hyper. If this error should be retried, please file an issue.");
435    |err: BoxError| ConnectorError::other(err, None)
436}
437
438fn find_source<'a, E: Error + 'static>(err: &'a (dyn Error + 'static)) -> Option<&'a E> {
439    let mut next = Some(err);
440    while let Some(err) = next {
441        if let Some(matching_err) = err.downcast_ref::<E>() {
442            return Some(matching_err);
443        }
444        next = err.source();
445    }
446    None
447}
448
449// TODO(https://github.com/awslabs/aws-sdk-rust/issues/1090): CacheKey must also include ptr equality to any
450// runtime components that are used—sleep_impl as a base (unless we prohibit overriding sleep impl)
451// If we decide to put a DnsResolver in RuntimeComponents, then we'll need to handle that as well.
452#[derive(Clone, Debug, Eq, PartialEq, Hash)]
453struct CacheKey {
454    connect_timeout: Option<Duration>,
455    read_timeout: Option<Duration>,
456}
457
458impl From<&HttpConnectorSettings> for CacheKey {
459    fn from(value: &HttpConnectorSettings) -> Self {
460        Self {
461            connect_timeout: value.connect_timeout(),
462            read_timeout: value.read_timeout(),
463        }
464    }
465}
466
467struct HyperClient<F> {
468    connector_cache: RwLock<HashMap<CacheKey, SharedHttpConnector>>,
469    client_builder: hyper_util::client::legacy::Builder,
470    connector_fn: F,
471}
472
473impl<F> fmt::Debug for HyperClient<F> {
474    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475        f.debug_struct("HyperClient")
476            .field("connector_cache", &self.connector_cache)
477            .field("client_builder", &self.client_builder)
478            .finish()
479    }
480}
481
482impl<F> HttpClient for HyperClient<F>
483where
484    F: Fn(
485            hyper_util::client::legacy::Builder,
486            Option<&HttpConnectorSettings>,
487            Option<&RuntimeComponents>,
488        ) -> Connector
489        + Send
490        + Sync
491        + 'static,
492{
493    fn http_connector(
494        &self,
495        settings: &HttpConnectorSettings,
496        components: &RuntimeComponents,
497    ) -> SharedHttpConnector {
498        let key = CacheKey::from(settings);
499        let mut connector = self.connector_cache.read().unwrap().get(&key).cloned();
500        if connector.is_none() {
501            let mut cache = self.connector_cache.write().unwrap();
502            // Short-circuit if another thread already wrote a connector to the cache for this key
503            if !cache.contains_key(&key) {
504                let start = components.time_source().map(|ts| ts.now());
505                let connector = (self.connector_fn)(
506                    self.client_builder.clone(),
507                    Some(settings),
508                    Some(components),
509                );
510                let end = components.time_source().map(|ts| ts.now());
511                if let (Some(start), Some(end)) = (start, end) {
512                    if let Ok(elapsed) = end.duration_since(start) {
513                        tracing::debug!("new connector created in {:?}", elapsed);
514                    }
515                }
516                let connector = SharedHttpConnector::new(connector);
517                cache.insert(key.clone(), connector);
518            }
519            connector = cache.get(&key).cloned();
520        }
521
522        connector.expect("cache populated above")
523    }
524
525    fn validate_base_client_config(
526        &self,
527        _: &RuntimeComponentsBuilder,
528        _: &ConfigBag,
529    ) -> Result<(), BoxError> {
530        // Initialize the TCP connector at this point so that native certs load
531        // at client initialization time instead of upon first request. We do it
532        // here rather than at construction so that it won't run if this is not
533        // the selected HTTP client for the base config (for example, if this was
534        // the default HTTP client, and it was overridden by a later plugin).
535        let _ = (self.connector_fn)(self.client_builder.clone(), None, None);
536        Ok(())
537    }
538
539    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
540        Some(ConnectorMetadata::new("hyper", Some(Cow::Borrowed("1.x"))))
541    }
542}
543
544/// Builder for a hyper-backed [`HttpClient`] implementation.
545///
546/// This builder can be used to customize the underlying TCP connector used, as well as
547/// hyper client configuration.
548///
549/// # Examples
550///
551/// Construct a Hyper client with the RusTLS TLS implementation.
552/// This can be useful when you want to share a Hyper connector between multiple
553/// generated Smithy clients.
554#[derive(Clone, Default, Debug)]
555pub struct Builder<Tls = TlsUnset> {
556    client_builder: Option<hyper_util::client::legacy::Builder>,
557    #[allow(unused)]
558    tls_provider: Tls,
559}
560
561cfg_tls! {
562    mod dns;
563    use aws_smithy_runtime_api::client::dns::ResolveDns;
564
565    impl ConnectorBuilder<TlsProviderSelected> {
566        /// Build a [`Connector`] that will use the default DNS resolver implementation.
567        pub fn build(self) -> Connector {
568            let http_connector = self.base_connector();
569            self.build_https(http_connector)
570        }
571
572        /// Configure the TLS context
573        pub fn tls_context(mut self, ctx: TlsContext) -> Self {
574            self.tls.context = ctx;
575            self
576        }
577
578        /// Configure the TLS context
579        pub fn set_tls_context(&mut self, ctx: TlsContext) -> &mut Self {
580            self.tls.context = ctx;
581            self
582        }
583
584        /// Build a [`Connector`] that will use the given DNS resolver implementation.
585        pub fn build_with_resolver<R: ResolveDns + Clone + 'static>(self, resolver: R) -> Connector {
586            use crate::client::dns::HyperUtilResolver;
587            let http_connector = self.base_connector_with_resolver(HyperUtilResolver { resolver });
588            self.build_https(http_connector)
589        }
590
591        fn build_https<R>(self, http_connector: HyperHttpConnector<R>) -> Connector
592        where
593            R: Clone + Send + Sync + 'static,
594            R: tower::Service<hyper_util::client::legacy::connect::dns::Name>,
595            R::Response: Iterator<Item = std::net::SocketAddr>,
596            R::Future: Send,
597            R::Error: Into<Box<dyn Error + Send + Sync>>,
598        {
599            match &self.tls.provider {
600                // TODO(hyper1) - fix cfg_rustls! to allow matching on patterns so we can re-use it and not duplicate these cfg matches everywhere
601                #[cfg(any(
602                    feature = "rustls-aws-lc",
603                    feature = "rustls-aws-lc-fips",
604                    feature = "rustls-ring"
605                ))]
606                tls::Provider::Rustls(crypto_mode) => {
607                    let https_connector = tls::rustls_provider::build_connector::wrap_connector(
608                        http_connector,
609                        crypto_mode.clone(),
610                        &self.tls.context,
611                    );
612                    self.wrap_connector(https_connector)
613                },
614                #[cfg(feature = "s2n-tls")]
615                tls::Provider::S2nTls  => {
616                    let https_connector = tls::s2n_tls_provider::build_connector::wrap_connector(http_connector, &self.tls.context);
617                    self.wrap_connector(https_connector)
618                }
619            }
620        }
621    }
622
623    impl Builder<TlsProviderSelected> {
624        /// Create an HTTPS client with the selected TLS provider.
625        ///
626        /// The trusted certificates will be loaded later when this becomes the selected
627        /// HTTP client for a Smithy client.
628        pub fn build_https(self) -> SharedHttpClient {
629            build_with_conn_fn(
630                self.client_builder,
631                move |client_builder, settings, runtime_components| {
632                    let builder = new_conn_builder(client_builder, settings, runtime_components)
633                        .tls_provider(self.tls_provider.provider.clone())
634                        .tls_context(self.tls_provider.context.clone());
635                    builder.build()
636                },
637            )
638        }
639
640        /// Create an HTTPS client using a custom DNS resolver
641        pub fn build_with_resolver(
642            self,
643            resolver: impl ResolveDns + Clone + 'static,
644        ) -> SharedHttpClient {
645            build_with_conn_fn(
646                self.client_builder,
647                move |client_builder, settings, runtime_components| {
648                    let builder = new_conn_builder(client_builder, settings, runtime_components)
649                        .tls_provider(self.tls_provider.provider.clone())
650                        .tls_context(self.tls_provider.context.clone());
651                    builder.build_with_resolver(resolver.clone())
652                },
653            )
654        }
655
656        /// Configure the TLS context
657        pub fn tls_context(mut self, ctx: TlsContext) -> Self {
658            self.tls_provider.context = ctx;
659            self
660        }
661    }
662}
663
664impl Builder<TlsUnset> {
665    /// Creates a new builder.
666    pub fn new() -> Self {
667        Self::default()
668    }
669
670    /// Build a new HTTP client without TLS enabled
671    #[doc(hidden)]
672    pub fn build_http(self) -> SharedHttpClient {
673        build_with_conn_fn(
674            self.client_builder,
675            move |client_builder, settings, runtime_components| {
676                let builder = new_conn_builder(client_builder, settings, runtime_components);
677                builder.build_http()
678            },
679        )
680    }
681
682    /// Set the TLS implementation to use
683    pub fn tls_provider(self, provider: tls::Provider) -> Builder<TlsProviderSelected> {
684        Builder {
685            client_builder: self.client_builder,
686            tls_provider: TlsProviderSelected {
687                provider,
688                context: TlsContext::default(),
689            },
690        }
691    }
692}
693
694pub(crate) fn build_with_conn_fn<F>(
695    client_builder: Option<hyper_util::client::legacy::Builder>,
696    connector_fn: F,
697) -> SharedHttpClient
698where
699    F: Fn(
700            hyper_util::client::legacy::Builder,
701            Option<&HttpConnectorSettings>,
702            Option<&RuntimeComponents>,
703        ) -> Connector
704        + Send
705        + Sync
706        + 'static,
707{
708    SharedHttpClient::new(HyperClient {
709        connector_cache: RwLock::new(HashMap::new()),
710        client_builder: client_builder
711            .unwrap_or_else(|| hyper_util::client::legacy::Builder::new(TokioExecutor::new())),
712        connector_fn,
713    })
714}
715
716#[allow(dead_code)]
717pub(crate) fn build_with_tcp_conn_fn<C, F>(
718    client_builder: Option<hyper_util::client::legacy::Builder>,
719    tcp_connector_fn: F,
720) -> SharedHttpClient
721where
722    F: Fn() -> C + Send + Sync + 'static,
723    C: Clone + Send + Sync + 'static,
724    C: tower::Service<Uri>,
725    C::Response: Connection + Read + Write + Send + Sync + Unpin + 'static,
726    C::Future: Unpin + Send + 'static,
727    C::Error: Into<BoxError>,
728    C: Connect,
729{
730    build_with_conn_fn(
731        client_builder,
732        move |client_builder, settings, runtime_components| {
733            let builder = new_conn_builder(client_builder, settings, runtime_components);
734            builder.wrap_connector(tcp_connector_fn())
735        },
736    )
737}
738
739fn new_conn_builder(
740    client_builder: hyper_util::client::legacy::Builder,
741    settings: Option<&HttpConnectorSettings>,
742    runtime_components: Option<&RuntimeComponents>,
743) -> ConnectorBuilder {
744    let mut builder = Connector::builder().hyper_builder(client_builder);
745    builder.set_connector_settings(settings.cloned());
746    if let Some(components) = runtime_components {
747        builder.set_sleep_impl(components.sleep_impl());
748    }
749    builder
750}
751
752#[cfg(test)]
753mod test {
754    use std::io::{Error, ErrorKind};
755    use std::pin::Pin;
756    use std::sync::atomic::{AtomicU32, Ordering};
757    use std::sync::Arc;
758    use std::task::{Context, Poll};
759
760    use aws_smithy_async::assert_elapsed;
761    use aws_smithy_async::rt::sleep::TokioSleep;
762    use aws_smithy_async::time::SystemTimeSource;
763    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
764    use http_1x::Uri;
765    use hyper::rt::ReadBufCursor;
766    use hyper_util::client::legacy::connect::Connected;
767
768    use crate::client::timeout::test::NeverConnects;
769
770    use super::*;
771
772    #[tokio::test]
773    async fn connector_selection() {
774        // Create a client that increments a count every time it creates a new Connector
775        let creation_count = Arc::new(AtomicU32::new(0));
776        let http_client = build_with_tcp_conn_fn(None, {
777            let count = creation_count.clone();
778            move || {
779                count.fetch_add(1, Ordering::Relaxed);
780                NeverConnects
781            }
782        });
783
784        // This configuration should result in 4 separate connectors with different timeout settings
785        let settings = [
786            HttpConnectorSettings::builder()
787                .connect_timeout(Duration::from_secs(3))
788                .build(),
789            HttpConnectorSettings::builder()
790                .read_timeout(Duration::from_secs(3))
791                .build(),
792            HttpConnectorSettings::builder()
793                .connect_timeout(Duration::from_secs(3))
794                .read_timeout(Duration::from_secs(3))
795                .build(),
796            HttpConnectorSettings::builder()
797                .connect_timeout(Duration::from_secs(5))
798                .read_timeout(Duration::from_secs(3))
799                .build(),
800        ];
801
802        // Kick off thousands of parallel tasks that will try to create a connector
803        let components = RuntimeComponentsBuilder::for_tests()
804            .with_time_source(Some(SystemTimeSource::new()))
805            .build()
806            .unwrap();
807        let mut handles = Vec::new();
808        for setting in &settings {
809            for _ in 0..1000 {
810                let client = http_client.clone();
811                handles.push(tokio::spawn({
812                    let setting = setting.clone();
813                    let components = components.clone();
814                    async move {
815                        let _ = client.http_connector(&setting, &components);
816                    }
817                }));
818            }
819        }
820        for handle in handles {
821            handle.await.unwrap();
822        }
823
824        // Verify only 4 connectors were created amidst the chaos
825        assert_eq!(4, creation_count.load(Ordering::Relaxed));
826    }
827
828    #[tokio::test]
829    async fn hyper_io_error() {
830        let connector = TestConnection {
831            inner: HangupStream,
832        };
833        let adapter = Connector::builder().wrap_connector(connector).adapter;
834        let err = adapter
835            .call(HttpRequest::get("https://socket-hangup.com").unwrap())
836            .await
837            .expect_err("socket hangup");
838        assert!(err.is_io(), "unexpected error type: {:?}", err);
839    }
840
841    // ---- machinery to make a Hyper connector that responds with an IO Error
842    #[derive(Clone)]
843    struct HangupStream;
844
845    impl Connection for HangupStream {
846        fn connected(&self) -> Connected {
847            Connected::new()
848        }
849    }
850
851    impl Read for HangupStream {
852        fn poll_read(
853            self: Pin<&mut Self>,
854            _cx: &mut Context<'_>,
855            _buf: ReadBufCursor<'_>,
856        ) -> Poll<std::io::Result<()>> {
857            Poll::Ready(Err(Error::new(
858                ErrorKind::ConnectionReset,
859                "connection reset",
860            )))
861        }
862    }
863
864    impl Write for HangupStream {
865        fn poll_write(
866            self: Pin<&mut Self>,
867            _cx: &mut Context<'_>,
868            _buf: &[u8],
869        ) -> Poll<Result<usize, Error>> {
870            Poll::Pending
871        }
872
873        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
874            Poll::Pending
875        }
876
877        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
878            Poll::Pending
879        }
880    }
881
882    #[derive(Clone)]
883    struct TestConnection<T> {
884        inner: T,
885    }
886
887    impl<T> tower::Service<Uri> for TestConnection<T>
888    where
889        T: Clone + Connection,
890    {
891        type Response = T;
892        type Error = BoxError;
893        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
894
895        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
896            Poll::Ready(Ok(()))
897        }
898
899        fn call(&mut self, _req: Uri) -> Self::Future {
900            std::future::ready(Ok(self.inner.clone()))
901        }
902    }
903
904    #[tokio::test]
905    async fn http_connect_timeout_works() {
906        let tcp_connector = NeverConnects::default();
907        let connector_settings = HttpConnectorSettings::builder()
908            .connect_timeout(Duration::from_secs(1))
909            .build();
910        let hyper = Connector::builder()
911            .connector_settings(connector_settings)
912            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
913            .wrap_connector(tcp_connector)
914            .adapter;
915        let now = tokio::time::Instant::now();
916        tokio::time::pause();
917        let resp = hyper
918            .call(HttpRequest::get("https://static-uri.com").unwrap())
919            .await
920            .unwrap_err();
921        assert!(
922            resp.is_timeout(),
923            "expected resp.is_timeout() to be true but it was false, resp == {:?}",
924            resp
925        );
926        let message = DisplayErrorContext(&resp).to_string();
927        let expected = "timeout: client error (Connect): HTTP connect timeout occurred after 1s";
928        assert!(
929            message.contains(expected),
930            "expected '{message}' to contain '{expected}'"
931        );
932        assert_elapsed!(now, Duration::from_secs(1));
933    }
934
935    #[tokio::test]
936    async fn http_read_timeout_works() {
937        let tcp_connector = crate::client::timeout::test::NeverReplies;
938        let connector_settings = HttpConnectorSettings::builder()
939            .connect_timeout(Duration::from_secs(1))
940            .read_timeout(Duration::from_secs(2))
941            .build();
942        let hyper = Connector::builder()
943            .connector_settings(connector_settings)
944            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
945            .wrap_connector(tcp_connector)
946            .adapter;
947        let now = tokio::time::Instant::now();
948        tokio::time::pause();
949        let err = hyper
950            .call(HttpRequest::get("https://fake-uri.com").unwrap())
951            .await
952            .unwrap_err();
953        assert!(
954            err.is_timeout(),
955            "expected err.is_timeout() to be true but it was false, err == {err:?}",
956        );
957        let message = format!("{}", DisplayErrorContext(&err));
958        let expected = "timeout: HTTP read timeout occurred after 2s";
959        assert!(
960            message.contains(expected),
961            "expected '{message}' to contain '{expected}'"
962        );
963        assert_elapsed!(now, Duration::from_secs(2));
964    }
965
966    #[cfg(feature = "s2n-tls")]
967    #[tokio::test]
968    async fn s2n_tls_provider() {
969        // Create an HttpConnector with the s2n-tls provider.
970        let client = Builder::new()
971            .tls_provider(tls::Provider::S2nTls)
972            .build_https();
973        let connector_settings = HttpConnectorSettings::builder().build();
974
975        // HyperClient::http_connector invokes TimeSource::now to determine how long it takes to
976        // create new HttpConnectors. As such, a real time source must be provided.
977        let runtime_components = RuntimeComponentsBuilder::for_tests()
978            .with_time_source(Some(SystemTimeSource::new()))
979            .build()
980            .unwrap();
981
982        let connector = client.http_connector(&connector_settings, &runtime_components);
983
984        // Ensure that s2n-tls is used as the underlying TLS provider when selected.
985        //
986        // s2n-tls-hyper will error when given an invalid scheme. Ensure that this error is produced
987        // from s2n-tls-hyper, and not another TLS provider.
988        let error = connector
989            .call(HttpRequest::get("notascheme://amazon.com").unwrap())
990            .await
991            .unwrap_err();
992        let error = error.into_source();
993        let s2n_error = error
994            .source()
995            .unwrap()
996            .downcast_ref::<s2n_tls_hyper::error::Error>()
997            .unwrap();
998        assert!(matches!(
999            s2n_error,
1000            s2n_tls_hyper::error::Error::InvalidScheme
1001        ));
1002    }
1003}