1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | mod timeout;
|
7 + | /// TLS connector(s)
|
8 + | pub mod tls;
|
9 + |
|
10 + | use crate::cfg::cfg_tls;
|
11 + | use crate::tls::TlsContext;
|
12 + | use aws_smithy_async::future::timeout::TimedOutError;
|
13 + | use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep, SharedAsyncSleep};
|
14 + | use aws_smithy_runtime_api::box_error::BoxError;
|
15 + | use aws_smithy_runtime_api::client::connection::CaptureSmithyConnection;
|
16 + | use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
|
17 + | use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
|
18 + | use aws_smithy_runtime_api::client::http::{
|
19 + | HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpClient,
|
20 + | SharedHttpConnector,
|
21 + | };
|
22 + | use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
|
23 + | use aws_smithy_runtime_api::client::result::ConnectorError;
|
24 + | use aws_smithy_runtime_api::client::runtime_components::{
|
25 + | RuntimeComponents, RuntimeComponentsBuilder,
|
26 + | };
|
27 + | use aws_smithy_runtime_api::shared::IntoShared;
|
28 + | use aws_smithy_types::body::SdkBody;
|
29 + | use aws_smithy_types::config_bag::ConfigBag;
|
30 + | use aws_smithy_types::error::display::DisplayErrorContext;
|
31 + | use aws_smithy_types::retry::ErrorKind;
|
32 + | use client::connect::Connection;
|
33 + | use h2::Reason;
|
34 + | use http_1x::{Extensions, Uri};
|
35 + | use hyper::rt::{Read, Write};
|
36 + | use hyper_util::client::legacy as client;
|
37 + | use hyper_util::client::legacy::connect::dns::GaiResolver;
|
38 + | use hyper_util::client::legacy::connect::{
|
39 + | capture_connection, CaptureConnection, Connect, HttpConnector as HyperHttpConnector, HttpInfo,
|
40 + | };
|
41 + | use hyper_util::rt::TokioExecutor;
|
42 + | use std::borrow::Cow;
|
43 + | use std::collections::HashMap;
|
44 + | use std::error::Error;
|
45 + | use std::fmt;
|
46 + | use std::sync::RwLock;
|
47 + | use std::time::Duration;
|
48 + |
|
49 + | /// Given `HttpConnectorSettings` and an `SharedAsyncSleep`, create a `SharedHttpConnector` from defaults depending on what cargo features are activated.
|
50 + | pub fn default_connector(
|
51 + | settings: &HttpConnectorSettings,
|
52 + | sleep: Option<SharedAsyncSleep>,
|
53 + | ) -> Option<SharedHttpConnector> {
|
54 + | #[cfg(feature = "rustls-aws-lc")]
|
55 + | {
|
56 + | tracing::trace!(settings = ?settings, sleep = ?sleep, "creating a new default connector");
|
57 + | let mut conn_builder = Connector::builder().connector_settings(settings.clone());
|
58 + |
|
59 + | if let Some(sleep) = sleep {
|
60 + | conn_builder = conn_builder.sleep_impl(sleep);
|
61 + | }
|
62 + |
|
63 + | let conn = conn_builder
|
64 + | .tls_provider(tls::Provider::Rustls(
|
65 + | tls::rustls_provider::CryptoMode::AwsLc,
|
66 + | ))
|
67 + | .build();
|
68 + | Some(SharedHttpConnector::new(conn))
|
69 + | }
|
70 + | #[cfg(not(feature = "rustls-aws-lc"))]
|
71 + | {
|
72 + | tracing::trace!(settings = ?settings, sleep = ?sleep, "no default connector available");
|
73 + | None
|
74 + | }
|
75 + | }
|
76 + |
|
77 + | /// [`HttpConnector`] used to make HTTP requests.
|
78 + | ///
|
79 + | /// This connector also implements socket connect and read timeouts.
|
80 + | ///
|
81 + | /// This shouldn't be used directly in most cases.
|
82 + | /// See the docs on [`Builder`] for examples of how to customize the HTTP client.
|
83 + | #[derive(Debug)]
|
84 + | pub struct Connector {
|
85 + | adapter: Box<dyn HttpConnector>,
|
86 + | }
|
87 + |
|
88 + | impl Connector {
|
89 + | /// Builder for an HTTP connector.
|
90 + | pub fn builder() -> ConnectorBuilder {
|
91 + | ConnectorBuilder {
|
92 + | enable_tcp_nodelay: true,
|
93 + | ..Default::default()
|
94 + | }
|
95 + | }
|
96 + | }
|
97 + |
|
98 + | impl HttpConnector for Connector {
|
99 + | fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
|
100 + | self.adapter.call(request)
|
101 + | }
|
102 + | }
|
103 + |
|
104 + | /// Builder for [`Connector`].
|
105 + | #[derive(Default, Debug)]
|
106 + | pub struct ConnectorBuilder<Tls = TlsUnset> {
|
107 + | connector_settings: Option<HttpConnectorSettings>,
|
108 + | sleep_impl: Option<SharedAsyncSleep>,
|
109 + | client_builder: Option<hyper_util::client::legacy::Builder>,
|
110 + | enable_tcp_nodelay: bool,
|
111 + | interface: Option<String>,
|
112 + | #[allow(unused)]
|
113 + | tls: Tls,
|
114 + | }
|
115 + |
|
116 + | /// Initial builder state, `TlsProvider` choice required
|
117 + | #[derive(Default)]
|
118 + | #[non_exhaustive]
|
119 + | pub struct TlsUnset {}
|
120 + |
|
121 + | /// TLS implementation selected
|
122 + | pub struct TlsProviderSelected {
|
123 + | #[allow(unused)]
|
124 + | provider: tls::Provider,
|
125 + | #[allow(unused)]
|
126 + | context: TlsContext,
|
127 + | }
|
128 + |
|
129 + | impl ConnectorBuilder<TlsUnset> {
|
130 + | /// Set the TLS implementation to use for this connector
|
131 + | pub fn tls_provider(self, provider: tls::Provider) -> ConnectorBuilder<TlsProviderSelected> {
|
132 + | ConnectorBuilder {
|
133 + | connector_settings: self.connector_settings,
|
134 + | sleep_impl: self.sleep_impl,
|
135 + | client_builder: self.client_builder,
|
136 + | enable_tcp_nodelay: self.enable_tcp_nodelay,
|
137 + | interface: self.interface,
|
138 + | tls: TlsProviderSelected {
|
139 + | provider,
|
140 + | context: TlsContext::default(),
|
141 + | },
|
142 + | }
|
143 + | }
|
144 + |
|
145 + | /// Build an HTTP connector sans TLS
|
146 + | #[doc(hidden)]
|
147 + | pub fn build_http(self) -> Connector {
|
148 + | let base = self.base_connector();
|
149 + | self.wrap_connector(base)
|
150 + | }
|
151 + | }
|
152 + |
|
153 + | impl<Any> ConnectorBuilder<Any> {
|
154 + | /// Create a [`Connector`] from this builder and a given connector.
|
155 + | pub(crate) fn wrap_connector<C>(self, tcp_connector: C) -> Connector
|
156 + | where
|
157 + | C: Send + Sync + 'static,
|
158 + | C: Clone,
|
159 + | C: tower::Service<Uri>,
|
160 + | C::Response: Read + Write + Connection + Send + Sync + Unpin,
|
161 + | C: Connect,
|
162 + | C::Future: Unpin + Send + 'static,
|
163 + | C::Error: Into<BoxError>,
|
164 + | {
|
165 + | let client_builder =
|
166 + | self.client_builder
|
167 + | .unwrap_or(hyper_util::client::legacy::Builder::new(
|
168 + | TokioExecutor::new(),
|
169 + | ));
|
170 + | let sleep_impl = self.sleep_impl.or_else(default_async_sleep);
|
171 + | let (connect_timeout, read_timeout) = self
|
172 + | .connector_settings
|
173 + | .map(|c| (c.connect_timeout(), c.read_timeout()))
|
174 + | .unwrap_or((None, None));
|
175 + |
|
176 + | let connector = match connect_timeout {
|
177 + | Some(duration) => timeout::ConnectTimeout::new(
|
178 + | tcp_connector,
|
179 + | sleep_impl
|
180 + | .clone()
|
181 + | .expect("a sleep impl must be provided in order to have a connect timeout"),
|
182 + | duration,
|
183 + | ),
|
184 + | None => timeout::ConnectTimeout::no_timeout(tcp_connector),
|
185 + | };
|
186 + | let base = client_builder.build(connector);
|
187 + | let read_timeout = match read_timeout {
|
188 + | Some(duration) => timeout::HttpReadTimeout::new(
|
189 + | base,
|
190 + | sleep_impl.expect("a sleep impl must be provided in order to have a read timeout"),
|
191 + | duration,
|
192 + | ),
|
193 + | None => timeout::HttpReadTimeout::no_timeout(base),
|
194 + | };
|
195 + | Connector {
|
196 + | adapter: Box::new(Adapter {
|
197 + | client: read_timeout,
|
198 + | }),
|
199 + | }
|
200 + | }
|
201 + |
|
202 + | /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
|
203 + | /// (which is a base TCP connector with no TLS or any wrapping)
|
204 + | fn base_connector(&self) -> HyperHttpConnector {
|
205 + | self.base_connector_with_resolver(GaiResolver::new())
|
206 + | }
|
207 + |
|
208 + | /// Get the base TCP connector by mapping our config to the underlying `HttpConnector` from hyper
|
209 + | /// using the given resolver `R`
|
210 + | fn base_connector_with_resolver<R>(&self, resolver: R) -> HyperHttpConnector<R> {
|
211 + | let mut conn = HyperHttpConnector::new_with_resolver(resolver);
|
212 + | conn.set_nodelay(self.enable_tcp_nodelay);
|
213 + | #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
214 + | if let Some(interface) = &self.interface {
|
215 + | conn.set_interface(interface);
|
216 + | }
|
217 + | conn
|
218 + | }
|
219 + |
|
220 + | /// Set the async sleep implementation used for timeouts
|
221 + | ///
|
222 + | /// Calling this is only necessary for testing or to use something other than
|
223 + | /// [`default_async_sleep`].
|
224 + | pub fn sleep_impl(mut self, sleep_impl: impl AsyncSleep + 'static) -> Self {
|
225 + | self.sleep_impl = Some(sleep_impl.into_shared());
|
226 + | self
|
227 + | }
|
228 + |
|
229 + | /// Set the async sleep implementation used for timeouts
|
230 + | ///
|
231 + | /// Calling this is only necessary for testing or to use something other than
|
232 + | /// [`default_async_sleep`].
|
233 + | pub fn set_sleep_impl(&mut self, sleep_impl: Option<SharedAsyncSleep>) -> &mut Self {
|
234 + | self.sleep_impl = sleep_impl;
|
235 + | self
|
236 + | }
|
237 + |
|
238 + | /// Configure the HTTP settings for the `HyperAdapter`
|
239 + | pub fn connector_settings(mut self, connector_settings: HttpConnectorSettings) -> Self {
|
240 + | self.connector_settings = Some(connector_settings);
|
241 + | self
|
242 + | }
|
243 + |
|
244 + | /// Configure the HTTP settings for the `HyperAdapter`
|
245 + | pub fn set_connector_settings(
|
246 + | &mut self,
|
247 + | connector_settings: Option<HttpConnectorSettings>,
|
248 + | ) -> &mut Self {
|
249 + | self.connector_settings = connector_settings;
|
250 + | self
|
251 + | }
|
252 + |
|
253 + | /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
|
254 + | pub fn enable_tcp_nodelay(mut self, nodelay: bool) -> Self {
|
255 + | self.enable_tcp_nodelay = nodelay;
|
256 + | self
|
257 + | }
|
258 + |
|
259 + | /// Configure `SO_NODELAY` for all sockets to the supplied value `nodelay`
|
260 + | pub fn set_enable_tcp_nodelay(&mut self, nodelay: bool) -> &mut Self {
|
261 + | self.enable_tcp_nodelay = nodelay;
|
262 + | self
|
263 + | }
|
264 + |
|
265 + | /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
|
266 + | ///
|
267 + | /// If a socket is bound to an interface, only packets received from that particular
|
268 + | /// interface are processed by the socket. Note that this only works for some socket
|
269 + | /// types (e.g. `AF_INET` sockets).
|
270 + | ///
|
271 + | /// On Linux it can be used to specify a [VRF], but the binary needs to either have
|
272 + | /// `CAP_NET_RAW` capability set or be run as root.
|
273 + | ///
|
274 + | /// This function is only available on Android, Fuchsia, and Linux.
|
275 + | ///
|
276 + | /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
|
277 + | #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
|
278 + | pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
|
279 + | self.interface = Some(interface.into());
|
280 + | self
|
281 + | }
|
282 + |
|
283 + | /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
|
284 + | ///
|
285 + | /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
|
286 + | pub(crate) fn hyper_builder(
|
287 + | mut self,
|
288 + | hyper_builder: hyper_util::client::legacy::Builder,
|
289 + | ) -> Self {
|
290 + | self.set_hyper_builder(Some(hyper_builder));
|
291 + | self
|
292 + | }
|
293 + |
|
294 + | /// Override the Hyper client [`Builder`](hyper_util::client::legacy::Builder) used to construct this client.
|
295 + | ///
|
296 + | /// This enables changing settings like forcing HTTP2 and modifying other default client behavior.
|
297 + | pub(crate) fn set_hyper_builder(
|
298 + | &mut self,
|
299 + | hyper_builder: Option<hyper_util::client::legacy::Builder>,
|
300 + | ) -> &mut Self {
|
301 + | self.client_builder = hyper_builder;
|
302 + | self
|
303 + | }
|
304 + | }
|
305 + |
|
306 + | /// Adapter to use a Hyper 1.0-based Client as an `HttpConnector`
|
307 + | ///
|
308 + | /// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`Connector::builder`].
|
309 + | struct Adapter<C> {
|
310 + | client: timeout::HttpReadTimeout<
|
311 + | hyper_util::client::legacy::Client<timeout::ConnectTimeout<C>, SdkBody>,
|
312 + | >,
|
313 + | }
|
314 + |
|
315 + | impl<C> fmt::Debug for Adapter<C> {
|
316 + | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
317 + | f.debug_struct("Adapter")
|
318 + | .field("client", &"** hyper client **")
|
319 + | .finish()
|
320 + | }
|
321 + | }
|
322 + |
|
323 + | /// Extract a smithy connection from a hyper CaptureConnection
|
324 + | fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option<ConnectionMetadata> {
|
325 + | let capture_conn = capture_conn.clone();
|
326 + | if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() {
|
327 + | let mut extensions = Extensions::new();
|
328 + | conn.get_extras(&mut extensions);
|
329 + | let http_info = extensions.get::<HttpInfo>();
|
330 + | let mut builder = ConnectionMetadata::builder()
|
331 + | .proxied(conn.is_proxied())
|
332 + | .poison_fn(move || match capture_conn.connection_metadata().as_ref() {
|
333 + | Some(conn) => conn.poison(),
|
334 + | None => tracing::trace!("no connection existed to poison"),
|
335 + | });
|
336 + |
|
337 + | builder
|
338 + | .set_local_addr(http_info.map(|info| info.local_addr()))
|
339 + | .set_remote_addr(http_info.map(|info| info.remote_addr()));
|
340 + |
|
341 + | let smithy_connection = builder.build();
|
342 + |
|
343 + | Some(smithy_connection)
|
344 + | } else {
|
345 + | None
|
346 + | }
|
347 + | }
|
348 + |
|
349 + | impl<C> HttpConnector for Adapter<C>
|
350 + | where
|
351 + | C: Clone + Send + Sync + 'static,
|
352 + | C: tower::Service<Uri>,
|
353 + | C::Response: Connection + Read + Write + Unpin + 'static,
|
354 + | timeout::ConnectTimeout<C>: Connect,
|
355 + | C::Future: Unpin + Send + 'static,
|
356 + | C::Error: Into<BoxError>,
|
357 + | {
|
358 + | fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
|
359 + | let mut request = match request.try_into_http1x() {
|
360 + | Ok(request) => request,
|
361 + | Err(err) => {
|
362 + | return HttpConnectorFuture::ready(Err(ConnectorError::user(err.into())));
|
363 + | }
|
364 + | };
|
365 + | let capture_connection = capture_connection(&mut request);
|
366 + | if let Some(capture_smithy_connection) =
|
367 + | request.extensions().get::<CaptureSmithyConnection>()
|
368 + | {
|
369 + | capture_smithy_connection
|
370 + | .set_connection_retriever(move || extract_smithy_connection(&capture_connection));
|
371 + | }
|
372 + | let mut client = self.client.clone();
|
373 + | use tower::Service;
|
374 + | let fut = client.call(request);
|
375 + | HttpConnectorFuture::new(async move {
|
376 + | let response = fut
|
377 + | .await
|
378 + | .map_err(downcast_error)?
|
379 + | .map(SdkBody::from_body_1_x);
|
380 + | match HttpResponse::try_from(response) {
|
381 + | Ok(response) => Ok(response),
|
382 + | Err(err) => Err(ConnectorError::other(err.into(), None)),
|
383 + | }
|
384 + | })
|
385 + | }
|
386 + | }
|
387 + |
|
388 + | /// Downcast errors coming out of hyper into an appropriate `ConnectorError`
|
389 + | fn downcast_error(err: BoxError) -> ConnectorError {
|
390 + | // is a `TimedOutError` (from aws_smithy_async::timeout) in the chain? if it is, this is a timeout
|
391 + | if find_source::<TimedOutError>(err.as_ref()).is_some() {
|
392 + | return ConnectorError::timeout(err);
|
393 + | }
|
394 + | // is the top of chain error actually already a `ConnectorError`? return that directly
|
395 + | let err = match err.downcast::<ConnectorError>() {
|
396 + | Ok(connector_error) => return *connector_error,
|
397 + | Err(box_error) => box_error,
|
398 + | };
|
399 + | // generally, the top of chain will probably be a hyper error. Go through a set of hyper specific
|
400 + | // error classifications
|
401 + | let err = match find_source::<hyper::Error>(err.as_ref()) {
|
402 + | Some(hyper_error) => return to_connector_error(hyper_error)(err),
|
403 + | None => err,
|
404 + | };
|
405 + |
|
406 + | // otherwise, we have no idea!
|
407 + | ConnectorError::other(err, None)
|
408 + | }
|
409 + |
|
410 + | /// Convert a [`hyper::Error`] into a [`ConnectorError`]
|
411 + | fn to_connector_error(err: &hyper::Error) -> fn(BoxError) -> ConnectorError {
|
412 + | if err.is_timeout() || find_source::<timeout::HttpTimeoutError>(err).is_some() {
|
413 + | return ConnectorError::timeout;
|
414 + | }
|
415 + | if err.is_user() {
|
416 + | return ConnectorError::user;
|
417 + | }
|
418 + | if err.is_closed() || err.is_canceled() || find_source::<std::io::Error>(err).is_some() {
|
419 + | return ConnectorError::io;
|
420 + | }
|
421 + | // We sometimes receive this from S3: hyper::Error(IncompleteMessage)
|
422 + | if err.is_incomplete_message() {
|
423 + | return |err: BoxError| ConnectorError::other(err, Some(ErrorKind::TransientError));
|
424 + | }
|
425 + |
|
426 + | if let Some(h2_err) = find_source::<h2::Error>(err) {
|
427 + | if h2_err.is_go_away()
|
428 + | || (h2_err.is_reset() && h2_err.reason() == Some(Reason::REFUSED_STREAM))
|
429 + | {
|
430 + | return ConnectorError::io;
|
431 + | }
|
432 + | }
|
433 + |
|
434 + | tracing::warn!(err = %DisplayErrorContext(&err), "unrecognized error from Hyper. If this error should be retried, please file an issue.");
|
435 + | |err: BoxError| ConnectorError::other(err, None)
|
436 + | }
|
437 + |
|
438 + | fn find_source<'a, E: Error + 'static>(err: &'a (dyn Error + 'static)) -> Option<&'a E> {
|
439 + | let mut next = Some(err);
|
440 + | while let Some(err) = next {
|
441 + | if let Some(matching_err) = err.downcast_ref::<E>() {
|
442 + | return Some(matching_err);
|
443 + | }
|
444 + | next = err.source();
|
445 + | }
|
446 + | None
|
447 + | }
|
448 + |
|
449 + | // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1090): CacheKey must also include ptr equality to any
|
450 + | // runtime components that are used—sleep_impl as a base (unless we prohibit overriding sleep impl)
|
451 + | // If we decide to put a DnsResolver in RuntimeComponents, then we'll need to handle that as well.
|
452 + | #[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
453 + | struct CacheKey {
|
454 + | connect_timeout: Option<Duration>,
|
455 + | read_timeout: Option<Duration>,
|
456 + | }
|
457 + |
|
458 + | impl From<&HttpConnectorSettings> for CacheKey {
|
459 + | fn from(value: &HttpConnectorSettings) -> Self {
|
460 + | Self {
|
461 + | connect_timeout: value.connect_timeout(),
|
462 + | read_timeout: value.read_timeout(),
|
463 + | }
|
464 + | }
|
465 + | }
|
466 + |
|
467 + | struct HyperClient<F> {
|
468 + | connector_cache: RwLock<HashMap<CacheKey, SharedHttpConnector>>,
|
469 + | client_builder: hyper_util::client::legacy::Builder,
|
470 + | connector_fn: F,
|
471 + | }
|
472 + |
|
473 + | impl<F> fmt::Debug for HyperClient<F> {
|
474 + | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
475 + | f.debug_struct("HyperClient")
|
476 + | .field("connector_cache", &self.connector_cache)
|
477 + | .field("client_builder", &self.client_builder)
|
478 + | .finish()
|
479 + | }
|
480 + | }
|
481 + |
|
482 + | impl<F> HttpClient for HyperClient<F>
|
483 + | where
|
484 + | F: Fn(
|
485 + | hyper_util::client::legacy::Builder,
|
486 + | Option<&HttpConnectorSettings>,
|
487 + | Option<&RuntimeComponents>,
|
488 + | ) -> Connector
|
489 + | + Send
|
490 + | + Sync
|
491 + | + 'static,
|
492 + | {
|
493 + | fn http_connector(
|
494 + | &self,
|
495 + | settings: &HttpConnectorSettings,
|
496 + | components: &RuntimeComponents,
|
497 + | ) -> SharedHttpConnector {
|
498 + | let key = CacheKey::from(settings);
|
499 + | let mut connector = self.connector_cache.read().unwrap().get(&key).cloned();
|
500 + | if connector.is_none() {
|
501 + | let mut cache = self.connector_cache.write().unwrap();
|
502 + | // Short-circuit if another thread already wrote a connector to the cache for this key
|
503 + | if !cache.contains_key(&key) {
|
504 + | let start = components.time_source().map(|ts| ts.now());
|
505 + | let connector = (self.connector_fn)(
|
506 + | self.client_builder.clone(),
|
507 + | Some(settings),
|
508 + | Some(components),
|
509 + | );
|
510 + | let end = components.time_source().map(|ts| ts.now());
|
511 + | if let (Some(start), Some(end)) = (start, end) {
|
512 + | if let Ok(elapsed) = end.duration_since(start) {
|
513 + | tracing::debug!("new connector created in {:?}", elapsed);
|
514 + | }
|
515 + | }
|
516 + | let connector = SharedHttpConnector::new(connector);
|
517 + | cache.insert(key.clone(), connector);
|
518 + | }
|
519 + | connector = cache.get(&key).cloned();
|
520 + | }
|
521 + |
|
522 + | connector.expect("cache populated above")
|
523 + | }
|
524 + |
|
525 + | fn validate_base_client_config(
|
526 + | &self,
|
527 + | _: &RuntimeComponentsBuilder,
|
528 + | _: &ConfigBag,
|
529 + | ) -> Result<(), BoxError> {
|
530 + | // Initialize the TCP connector at this point so that native certs load
|
531 + | // at client initialization time instead of upon first request. We do it
|
532 + | // here rather than at construction so that it won't run if this is not
|
533 + | // the selected HTTP client for the base config (for example, if this was
|
534 + | // the default HTTP client, and it was overridden by a later plugin).
|
535 + | let _ = (self.connector_fn)(self.client_builder.clone(), None, None);
|
536 + | Ok(())
|
537 + | }
|
538 + |
|
539 + | fn connector_metadata(&self) -> Option<ConnectorMetadata> {
|
540 + | Some(ConnectorMetadata::new("hyper", Some(Cow::Borrowed("1.x"))))
|
541 + | }
|
542 + | }
|
543 + |
|
544 + | /// Builder for a hyper-backed [`HttpClient`] implementation.
|
545 + | ///
|
546 + | /// This builder can be used to customize the underlying TCP connector used, as well as
|
547 + | /// hyper client configuration.
|
548 + | ///
|
549 + | /// # Examples
|
550 + | ///
|
551 + | /// Construct a Hyper client with the RusTLS TLS implementation.
|
552 + | /// This can be useful when you want to share a Hyper connector between multiple
|
553 + | /// generated Smithy clients.
|
554 + | #[derive(Clone, Default, Debug)]
|
555 + | pub struct Builder<Tls = TlsUnset> {
|
556 + | client_builder: Option<hyper_util::client::legacy::Builder>,
|
557 + | #[allow(unused)]
|
558 + | tls_provider: Tls,
|
559 + | }
|
560 + |
|
561 + | cfg_tls! {
|
562 + | mod dns;
|
563 + | use aws_smithy_runtime_api::client::dns::ResolveDns;
|
564 + |
|
565 + | impl ConnectorBuilder<TlsProviderSelected> {
|
566 + | /// Build a [`Connector`] that will use the default DNS resolver implementation.
|
567 + | pub fn build(self) -> Connector {
|
568 + | let http_connector = self.base_connector();
|
569 + | self.build_https(http_connector)
|
570 + | }
|
571 + |
|
572 + | /// Configure the TLS context
|
573 + | pub fn tls_context(mut self, ctx: TlsContext) -> Self {
|
574 + | self.tls.context = ctx;
|
575 + | self
|
576 + | }
|
577 + |
|
578 + | /// Configure the TLS context
|
579 + | pub fn set_tls_context(&mut self, ctx: TlsContext) -> &mut Self {
|
580 + | self.tls.context = ctx;
|
581 + | self
|
582 + | }
|
583 + |
|
584 + | /// Build a [`Connector`] that will use the given DNS resolver implementation.
|
585 + | pub fn build_with_resolver<R: ResolveDns + Clone + 'static>(self, resolver: R) -> Connector {
|
586 + | use crate::client::dns::HyperUtilResolver;
|
587 + | let http_connector = self.base_connector_with_resolver(HyperUtilResolver { resolver });
|
588 + | self.build_https(http_connector)
|
589 + | }
|
590 + |
|
591 + | fn build_https<R>(self, http_connector: HyperHttpConnector<R>) -> Connector
|
592 + | where
|
593 + | R: Clone + Send + Sync + 'static,
|
594 + | R: tower::Service<hyper_util::client::legacy::connect::dns::Name>,
|
595 + | R::Response: Iterator<Item = std::net::SocketAddr>,
|
596 + | R::Future: Send,
|
597 + | R::Error: Into<Box<dyn Error + Send + Sync>>,
|
598 + | {
|
599 + | match &self.tls.provider {
|
600 + | // TODO(hyper1) - fix cfg_rustls! to allow matching on patterns so we can re-use it and not duplicate these cfg matches everywhere
|
601 + | #[cfg(any(
|
602 + | feature = "rustls-aws-lc",
|
603 + | feature = "rustls-aws-lc-fips",
|
604 + | feature = "rustls-ring"
|
605 + | ))]
|
606 + | tls::Provider::Rustls(crypto_mode) => {
|
607 + | let https_connector = tls::rustls_provider::build_connector::wrap_connector(
|
608 + | http_connector,
|
609 + | crypto_mode.clone(),
|
610 + | &self.tls.context,
|
611 + | );
|
612 + | self.wrap_connector(https_connector)
|
613 + | },
|
614 + | #[cfg(feature = "s2n-tls")]
|
615 + | tls::Provider::S2nTls => {
|
616 + | let https_connector = tls::s2n_tls_provider::build_connector::wrap_connector(http_connector, &self.tls.context);
|
617 + | self.wrap_connector(https_connector)
|
618 + | }
|
619 + | }
|
620 + | }
|
621 + | }
|
622 + |
|
623 + | impl Builder<TlsProviderSelected> {
|
624 + | /// Create an HTTPS client with the selected TLS provider.
|
625 + | ///
|
626 + | /// The trusted certificates will be loaded later when this becomes the selected
|
627 + | /// HTTP client for a Smithy client.
|
628 + | pub fn build_https(self) -> SharedHttpClient {
|
629 + | build_with_conn_fn(
|
630 + | self.client_builder,
|
631 + | move |client_builder, settings, runtime_components| {
|
632 + | let builder = new_conn_builder(client_builder, settings, runtime_components)
|
633 + | .tls_provider(self.tls_provider.provider.clone())
|
634 + | .tls_context(self.tls_provider.context.clone());
|
635 + | builder.build()
|
636 + | },
|
637 + | )
|
638 + | }
|
639 + |
|
640 + | /// Create an HTTPS client using a custom DNS resolver
|
641 + | pub fn build_with_resolver(
|
642 + | self,
|
643 + | resolver: impl ResolveDns + Clone + 'static,
|
644 + | ) -> SharedHttpClient {
|
645 + | build_with_conn_fn(
|
646 + | self.client_builder,
|
647 + | move |client_builder, settings, runtime_components| {
|
648 + | let builder = new_conn_builder(client_builder, settings, runtime_components)
|
649 + | .tls_provider(self.tls_provider.provider.clone())
|
650 + | .tls_context(self.tls_provider.context.clone());
|
651 + | builder.build_with_resolver(resolver.clone())
|
652 + | },
|
653 + | )
|
654 + | }
|
655 + |
|
656 + | /// Configure the TLS context
|
657 + | pub fn tls_context(mut self, ctx: TlsContext) -> Self {
|
658 + | self.tls_provider.context = ctx;
|
659 + | self
|
660 + | }
|
661 + | }
|
662 + | }
|
663 + |
|
664 + | impl Builder<TlsUnset> {
|
665 + | /// Creates a new builder.
|
666 + | pub fn new() -> Self {
|
667 + | Self::default()
|
668 + | }
|
669 + |
|
670 + | /// Build a new HTTP client without TLS enabled
|
671 + | #[doc(hidden)]
|
672 + | pub fn build_http(self) -> SharedHttpClient {
|
673 + | build_with_conn_fn(
|
674 + | self.client_builder,
|
675 + | move |client_builder, settings, runtime_components| {
|
676 + | let builder = new_conn_builder(client_builder, settings, runtime_components);
|
677 + | builder.build_http()
|
678 + | },
|
679 + | )
|
680 + | }
|
681 + |
|
682 + | /// Set the TLS implementation to use
|
683 + | pub fn tls_provider(self, provider: tls::Provider) -> Builder<TlsProviderSelected> {
|
684 + | Builder {
|
685 + | client_builder: self.client_builder,
|
686 + | tls_provider: TlsProviderSelected {
|
687 + | provider,
|
688 + | context: TlsContext::default(),
|
689 + | },
|
690 + | }
|
691 + | }
|
692 + | }
|
693 + |
|
694 + | pub(crate) fn build_with_conn_fn<F>(
|
695 + | client_builder: Option<hyper_util::client::legacy::Builder>,
|
696 + | connector_fn: F,
|
697 + | ) -> SharedHttpClient
|
698 + | where
|
699 + | F: Fn(
|
700 + | hyper_util::client::legacy::Builder,
|
701 + | Option<&HttpConnectorSettings>,
|
702 + | Option<&RuntimeComponents>,
|
703 + | ) -> Connector
|
704 + | + Send
|
705 + | + Sync
|
706 + | + 'static,
|
707 + | {
|
708 + | SharedHttpClient::new(HyperClient {
|
709 + | connector_cache: RwLock::new(HashMap::new()),
|
710 + | client_builder: client_builder
|
711 + | .unwrap_or_else(|| hyper_util::client::legacy::Builder::new(TokioExecutor::new())),
|
712 + | connector_fn,
|
713 + | })
|
714 + | }
|
715 + |
|
716 + | #[allow(dead_code)]
|
717 + | pub(crate) fn build_with_tcp_conn_fn<C, F>(
|
718 + | client_builder: Option<hyper_util::client::legacy::Builder>,
|
719 + | tcp_connector_fn: F,
|
720 + | ) -> SharedHttpClient
|
721 + | where
|
722 + | F: Fn() -> C + Send + Sync + 'static,
|
723 + | C: Clone + Send + Sync + 'static,
|
724 + | C: tower::Service<Uri>,
|
725 + | C::Response: Connection + Read + Write + Send + Sync + Unpin + 'static,
|
726 + | C::Future: Unpin + Send + 'static,
|
727 + | C::Error: Into<BoxError>,
|
728 + | C: Connect,
|
729 + | {
|
730 + | build_with_conn_fn(
|
731 + | client_builder,
|
732 + | move |client_builder, settings, runtime_components| {
|
733 + | let builder = new_conn_builder(client_builder, settings, runtime_components);
|
734 + | builder.wrap_connector(tcp_connector_fn())
|
735 + | },
|
736 + | )
|
737 + | }
|
738 + |
|
739 + | fn new_conn_builder(
|
740 + | client_builder: hyper_util::client::legacy::Builder,
|
741 + | settings: Option<&HttpConnectorSettings>,
|
742 + | runtime_components: Option<&RuntimeComponents>,
|
743 + | ) -> ConnectorBuilder {
|
744 + | let mut builder = Connector::builder().hyper_builder(client_builder);
|
745 + | builder.set_connector_settings(settings.cloned());
|
746 + | if let Some(components) = runtime_components {
|
747 + | builder.set_sleep_impl(components.sleep_impl());
|
748 + | }
|
749 + | builder
|
750 + | }
|
751 + |
|
752 + | #[cfg(test)]
|
753 + | mod test {
|
754 + | use std::io::{Error, ErrorKind};
|
755 + | use std::pin::Pin;
|
756 + | use std::sync::atomic::{AtomicU32, Ordering};
|
757 + | use std::sync::Arc;
|
758 + | use std::task::{Context, Poll};
|
759 + |
|
760 + | use aws_smithy_async::assert_elapsed;
|
761 + | use aws_smithy_async::rt::sleep::TokioSleep;
|
762 + | use aws_smithy_async::time::SystemTimeSource;
|
763 + | use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
|
764 + | use http_1x::Uri;
|
765 + | use hyper::rt::ReadBufCursor;
|
766 + | use hyper_util::client::legacy::connect::Connected;
|
767 + |
|
768 + | use crate::client::timeout::test::NeverConnects;
|
769 + |
|
770 + | use super::*;
|
771 + |
|
772 + | #[tokio::test]
|
773 + | async fn connector_selection() {
|
774 + | // Create a client that increments a count every time it creates a new Connector
|
775 + | let creation_count = Arc::new(AtomicU32::new(0));
|
776 + | let http_client = build_with_tcp_conn_fn(None, {
|
777 + | let count = creation_count.clone();
|
778 + | move || {
|
779 + | count.fetch_add(1, Ordering::Relaxed);
|
780 + | NeverConnects
|
781 + | }
|
782 + | });
|
783 + |
|
784 + | // This configuration should result in 4 separate connectors with different timeout settings
|
785 + | let settings = [
|
786 + | HttpConnectorSettings::builder()
|
787 + | .connect_timeout(Duration::from_secs(3))
|
788 + | .build(),
|
789 + | HttpConnectorSettings::builder()
|
790 + | .read_timeout(Duration::from_secs(3))
|
791 + | .build(),
|
792 + | HttpConnectorSettings::builder()
|
793 + | .connect_timeout(Duration::from_secs(3))
|
794 + | .read_timeout(Duration::from_secs(3))
|
795 + | .build(),
|
796 + | HttpConnectorSettings::builder()
|
797 + | .connect_timeout(Duration::from_secs(5))
|
798 + | .read_timeout(Duration::from_secs(3))
|
799 + | .build(),
|
800 + | ];
|
801 + |
|
802 + | // Kick off thousands of parallel tasks that will try to create a connector
|
803 + | let components = RuntimeComponentsBuilder::for_tests()
|
804 + | .with_time_source(Some(SystemTimeSource::new()))
|
805 + | .build()
|
806 + | .unwrap();
|
807 + | let mut handles = Vec::new();
|
808 + | for setting in &settings {
|
809 + | for _ in 0..1000 {
|
810 + | let client = http_client.clone();
|
811 + | handles.push(tokio::spawn({
|
812 + | let setting = setting.clone();
|
813 + | let components = components.clone();
|
814 + | async move {
|
815 + | let _ = client.http_connector(&setting, &components);
|
816 + | }
|
817 + | }));
|
818 + | }
|
819 + | }
|
820 + | for handle in handles {
|
821 + | handle.await.unwrap();
|
822 + | }
|
823 + |
|
824 + | // Verify only 4 connectors were created amidst the chaos
|
825 + | assert_eq!(4, creation_count.load(Ordering::Relaxed));
|
826 + | }
|
827 + |
|
828 + | #[tokio::test]
|
829 + | async fn hyper_io_error() {
|
830 + | let connector = TestConnection {
|
831 + | inner: HangupStream,
|
832 + | };
|
833 + | let adapter = Connector::builder().wrap_connector(connector).adapter;
|
834 + | let err = adapter
|
835 + | .call(HttpRequest::get("https://socket-hangup.com").unwrap())
|
836 + | .await
|
837 + | .expect_err("socket hangup");
|
838 + | assert!(err.is_io(), "unexpected error type: {:?}", err);
|
839 + | }
|
840 + |
|
841 + | // ---- machinery to make a Hyper connector that responds with an IO Error
|
842 + | #[derive(Clone)]
|
843 + | struct HangupStream;
|
844 + |
|
845 + | impl Connection for HangupStream {
|
846 + | fn connected(&self) -> Connected {
|
847 + | Connected::new()
|
848 + | }
|
849 + | }
|
850 + |
|
851 + | impl Read for HangupStream {
|
852 + | fn poll_read(
|
853 + | self: Pin<&mut Self>,
|
854 + | _cx: &mut Context<'_>,
|
855 + | _buf: ReadBufCursor<'_>,
|
856 + | ) -> Poll<std::io::Result<()>> {
|
857 + | Poll::Ready(Err(Error::new(
|
858 + | ErrorKind::ConnectionReset,
|
859 + | "connection reset",
|
860 + | )))
|
861 + | }
|
862 + | }
|
863 + |
|
864 + | impl Write for HangupStream {
|
865 + | fn poll_write(
|
866 + | self: Pin<&mut Self>,
|
867 + | _cx: &mut Context<'_>,
|
868 + | _buf: &[u8],
|
869 + | ) -> Poll<Result<usize, Error>> {
|
870 + | Poll::Pending
|
871 + | }
|
872 + |
|
873 + | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
874 + | Poll::Pending
|
875 + | }
|
876 + |
|
877 + | fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
878 + | Poll::Pending
|
879 + | }
|
880 + | }
|
881 + |
|
882 + | #[derive(Clone)]
|
883 + | struct TestConnection<T> {
|
884 + | inner: T,
|
885 + | }
|
886 + |
|
887 + | impl<T> tower::Service<Uri> for TestConnection<T>
|
888 + | where
|
889 + | T: Clone + Connection,
|
890 + | {
|
891 + | type Response = T;
|
892 + | type Error = BoxError;
|
893 + | type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
894 + |
|
895 + | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
896 + | Poll::Ready(Ok(()))
|
897 + | }
|
898 + |
|
899 + | fn call(&mut self, _req: Uri) -> Self::Future {
|
900 + | std::future::ready(Ok(self.inner.clone()))
|
901 + | }
|
902 + | }
|
903 + |
|
904 + | #[tokio::test]
|
905 + | async fn http_connect_timeout_works() {
|
906 + | let tcp_connector = NeverConnects::default();
|
907 + | let connector_settings = HttpConnectorSettings::builder()
|
908 + | .connect_timeout(Duration::from_secs(1))
|
909 + | .build();
|
910 + | let hyper = Connector::builder()
|
911 + | .connector_settings(connector_settings)
|
912 + | .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
|
913 + | .wrap_connector(tcp_connector)
|
914 + | .adapter;
|
915 + | let now = tokio::time::Instant::now();
|
916 + | tokio::time::pause();
|
917 + | let resp = hyper
|
918 + | .call(HttpRequest::get("https://static-uri.com").unwrap())
|
919 + | .await
|
920 + | .unwrap_err();
|
921 + | assert!(
|
922 + | resp.is_timeout(),
|
923 + | "expected resp.is_timeout() to be true but it was false, resp == {:?}",
|
924 + | resp
|
925 + | );
|
926 + | let message = DisplayErrorContext(&resp).to_string();
|
927 + | let expected = "timeout: client error (Connect): HTTP connect timeout occurred after 1s";
|
928 + | assert!(
|
929 + | message.contains(expected),
|
930 + | "expected '{message}' to contain '{expected}'"
|
931 + | );
|
932 + | assert_elapsed!(now, Duration::from_secs(1));
|
933 + | }
|
934 + |
|
935 + | #[tokio::test]
|
936 + | async fn http_read_timeout_works() {
|
937 + | let tcp_connector = crate::client::timeout::test::NeverReplies;
|
938 + | let connector_settings = HttpConnectorSettings::builder()
|
939 + | .connect_timeout(Duration::from_secs(1))
|
940 + | .read_timeout(Duration::from_secs(2))
|
941 + | .build();
|
942 + | let hyper = Connector::builder()
|
943 + | .connector_settings(connector_settings)
|
944 + | .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
|
945 + | .wrap_connector(tcp_connector)
|
946 + | .adapter;
|
947 + | let now = tokio::time::Instant::now();
|
948 + | tokio::time::pause();
|
949 + | let err = hyper
|
950 + | .call(HttpRequest::get("https://fake-uri.com").unwrap())
|
951 + | .await
|
952 + | .unwrap_err();
|
953 + | assert!(
|
954 + | err.is_timeout(),
|
955 + | "expected err.is_timeout() to be true but it was false, err == {err:?}",
|
956 + | );
|
957 + | let message = format!("{}", DisplayErrorContext(&err));
|
958 + | let expected = "timeout: HTTP read timeout occurred after 2s";
|
959 + | assert!(
|
960 + | message.contains(expected),
|
961 + | "expected '{message}' to contain '{expected}'"
|
962 + | );
|
963 + | assert_elapsed!(now, Duration::from_secs(2));
|
964 + | }
|
965 + |
|
966 + | #[cfg(feature = "s2n-tls")]
|
967 + | #[tokio::test]
|
968 + | async fn s2n_tls_provider() {
|
969 + | // Create an HttpConnector with the s2n-tls provider.
|
970 + | let client = Builder::new()
|
971 + | .tls_provider(tls::Provider::S2nTls)
|
972 + | .build_https();
|
973 + | let connector_settings = HttpConnectorSettings::builder().build();
|
974 + |
|
975 + | // HyperClient::http_connector invokes TimeSource::now to determine how long it takes to
|
976 + | // create new HttpConnectors. As such, a real time source must be provided.
|
977 + | let runtime_components = RuntimeComponentsBuilder::for_tests()
|
978 + | .with_time_source(Some(SystemTimeSource::new()))
|
979 + | .build()
|
980 + | .unwrap();
|
981 + |
|
982 + | let connector = client.http_connector(&connector_settings, &runtime_components);
|
983 + |
|
984 + | // Ensure that s2n-tls is used as the underlying TLS provider when selected.
|
985 + | //
|
986 + | // s2n-tls-hyper will error when given an invalid scheme. Ensure that this error is produced
|
987 + | // from s2n-tls-hyper, and not another TLS provider.
|
988 + | let error = connector
|
989 + | .call(HttpRequest::get("notascheme://amazon.com").unwrap())
|
990 + | .await
|
991 + | .unwrap_err();
|
992 + | let error = error.into_source();
|
993 + | let s2n_error = error
|
994 + | .source()
|
995 + | .unwrap()
|
996 + | .downcast_ref::<s2n_tls_hyper::error::Error>()
|
997 + | .unwrap();
|
998 + | assert!(matches!(
|
999 + | s2n_error,
|
1000 + | s2n_tls_hyper::error::Error::InvalidScheme
|
1001 + | ));
|
1002 + | }
|
1003 + | }
|