aws_smithy_http_client/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6mod dns;
7mod timeout;
8/// TLS connector(s)
9pub mod tls;
10
11use crate::cfg::cfg_tls;
12use crate::tls::TlsContext;
13use aws_smithy_async::future::timeout::TimedOutError;
14use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep, SharedAsyncSleep};
15use aws_smithy_runtime_api::box_error::BoxError;
16use aws_smithy_runtime_api::client::connection::CaptureSmithyConnection;
17use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
18use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
19use aws_smithy_runtime_api::client::http::{
20    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpClient,
21    SharedHttpConnector,
22};
23use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
24use aws_smithy_runtime_api::client::result::ConnectorError;
25use aws_smithy_runtime_api::client::runtime_components::{
26    RuntimeComponents, RuntimeComponentsBuilder,
27};
28use aws_smithy_runtime_api::shared::IntoShared;
29use aws_smithy_types::body::SdkBody;
30use aws_smithy_types::config_bag::ConfigBag;
31use aws_smithy_types::error::display::DisplayErrorContext;
32use aws_smithy_types::retry::ErrorKind;
33use client::connect::Connection;
34use h2::Reason;
35use http_1x::{Extensions, Uri};
36use hyper::rt::{Read, Write};
37use hyper_util::client::legacy as client;
38use hyper_util::client::legacy::connect::dns::GaiResolver;
39use hyper_util::client::legacy::connect::{
40    capture_connection, CaptureConnection, Connect, HttpConnector as HyperHttpConnector, HttpInfo,
41};
42use hyper_util::rt::TokioExecutor;
43use std::borrow::Cow;
44use std::collections::HashMap;
45use std::error::Error;
46use std::fmt;
47use std::sync::RwLock;
48use std::time::Duration;
49
50/// Given `HttpConnectorSettings` and an `SharedAsyncSleep`, create a `SharedHttpConnector` from defaults depending on what cargo features are activated.
51pub fn default_connector(
52    settings: &HttpConnectorSettings,
53    sleep: Option<SharedAsyncSleep>,
54) -> Option<SharedHttpConnector> {
55    #[cfg(feature = "rustls-aws-lc")]
56    {
57        tracing::trace!(settings = ?settings, sleep = ?sleep, "creating a new default connector");
58        let mut conn_builder = Connector::builder().connector_settings(settings.clone());
59
60        if let Some(sleep) = sleep {
61            conn_builder = conn_builder.sleep_impl(sleep);
62        }
63
64        let conn = conn_builder
65            .tls_provider(tls::Provider::Rustls(
66                tls::rustls_provider::CryptoMode::AwsLc,
67            ))
68            .build();
69        Some(SharedHttpConnector::new(conn))
70    }
71    #[cfg(not(feature = "rustls-aws-lc"))]
72    {
73        tracing::trace!(settings = ?settings, sleep = ?sleep, "no default connector available");
74        None
75    }
76}
77
78/// [`HttpConnector`] used to make HTTP requests.
79///
80/// This connector also implements socket connect and read timeouts.
81///
82/// This shouldn't be used directly in most cases.
83/// See the docs on [`Builder`] for examples of how to customize the HTTP client.
84#[derive(Debug)]
85pub struct Connector {
86    adapter: Box<dyn HttpConnector>,
87}
88
89impl Connector {
90    /// Builder for an HTTP connector.
91    pub fn builder() -> ConnectorBuilder {
92        ConnectorBuilder {
93            enable_tcp_nodelay: true,
94            ..Default::default()
95        }
96    }
97}
98
99impl HttpConnector for Connector {
100    fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
101        self.adapter.call(request)
102    }
103}
104
105/// Builder for [`Connector`].
106#[derive(Default, Debug)]
107pub struct ConnectorBuilder<Tls = TlsUnset> {
108    connector_settings: Option<HttpConnectorSettings>,
109    sleep_impl: Option<SharedAsyncSleep>,
110    client_builder: Option<hyper_util::client::legacy::Builder>,
111    enable_tcp_nodelay: bool,
112    interface: Option<String>,
113    #[allow(unused)]
114    tls: Tls,
115}
116
117/// Initial builder state, `TlsProvider` choice required
118#[derive(Default)]
119#[non_exhaustive]
120pub struct TlsUnset {}
121
122/// TLS implementation selected
123pub struct TlsProviderSelected {
124    #[allow(unused)]
125    provider: tls::Provider,
126    #[allow(unused)]
127    context: TlsContext,
128}
129
130impl ConnectorBuilder<TlsUnset> {
131    /// Set the TLS implementation to use for this connector
132    pub fn tls_provider(self, provider: tls::Provider) -> ConnectorBuilder<TlsProviderSelected> {
133        ConnectorBuilder {
134            connector_settings: self.connector_settings,
135            sleep_impl: self.sleep_impl,
136            client_builder: self.client_builder,
137            enable_tcp_nodelay: self.enable_tcp_nodelay,
138            interface: self.interface,
139            tls: TlsProviderSelected {
140                provider,
141                context: TlsContext::default(),
142            },
143        }
144    }
145
146    /// Build an HTTP connector sans TLS
147    #[doc(hidden)]
148    pub fn build_http(self) -> Connector {
149        let base = self.base_connector();
150        self.wrap_connector(base)
151    }
152}
153
154impl<Any> ConnectorBuilder<Any> {
155    /// Create a [`Connector`] from this builder and a given connector.
156    pub(crate) fn wrap_connector<C>(self, tcp_connector: C) -> Connector
157    where
158        C: Send + Sync + 'static,
159        C: Clone,
160        C: tower::Service<Uri>,
161        C::Response: Read + Write + Connection + Send + Sync + Unpin,
162        C: Connect,
163        C::Future: Unpin + Send + 'static,
164        C::Error: Into<BoxError>,
165    {
166        let client_builder =
167            self.client_builder
168                .unwrap_or(hyper_util::client::legacy::Builder::new(
169                    TokioExecutor::new(),
170                ));
171        let sleep_impl = self.sleep_impl.or_else(default_async_sleep);
172        let (connect_timeout, read_timeout) = self
173            .connector_settings
174            .map(|c| (c.connect_timeout(), c.read_timeout()))
175            .unwrap_or((None, None));
176
177        let connector = match connect_timeout {
178            Some(duration) => timeout::ConnectTimeout::new(
179                tcp_connector,
180                sleep_impl
181                    .clone()
182                    .expect("a sleep impl must be provided in order to have a connect timeout"),
183                duration,
184            ),
185            None => timeout::ConnectTimeout::no_timeout(tcp_connector),
186        };
187        let base = client_builder.build(connector);
188        let read_timeout = match read_timeout {
189            Some(duration) => timeout::HttpReadTimeout::new(
190                base,
191                sleep_impl.expect("a sleep impl must be provided in order to have a read timeout"),
192                duration,
193            ),
194            None => timeout::HttpReadTimeout::no_timeout(base),
195        };
196        Connector {
197            adapter: Box::new(Adapter {
198                client: read_timeout,
199            }),
200        }
201    }
202
203    /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
204    /// (which is a base TCP connector with no TLS or any wrapping)
205    fn base_connector(&self) -> HyperHttpConnector {
206        self.base_connector_with_resolver(GaiResolver::new())
207    }
208
209    /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
210    /// using the given resolver `R`
211    fn base_connector_with_resolver<R>(&self, resolver: R) -> HyperHttpConnector<R> {
212        let mut conn = HyperHttpConnector::new_with_resolver(resolver);
213        conn.set_nodelay(self.enable_tcp_nodelay);
214        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
215        if let Some(interface) = &self.interface {
216            conn.set_interface(interface);
217        }
218        conn
219    }
220
221    /// Set the async sleep implementation used for timeouts
222    ///
223    /// Calling this is only necessary for testing or to use something other than
224    /// [`default_async_sleep`].
225    pub fn sleep_impl(mut self, sleep_impl: impl AsyncSleep + 'static) -> Self {
226        self.sleep_impl = Some(sleep_impl.into_shared());
227        self
228    }
229
230    /// Set the async sleep implementation used for timeouts
231    ///
232    /// Calling this is only necessary for testing or to use something other than
233    /// [`default_async_sleep`].
234    pub fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) -> &mut Self {
235        self.sleep_impl = sleep_impl;
236        self
237    }
238
239    /// Configure the HTTP settings for the `HyperAdapter`
240    pub fn connector_settings(mut self, connector_settings: HttpConnectorSettings) -> Self {
241        self.connector_settings = Some(connector_settings);
242        self
243    }
244
245    /// Configure the HTTP settings for the `HyperAdapter`
246    pub fn set_connector_settings(
247        &mut self,
248        connector_settings: Option<HttpConnectorSettings>,
249    ) -> &mut Self {
250        self.connector_settings = connector_settings;
251        self
252    }
253
254    /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
255    pub fn enable_tcp_nodelay(mut self, nodelay: bool) -> Self {
256        self.enable_tcp_nodelay = nodelay;
257        self
258    }
259
260    /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
261    pub fn set_enable_tcp_nodelay(&mut self, nodelay: bool) -> &mut Self {
262        self.enable_tcp_nodelay = nodelay;
263        self
264    }
265
266    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
267    ///
268    /// If a socket is bound to an interface, only packets received from that particular
269    /// interface are processed by the socket. Note that this only works for some socket
270    /// types (e.g. `AF_INET` sockets).
271    ///
272    /// On Linux it can be used to specify a [VRF], but the binary needs to either have
273    /// `CAP_NET_RAW` capability set or be run as root.
274    ///
275    /// This function is only available on Android, Fuchsia, and Linux.
276    ///
277    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
278    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
279    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
280        self.interface = Some(interface.into());
281        self
282    }
283
284    /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
285    ///
286    /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
287    pub(crate) fn hyper_builder(
288        mut self,
289        hyper_builder: hyper_util::client::legacy::Builder,
290    ) -> Self {
291        self.set_hyper_builder(Some(hyper_builder));
292        self
293    }
294
295    /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
296    ///
297    /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
298    pub(crate) fn set_hyper_builder(
299        &mut self,
300        hyper_builder: Option<hyper_util::client::legacy::Builder>,
301    ) -> &mut Self {
302        self.client_builder = hyper_builder;
303        self
304    }
305}
306
307/// Adapter to use a Hyper 1.0-based Client as an `HttpConnector`
308///
309/// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`Connector::builder`].
310struct Adapter<C> {
311    client: timeout::HttpReadTimeout<
312        hyper_util::client::legacy::Client<timeout::ConnectTimeout<C>, SdkBody>,
313    >,
314}
315
316impl<C> fmt::Debug for Adapter<C> {
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        f.debug_struct("Adapter")
319            .field("client", &"** hyper client **")
320            .finish()
321    }
322}
323
324/// Extract a smithy connection from a hyper CaptureConnection
325fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
326    let capture_conn = capture_conn.clone();
327    if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
328        let mut extensions = Extensions::new();
329        conn.get_extras(&mut extensions);
330        let http_info = extensions.get::<HttpInfo>();
331        let mut builder = ConnectionMetadata::builder()
332            .proxied(conn.is_proxied())
333            .poison_fn(move || match capture_conn.connection_metadata().as_ref() {
334                Some(conn) => conn.poison(),
335                None => tracing::trace!("no connection existed to poison"),
336            });
337
338        builder
339            .set_local_addr(http_info.map(|info| info.local_addr()))
340            .set_remote_addr(http_info.map(|info| info.remote_addr()));
341
342        let smithy_connection = builder.build();
343
344        Some(smithy_connection)
345    } else {
346        None
347    }
348}
349
350impl<C> HttpConnector for Adapter<C>
351where
352    C: Clone + Send + Sync + 'static,
353    C: tower::Service<Uri>,
354    C::Response: Connection + Read + Write + Unpin + 'static,
355    timeout::ConnectTimeout<C>: Connect,
356    C::Future: Unpin + Send + 'static,
357    C::Error: Into<BoxError>,
358{
359    fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
360        let mut request = match request.try_into_http1x() {
361            Ok(request) => request,
362            Err(err) => {
363                return HttpConnectorFuture::ready(Err(ConnectorError::user(err.into())));
364            }
365        };
366        let capture_connection = capture_connection(&mut request);
367        if let Some(capture_smithy_connection) =
368            request.extensions().get::<CaptureSmithyConnection>()
369        {
370            capture_smithy_connection
371                .set_connection_retriever(move || extract_smithy_connection(&capture_connection));
372        }
373        let mut client = self.client.clone();
374        use tower::Service;
375        let fut = client.call(request);
376        HttpConnectorFuture::new(async move {
377            let response = fut
378                .await
379                .map_err(downcast_error)?
380                .map(SdkBody::from_body_1_x);
381            match HttpResponse::try_from(response) {
382                Ok(response) => Ok(response),
383                Err(err) => Err(ConnectorError::other(err.into(), None)),
384            }
385        })
386    }
387}
388
389/// Downcast errors coming out of hyper into an appropriate `ConnectorError`
390fn downcast_error(err: BoxError) -> ConnectorError {
391    // is a `TimedOutError` (from aws_smithy_async::timeout) in the chain? if it is, this is a timeout
392    if find_source::<TimedOutError>(err.as_ref()).is_some() {
393        return ConnectorError::timeout(err);
394    }
395    // is the top of chain error actually already a `ConnectorError`? return that directly
396    let err = match err.downcast::<ConnectorError>() {
397        Ok(connector_error) => return *connector_error,
398        Err(box_error) => box_error,
399    };
400    // generally, the top of chain will probably be a hyper error. Go through a set of hyper specific
401    // error classifications
402    let err = match find_source::<hyper::Error>(err.as_ref()) {
403        Some(hyper_error) => return to_connector_error(hyper_error)(err),
404        None => match find_source::<hyper_util::client::legacy::Error>(err.as_ref()) {
405            Some(hyper_util_err) => {
406                if hyper_util_err.is_connect()
407                    || find_source::<std::io::Error>(hyper_util_err).is_some()
408                {
409                    return ConnectorError::io(err);
410                }
411                err
412            }
413            None => err,
414        },
415    };
416
417    // otherwise, we have no idea!
418    ConnectorError::other(err, None)
419}
420
421/// Convert a [`hyper::Error`] into a [`ConnectorError`]
422fn to_connector_error(err: &hyper::Error) -> fn(BoxError) -> ConnectorError {
423    if err.is_timeout() || find_source::<timeout::HttpTimeoutError>(err).is_some() {
424        return ConnectorError::timeout;
425    }
426    if err.is_user() {
427        return ConnectorError::user;
428    }
429    if err.is_closed() || err.is_canceled() || find_source::<std::io::Error>(err).is_some() {
430        return ConnectorError::io;
431    }
432    // We sometimes receive this from S3: hyper::Error(IncompleteMessage)
433    if err.is_incomplete_message() {
434        return |err: BoxError| ConnectorError::other(err, Some(ErrorKind::TransientError));
435    }
436
437    if let Some(h2_err) = find_source::<h2::Error>(err) {
438        if h2_err.is_go_away()
439            || (h2_err.is_reset() && h2_err.reason() == Some(Reason::REFUSED_STREAM))
440        {
441            return ConnectorError::io;
442        }
443    }
444
445    tracing::warn!(err = %DisplayErrorContext(&err), "unrecognized error from Hyper. If this error should be retried, please file an issue.");
446    |err: BoxError| ConnectorError::other(err, None)
447}
448
449fn find_source<'a, E: Error + 'static>(err: &'a (dyn Error + 'static)) -> Option<&'a E> {
450    let mut next = Some(err);
451    while let Some(err) = next {
452        if let Some(matching_err) = err.downcast_ref::<E>() {
453            return Some(matching_err);
454        }
455        next = err.source();
456    }
457    None
458}
459
460// TODO(https://github.com/awslabs/aws-sdk-rust/issues/1090): CacheKey must also include ptr equality to any
461// runtime components that are used—sleep_impl as a base (unless we prohibit overriding sleep impl)
462// If we decide to put a DnsResolver in RuntimeComponents, then we'll need to handle that as well.
463#[derive(Clone, Debug, Eq, PartialEq, Hash)]
464struct CacheKey {
465    connect_timeout: Option<Duration>,
466    read_timeout: Option<Duration>,
467}
468
469impl From<&HttpConnectorSettings> for CacheKey {
470    fn from(value: &HttpConnectorSettings) -> Self {
471        Self {
472            connect_timeout: value.connect_timeout(),
473            read_timeout: value.read_timeout(),
474        }
475    }
476}
477
478struct HyperClient<F> {
479    connector_cache: RwLock<HashMap<CacheKey, SharedHttpConnector>>,
480    client_builder: hyper_util::client::legacy::Builder,
481    connector_fn: F,
482}
483
484impl<F> fmt::Debug for HyperClient<F> {
485    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
486        f.debug_struct("HyperClient")
487            .field("connector_cache", &self.connector_cache)
488            .field("client_builder", &self.client_builder)
489            .finish()
490    }
491}
492
493impl<F> HttpClient for HyperClient<F>
494where
495    F: Fn(
496            hyper_util::client::legacy::Builder,
497            Option<&HttpConnectorSettings>,
498            Option<&RuntimeComponents>,
499        ) -> Connector
500        + Send
501        + Sync
502        + 'static,
503{
504    fn http_connector(
505        &self,
506        settings: &HttpConnectorSettings,
507        components: &RuntimeComponents,
508    ) -> SharedHttpConnector {
509        let key = CacheKey::from(settings);
510        let mut connector = self.connector_cache.read().unwrap().get(&key).cloned();
511        if connector.is_none() {
512            let mut cache = self.connector_cache.write().unwrap();
513            // Short-circuit if another thread already wrote a connector to the cache for this key
514            if !cache.contains_key(&key) {
515                let start = components.time_source().map(|ts| ts.now());
516                let connector = (self.connector_fn)(
517                    self.client_builder.clone(),
518                    Some(settings),
519                    Some(components),
520                );
521                let end = components.time_source().map(|ts| ts.now());
522                if let (Some(start), Some(end)) = (start, end) {
523                    if let Ok(elapsed) = end.duration_since(start) {
524                        tracing::debug!("new connector created in {:?}", elapsed);
525                    }
526                }
527                let connector = SharedHttpConnector::new(connector);
528                cache.insert(key.clone(), connector);
529            }
530            connector = cache.get(&key).cloned();
531        }
532
533        connector.expect("cache populated above")
534    }
535
536    fn validate_base_client_config(
537        &self,
538        _: &RuntimeComponentsBuilder,
539        _: &ConfigBag,
540    ) -> Result<(), BoxError> {
541        // Initialize the TCP connector at this point so that native certs load
542        // at client initialization time instead of upon first request. We do it
543        // here rather than at construction so that it won't run if this is not
544        // the selected HTTP client for the base config (for example, if this was
545        // the default HTTP client, and it was overridden by a later plugin).
546        let _ = (self.connector_fn)(self.client_builder.clone(), None, None);
547        Ok(())
548    }
549
550    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
551        Some(ConnectorMetadata::new("hyper", Some(Cow::Borrowed("1.x"))))
552    }
553}
554
555/// Builder for a hyper-backed [`HttpClient`] implementation.
556///
557/// This builder can be used to customize the underlying TCP connector used, as well as
558/// hyper client configuration.
559///
560/// # Examples
561///
562/// Construct a Hyper client with the RusTLS TLS implementation.
563/// This can be useful when you want to share a Hyper connector between multiple
564/// generated Smithy clients.
565#[derive(Clone, Default, Debug)]
566pub struct Builder<Tls = TlsUnset> {
567    client_builder: Option<hyper_util::client::legacy::Builder>,
568    #[allow(unused)]
569    tls_provider: Tls,
570}
571
572cfg_tls! {
573    use aws_smithy_runtime_api::client::dns::ResolveDns;
574
575    impl ConnectorBuilder<TlsProviderSelected> {
576        /// Build a [`Connector`] that will use the default DNS resolver implementation.
577        pub fn build(self) -> Connector {
578            let http_connector = self.base_connector();
579            self.build_https(http_connector)
580        }
581
582        /// Configure the TLS context
583        pub fn tls_context(mut self, ctx: TlsContext) -> Self {
584            self.tls.context = ctx;
585            self
586        }
587
588        /// Configure the TLS context
589        pub fn set_tls_context(&mut self, ctx: TlsContext) -> &mut Self {
590            self.tls.context = ctx;
591            self
592        }
593
594        /// Build a [`Connector`] that will use the given DNS resolver implementation.
595        pub fn build_with_resolver<R: ResolveDns + Clone + 'static>(self, resolver: R) -> Connector {
596            use crate::client::dns::HyperUtilResolver;
597            let http_connector = self.base_connector_with_resolver(HyperUtilResolver { resolver });
598            self.build_https(http_connector)
599        }
600
601        fn build_https<R>(self, http_connector: HyperHttpConnector<R>) -> Connector
602        where
603            R: Clone + Send + Sync + 'static,
604            R: tower::Service<hyper_util::client::legacy::connect::dns::Name>,
605            R::Response: Iterator<Item = std::net::SocketAddr>,
606            R::Future: Send,
607            R::Error: Into<Box<dyn Error + Send + Sync>>,
608        {
609            match &self.tls.provider {
610                // TODO(hyper1) - fix cfg_rustls! to allow matching on patterns so we can re-use it and not duplicate these cfg matches everywhere
611                #[cfg(any(
612                    feature = "rustls-aws-lc",
613                    feature = "rustls-aws-lc-fips",
614                    feature = "rustls-ring"
615                ))]
616                tls::Provider::Rustls(crypto_mode) => {
617                    let https_connector = tls::rustls_provider::build_connector::wrap_connector(
618                        http_connector,
619                        crypto_mode.clone(),
620                        &self.tls.context,
621                    );
622                    self.wrap_connector(https_connector)
623                },
624                #[cfg(feature = "s2n-tls")]
625                tls::Provider::S2nTls  => {
626                    let https_connector = tls::s2n_tls_provider::build_connector::wrap_connector(http_connector, &self.tls.context);
627                    self.wrap_connector(https_connector)
628                }
629            }
630        }
631    }
632
633    impl Builder<TlsProviderSelected> {
634        /// Create an HTTPS client with the selected TLS provider.
635        ///
636        /// The trusted certificates will be loaded later when this becomes the selected
637        /// HTTP client for a Smithy client.
638        pub fn build_https(self) -> SharedHttpClient {
639            build_with_conn_fn(
640                self.client_builder,
641                move |client_builder, settings, runtime_components| {
642                    let builder = new_conn_builder(client_builder, settings, runtime_components)
643                        .tls_provider(self.tls_provider.provider.clone())
644                        .tls_context(self.tls_provider.context.clone());
645                    builder.build()
646                },
647            )
648        }
649
650        /// Create an HTTPS client using a custom DNS resolver
651        pub fn build_with_resolver(
652            self,
653            resolver: impl ResolveDns + Clone + 'static,
654        ) -> SharedHttpClient {
655            build_with_conn_fn(
656                self.client_builder,
657                move |client_builder, settings, runtime_components| {
658                    let builder = new_conn_builder(client_builder, settings, runtime_components)
659                        .tls_provider(self.tls_provider.provider.clone())
660                        .tls_context(self.tls_provider.context.clone());
661                    builder.build_with_resolver(resolver.clone())
662                },
663            )
664        }
665
666        /// Configure the TLS context
667        pub fn tls_context(mut self, ctx: TlsContext) -> Self {
668            self.tls_provider.context = ctx;
669            self
670        }
671    }
672}
673
674impl Builder<TlsUnset> {
675    /// Creates a new builder.
676    pub fn new() -> Self {
677        Self::default()
678    }
679
680    /// Build a new HTTP client without TLS enabled
681    #[doc(hidden)]
682    pub fn build_http(self) -> SharedHttpClient {
683        build_with_conn_fn(
684            self.client_builder,
685            move |client_builder, settings, runtime_components| {
686                let builder = new_conn_builder(client_builder, settings, runtime_components);
687                builder.build_http()
688            },
689        )
690    }
691
692    /// Set the TLS implementation to use
693    pub fn tls_provider(self, provider: tls::Provider) -> Builder<TlsProviderSelected> {
694        Builder {
695            client_builder: self.client_builder,
696            tls_provider: TlsProviderSelected {
697                provider,
698                context: TlsContext::default(),
699            },
700        }
701    }
702}
703
704pub(crate) fn build_with_conn_fn<F>(
705    client_builder: Option<hyper_util::client::legacy::Builder>,
706    connector_fn: F,
707) -> SharedHttpClient
708where
709    F: Fn(
710            hyper_util::client::legacy::Builder,
711            Option<&HttpConnectorSettings>,
712            Option<&RuntimeComponents>,
713        ) -> Connector
714        + Send
715        + Sync
716        + 'static,
717{
718    SharedHttpClient::new(HyperClient {
719        connector_cache: RwLock::new(HashMap::new()),
720        client_builder: client_builder
721            .unwrap_or_else(|| hyper_util::client::legacy::Builder::new(TokioExecutor::new())),
722        connector_fn,
723    })
724}
725
726#[allow(dead_code)]
727pub(crate) fn build_with_tcp_conn_fn<C, F>(
728    client_builder: Option<hyper_util::client::legacy::Builder>,
729    tcp_connector_fn: F,
730) -> SharedHttpClient
731where
732    F: Fn() -> C + Send + Sync + 'static,
733    C: Clone + Send + Sync + 'static,
734    C: tower::Service<Uri>,
735    C::Response: Connection + Read + Write + Send + Sync + Unpin + 'static,
736    C::Future: Unpin + Send + 'static,
737    C::Error: Into<BoxError>,
738    C: Connect,
739{
740    build_with_conn_fn(
741        client_builder,
742        move |client_builder, settings, runtime_components| {
743            let builder = new_conn_builder(client_builder, settings, runtime_components);
744            builder.wrap_connector(tcp_connector_fn())
745        },
746    )
747}
748
749fn new_conn_builder(
750    client_builder: hyper_util::client::legacy::Builder,
751    settings: Option<&HttpConnectorSettings>,
752    runtime_components: Option<&RuntimeComponents>,
753) -> ConnectorBuilder {
754    let mut builder = Connector::builder().hyper_builder(client_builder);
755    builder.set_connector_settings(settings.cloned());
756    if let Some(components) = runtime_components {
757        builder.set_sleep_impl(components.sleep_impl());
758    }
759    builder
760}
761
762#[cfg(test)]
763mod test {
764    use std::io::{Error, ErrorKind};
765    use std::pin::Pin;
766    use std::sync::atomic::{AtomicU32, Ordering};
767    use std::sync::Arc;
768    use std::task::{Context, Poll};
769
770    use crate::client::timeout::test::NeverConnects;
771    use aws_smithy_async::assert_elapsed;
772    use aws_smithy_async::rt::sleep::TokioSleep;
773    use aws_smithy_async::time::SystemTimeSource;
774    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
775    use http_1x::Uri;
776    use hyper::rt::ReadBufCursor;
777    use hyper_util::client::legacy::connect::Connected;
778
779    use super::*;
780
781    #[tokio::test]
782    async fn connector_selection() {
783        // Create a client that increments a count every time it creates a new Connector
784        let creation_count = Arc::new(AtomicU32::new(0));
785        let http_client = build_with_tcp_conn_fn(None, {
786            let count = creation_count.clone();
787            move || {
788                count.fetch_add(1, Ordering::Relaxed);
789                NeverConnects
790            }
791        });
792
793        // This configuration should result in 4 separate connectors with different timeout settings
794        let settings = [
795            HttpConnectorSettings::builder()
796                .connect_timeout(Duration::from_secs(3))
797                .build(),
798            HttpConnectorSettings::builder()
799                .read_timeout(Duration::from_secs(3))
800                .build(),
801            HttpConnectorSettings::builder()
802                .connect_timeout(Duration::from_secs(3))
803                .read_timeout(Duration::from_secs(3))
804                .build(),
805            HttpConnectorSettings::builder()
806                .connect_timeout(Duration::from_secs(5))
807                .read_timeout(Duration::from_secs(3))
808                .build(),
809        ];
810
811        // Kick off thousands of parallel tasks that will try to create a connector
812        let components = RuntimeComponentsBuilder::for_tests()
813            .with_time_source(Some(SystemTimeSource::new()))
814            .build()
815            .unwrap();
816        let mut handles = Vec::new();
817        for setting in &settings {
818            for _ in 0..1000 {
819                let client = http_client.clone();
820                handles.push(tokio::spawn({
821                    let setting = setting.clone();
822                    let components = components.clone();
823                    async move {
824                        let _ = client.http_connector(&setting, &components);
825                    }
826                }));
827            }
828        }
829        for handle in handles {
830            handle.await.unwrap();
831        }
832
833        // Verify only 4 connectors were created amidst the chaos
834        assert_eq!(4, creation_count.load(Ordering::Relaxed));
835    }
836
837    #[tokio::test]
838    async fn hyper_io_error() {
839        let connector = TestConnection {
840            inner: HangupStream,
841        };
842        let adapter = Connector::builder().wrap_connector(connector).adapter;
843        let err = adapter
844            .call(HttpRequest::get("https://socket-hangup.com").unwrap())
845            .await
846            .expect_err("socket hangup");
847        assert!(err.is_io(), "unexpected error type: {:?}", err);
848    }
849
850    // ---- machinery to make a Hyper connector that responds with an IO Error
851    #[derive(Clone)]
852    struct HangupStream;
853
854    impl Connection for HangupStream {
855        fn connected(&self) -> Connected {
856            Connected::new()
857        }
858    }
859
860    impl Read for HangupStream {
861        fn poll_read(
862            self: Pin<&mut Self>,
863            _cx: &mut Context<'_>,
864            _buf: ReadBufCursor<'_>,
865        ) -> Poll<std::io::Result<()>> {
866            Poll::Ready(Err(Error::new(
867                ErrorKind::ConnectionReset,
868                "connection reset",
869            )))
870        }
871    }
872
873    impl Write for HangupStream {
874        fn poll_write(
875            self: Pin<&mut Self>,
876            _cx: &mut Context<'_>,
877            _buf: &[u8],
878        ) -> Poll<Result<usize, Error>> {
879            Poll::Pending
880        }
881
882        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
883            Poll::Pending
884        }
885
886        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
887            Poll::Pending
888        }
889    }
890
891    #[derive(Clone)]
892    struct TestConnection<T> {
893        inner: T,
894    }
895
896    impl<T> tower::Service<Uri> for TestConnection<T>
897    where
898        T: Clone + Connection,
899    {
900        type Response = T;
901        type Error = BoxError;
902        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
903
904        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
905            Poll::Ready(Ok(()))
906        }
907
908        fn call(&mut self, _req: Uri) -> Self::Future {
909            std::future::ready(Ok(self.inner.clone()))
910        }
911    }
912
913    #[tokio::test]
914    async fn http_connect_timeout_works() {
915        let tcp_connector = NeverConnects::default();
916        let connector_settings = HttpConnectorSettings::builder()
917            .connect_timeout(Duration::from_secs(1))
918            .build();
919        let hyper = Connector::builder()
920            .connector_settings(connector_settings)
921            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
922            .wrap_connector(tcp_connector)
923            .adapter;
924        let now = tokio::time::Instant::now();
925        tokio::time::pause();
926        let resp = hyper
927            .call(HttpRequest::get("https://static-uri.com").unwrap())
928            .await
929            .unwrap_err();
930        assert!(
931            resp.is_timeout(),
932            "expected resp.is_timeout() to be true but it was false, resp == {:?}",
933            resp
934        );
935        let message = DisplayErrorContext(&resp).to_string();
936        let expected = "timeout: client error (Connect): HTTP connect timeout occurred after 1s";
937        assert!(
938            message.contains(expected),
939            "expected '{message}' to contain '{expected}'"
940        );
941        assert_elapsed!(now, Duration::from_secs(1));
942    }
943
944    #[tokio::test]
945    async fn http_read_timeout_works() {
946        let tcp_connector = crate::client::timeout::test::NeverReplies;
947        let connector_settings = HttpConnectorSettings::builder()
948            .connect_timeout(Duration::from_secs(1))
949            .read_timeout(Duration::from_secs(2))
950            .build();
951        let hyper = Connector::builder()
952            .connector_settings(connector_settings)
953            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
954            .wrap_connector(tcp_connector)
955            .adapter;
956        let now = tokio::time::Instant::now();
957        tokio::time::pause();
958        let err = hyper
959            .call(HttpRequest::get("https://fake-uri.com").unwrap())
960            .await
961            .unwrap_err();
962        assert!(
963            err.is_timeout(),
964            "expected err.is_timeout() to be true but it was false, err == {err:?}",
965        );
966        let message = format!("{}", DisplayErrorContext(&err));
967        let expected = "timeout: HTTP read timeout occurred after 2s";
968        assert!(
969            message.contains(expected),
970            "expected '{message}' to contain '{expected}'"
971        );
972        assert_elapsed!(now, Duration::from_secs(2));
973    }
974
975    #[cfg(not(windows))]
976    #[tokio::test]
977    async fn connection_refused_works() {
978        use crate::client::dns::HyperUtilResolver;
979        use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns};
980        use std::net::{IpAddr, Ipv4Addr};
981
982        #[derive(Debug, Clone, Default)]
983        struct TestResolver;
984        impl ResolveDns for TestResolver {
985            fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
986                let localhost_v4 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
987                DnsFuture::ready(Ok(vec![localhost_v4]))
988            }
989        }
990
991        let connector_settings = HttpConnectorSettings::builder()
992            .connect_timeout(Duration::from_secs(20))
993            .build();
994
995        let resolver = HyperUtilResolver {
996            resolver: TestResolver::default(),
997        };
998        let connector = Connector::builder().base_connector_with_resolver(resolver);
999
1000        let hyper = Connector::builder()
1001            .connector_settings(connector_settings)
1002            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
1003            .wrap_connector(connector)
1004            .adapter;
1005
1006        let resp = hyper
1007            .call(HttpRequest::get("http://static-uri:50227.com").unwrap())
1008            .await
1009            .unwrap_err();
1010        assert!(
1011            resp.is_io(),
1012            "expected resp.is_io() to be true but it was false, resp == {:?}",
1013            resp
1014        );
1015        let message = DisplayErrorContext(&resp).to_string();
1016        let expected = "Connection refused";
1017        assert!(
1018            message.contains(expected),
1019            "expected '{message}' to contain '{expected}'"
1020        );
1021    }
1022
1023    #[cfg(feature = "s2n-tls")]
1024    #[tokio::test]
1025    async fn s2n_tls_provider() {
1026        // Create an HttpConnector with the s2n-tls provider.
1027        let client = Builder::new()
1028            .tls_provider(tls::Provider::S2nTls)
1029            .build_https();
1030        let connector_settings = HttpConnectorSettings::builder().build();
1031
1032        // HyperClient::http_connector invokes TimeSource::now to determine how long it takes to
1033        // create new HttpConnectors. As such, a real time source must be provided.
1034        let runtime_components = RuntimeComponentsBuilder::for_tests()
1035            .with_time_source(Some(SystemTimeSource::new()))
1036            .build()
1037            .unwrap();
1038
1039        let connector = client.http_connector(&connector_settings, &runtime_components);
1040
1041        // Ensure that s2n-tls is used as the underlying TLS provider when selected.
1042        //
1043        // s2n-tls-hyper will error when given an invalid scheme. Ensure that this error is produced
1044        // from s2n-tls-hyper, and not another TLS provider.
1045        let error = connector
1046            .call(HttpRequest::get("notascheme://amazon.com").unwrap())
1047            .await
1048            .unwrap_err();
1049        let error = error.into_source();
1050        let s2n_error = error
1051            .source()
1052            .unwrap()
1053            .downcast_ref::<s2n_tls_hyper::error::Error>()
1054            .unwrap();
1055        assert!(matches!(
1056            s2n_error,
1057            s2n_tls_hyper::error::Error::InvalidScheme
1058        ));
1059    }
1060}