aws_smithy_http_client/test_util/
wire.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Utilities for mocking at the socket level
7//!
8//! Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
9//! is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`WireMockServer`] binds
10//! to an actual socket on the host.
11//!
12//! # Examples
13//! ```no_run
14//! use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
15//! use aws_smithy_http_client::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
16//! use aws_smithy_http_client::{match_events, ev};
17//! # async fn example() {
18//!
19//! // This connection binds to a local address
20//! let mock = WireMockServer::start(vec![
21//!     ReplayedEvent::status(503),
22//!     ReplayedEvent::status(200)
23//! ]).await;
24//!
25//! # /*
26//! // Create a client using the wire mock
27//! let config = my_generated_client::Config::builder()
28//!     .http_client(mock.http_client())
29//!     .build();
30//! let client = Client::from_conf(config);
31//!
32//! // ... do something with <client>
33//! # */
34//!
35//! // assert that you got the events you expected
36//! match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
37//! # }
38//! ```
39
40#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62/// An event recorded by [`WireMockServer`].
63#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66    DnsLookup(String),
67    NewConnection,
68    Response(ReplayedEvent),
69}
70
71type Matcher = (
72    Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73    &'static str,
74);
75
76/// This method should only be used by the macro
77pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78    let mut events_iter = events.iter();
79    let mut matcher_iter = matchers.iter();
80    let mut idx = -1;
81    loop {
82        idx += 1;
83        let bail = |err: Box<dyn Error>| {
84            panic!(
85                "failed on event {}:\n  {}\n  actual recorded events: {:?}",
86                idx, err, events
87            )
88        };
89        match (events_iter.next(), matcher_iter.next()) {
90            (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
91            (None, None) => return,
92            (Some(event), None) => {
93                bail(format!("got {:?} but no more events were expected", event).into())
94            }
95            (None, Some((_expect, msg))) => {
96                bail(format!("expected {:?} but no more events were expected", msg).into())
97            }
98        }
99    }
100}
101
102#[macro_export]
103macro_rules! matcher {
104    ($expect:tt) => {
105        (
106            Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
107                if !matches!(event, $expect) {
108                    return Err(
109                        format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
110                    );
111                }
112                Ok(())
113            }),
114            stringify!($expect),
115        )
116    };
117}
118
119/// Helper macro to generate a series of test expectations
120#[macro_export]
121macro_rules! match_events {
122        ($( $expect:pat),*) => {
123            |events| {
124                $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
125            }
126        };
127    }
128
129/// Helper to generate match expressions for events
130#[macro_export]
131macro_rules! ev {
132    (http($status:expr)) => {
133        $crate::test_util::wire::RecordedEvent::Response(
134            $crate::test_util::wire::ReplayedEvent::HttpResponse {
135                status: $status,
136                ..
137            },
138        )
139    };
140    (dns) => {
141        $crate::test_util::wire::RecordedEvent::DnsLookup(_)
142    };
143    (connect) => {
144        $crate::test_util::wire::RecordedEvent::NewConnection
145    };
146    (timeout) => {
147        $crate::test_util::wire::RecordedEvent::Response(
148            $crate::test_util::wire::ReplayedEvent::Timeout,
149        )
150    };
151}
152
153pub use {ev, match_events, matcher};
154
155#[non_exhaustive]
156#[derive(Clone, Debug, PartialEq, Eq)]
157pub enum ReplayedEvent {
158    Timeout,
159    HttpResponse { status: u16, body: Bytes },
160}
161
162impl ReplayedEvent {
163    pub fn ok() -> Self {
164        Self::HttpResponse {
165            status: 200,
166            body: Bytes::new(),
167        }
168    }
169
170    pub fn with_body(body: impl AsRef<[u8]>) -> Self {
171        Self::HttpResponse {
172            status: 200,
173            body: Bytes::copy_from_slice(body.as_ref()),
174        }
175    }
176
177    pub fn status(status: u16) -> Self {
178        Self::HttpResponse {
179            status,
180            body: Bytes::new(),
181        }
182    }
183}
184
185/// Test server that binds to 127.0.0.1:0
186///
187/// See the [module docs](crate::test_util::wire) for a usage example.
188///
189/// Usage:
190/// - Call [`WireMockServer::start`] to start the server
191/// - Use [`WireMockServer::http_client`] or [`dns_resolver`](WireMockServer::dns_resolver) to configure your client.
192/// - Make requests to [`endpoint_url`](WireMockServer::endpoint_url).
193/// - Once the test is complete, retrieve a list of events from [`WireMockServer::events`]
194#[derive(Debug)]
195pub struct WireMockServer {
196    event_log: Arc<Mutex<Vec<RecordedEvent>>>,
197    bind_addr: SocketAddr,
198    // when the sender is dropped, that stops the server
199    shutdown_hook: oneshot::Sender<()>,
200}
201
202#[derive(Debug, Clone)]
203struct SharedGraceful {
204    graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
205}
206
207impl SharedGraceful {
208    fn new() -> Self {
209        Self {
210            graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
211        }
212    }
213
214    fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
215        let graceful = self.graceful.lock().unwrap();
216        graceful
217            .as_ref()
218            .expect("graceful not shutdown")
219            .watch(conn)
220    }
221
222    async fn shutdown(&self) {
223        let graceful = { self.graceful.lock().unwrap().take() };
224
225        if let Some(graceful) = graceful {
226            graceful.shutdown().await;
227        }
228    }
229}
230
231impl WireMockServer {
232    /// Start a wire mock server with the given events to replay.
233    pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
234        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
235        let (tx, mut rx) = oneshot::channel();
236        let listener_addr = listener.local_addr().unwrap();
237        response_events.reverse();
238        let response_events = Arc::new(Mutex::new(response_events));
239        let handler_events = response_events;
240        let wire_events = Arc::new(Mutex::new(vec![]));
241        let wire_log_for_service = wire_events.clone();
242        let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
243        let graceful = SharedGraceful::new();
244        let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
245            TokioExecutor::new(),
246        ));
247
248        let server = async move {
249            let poisoned_conns = poisoned_conns.clone();
250            let events = handler_events.clone();
251            let wire_log = wire_log_for_service.clone();
252            loop {
253                tokio::select! {
254                    Ok((stream, remote_addr)) = listener.accept() => {
255                        tracing::info!("established connection: {:?}", remote_addr);
256                        let poisoned_conns = poisoned_conns.clone();
257                        let events = events.clone();
258                        let wire_log = wire_log.clone();
259                        wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
260                        let io = TokioIo::new(stream);
261
262                        let svc = service_fn(move |_req| {
263                            let poisoned_conns = poisoned_conns.clone();
264                            let events = events.clone();
265                            let wire_log = wire_log.clone();
266                            if poisoned_conns.lock().unwrap().contains(&remote_addr) {
267                                tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
268                                panic!("poisoned connection was reused!");
269                            }
270                            let next_event = events.clone().lock().unwrap().pop();
271                            async move {
272                                let next_event = next_event
273                                    .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
274
275                                wire_log
276                                    .lock()
277                                    .unwrap()
278                                    .push(RecordedEvent::Response(next_event.clone()));
279
280                                if next_event == ReplayedEvent::Timeout {
281                                    tracing::info!("{} is poisoned", remote_addr);
282                                    poisoned_conns.lock().unwrap().insert(remote_addr);
283                                }
284                                tracing::debug!("replying with {:?}", next_event);
285                                let event = generate_response_event(next_event).await;
286                                dbg!(event)
287                            }
288                        });
289
290                        let conn_builder = conn_builder.clone();
291                        let graceful = graceful.clone();
292                        tokio::spawn(async move {
293                            let conn = conn_builder.serve_connection(io, svc);
294                            let fut = graceful.watch(conn);
295                            if let Err(e) = fut.await {
296                                panic!("Error serving connection: {:?}", e);
297                            }
298                        });
299                    },
300                    _ = &mut rx => {
301                        tracing::info!("wire server: shutdown signalled");
302                        graceful.shutdown().await;
303                        tracing::info!("wire server: shutdown complete!");
304                        break;
305                    }
306                }
307            }
308        };
309
310        tokio::spawn(server);
311        Self {
312            event_log: wire_events,
313            bind_addr: listener_addr,
314            shutdown_hook: tx,
315        }
316    }
317
318    /// Retrieve the events recorded by this connection
319    pub fn events(&self) -> Vec<RecordedEvent> {
320        self.event_log.lock().unwrap().clone()
321    }
322
323    fn bind_addr(&self) -> SocketAddr {
324        self.bind_addr
325    }
326
327    pub fn dns_resolver(&self) -> LoggingDnsResolver {
328        let event_log = self.event_log.clone();
329        let bind_addr = self.bind_addr;
330        LoggingDnsResolver(InnerDnsResolver {
331            log: event_log,
332            socket_addr: bind_addr,
333        })
334    }
335
336    /// Prebuilt [`HttpClient`](aws_smithy_runtime_api::client::http::HttpClient) with correctly wired DNS resolver.
337    ///
338    /// **Note**: This must be used in tandem with [`Self::dns_resolver`]
339    pub fn http_client(&self) -> SharedHttpClient {
340        let resolver = self.dns_resolver();
341        crate::client::build_with_tcp_conn_fn(None, move || {
342            hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
343                resolver.clone().0,
344            )
345        })
346    }
347
348    /// Endpoint to use when connecting
349    ///
350    /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
351    pub fn endpoint_url(&self) -> String {
352        format!(
353            "http://this-url-is-converted-to-localhost.com:{}",
354            self.bind_addr().port()
355        )
356    }
357
358    /// Shuts down the mock server.
359    pub fn shutdown(self) {
360        let _ = self.shutdown_hook.send(());
361    }
362}
363
364async fn generate_response_event(
365    event: ReplayedEvent,
366) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
367    let resp = match event {
368        ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
369            .status(status)
370            .body(Full::new(body))
371            .unwrap(),
372        ReplayedEvent::Timeout => {
373            Never::new().await;
374            unreachable!()
375        }
376    };
377    Ok::<_, Infallible>(resp)
378}
379
380/// DNS resolver that keeps a log of all lookups
381///
382/// Regardless of what hostname is requested, it will always return the same socket address.
383#[derive(Clone, Debug)]
384pub struct LoggingDnsResolver(InnerDnsResolver);
385
386// internal implementation so we don't have to expose hyper_util
387#[derive(Clone, Debug)]
388struct InnerDnsResolver {
389    log: Arc<Mutex<Vec<RecordedEvent>>>,
390    socket_addr: SocketAddr,
391}
392
393impl tower::Service<Name> for InnerDnsResolver {
394    type Response = Once<SocketAddr>;
395    type Error = Infallible;
396    type Future = BoxFuture<'static, Self::Response, Self::Error>;
397
398    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
399        Poll::Ready(Ok(()))
400    }
401
402    fn call(&mut self, req: Name) -> Self::Future {
403        let socket_addr = self.socket_addr;
404        let log = self.log.clone();
405        Box::pin(async move {
406            println!("looking up {:?}, replying with {:?}", req, socket_addr);
407            log.lock()
408                .unwrap()
409                .push(RecordedEvent::DnsLookup(req.to_string()));
410            Ok(std::iter::once(socket_addr))
411        })
412    }
413}
414
415#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
416impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
417    type Response = Once<SocketAddr>;
418    type Error = Infallible;
419    type Future = BoxFuture<'static, Self::Response, Self::Error>;
420
421    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
422        self.0.poll_ready(cx)
423    }
424
425    fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
426        use std::str::FromStr;
427        let adapter = Name::from_str(req.as_str()).expect("valid conversion");
428        self.0.call(adapter)
429    }
430}