aws_smithy_http_client/test_util/
wire.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Utilities for mocking at the socket level
7//!
8//! Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
9//! is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`WireMockServer`] binds
10//! to an actual socket on the host.
11//!
12//! # Examples
13//! ```no_run
14//! use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
15//! use aws_smithy_http_client::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
16//! use aws_smithy_http_client::{match_events, ev};
17//! # async fn example() {
18//!
19//! // This connection binds to a local address
20//! let mock = WireMockServer::start(vec![
21//!     ReplayedEvent::status(503),
22//!     ReplayedEvent::status(200)
23//! ]).await;
24//!
25//! # /*
26//! // Create a client using the wire mock
27//! let config = my_generated_client::Config::builder()
28//!     .http_client(mock.http_client())
29//!     .build();
30//! let client = Client::from_conf(config);
31//!
32//! // ... do something with <client>
33//! # */
34//!
35//! // assert that you got the events you expected
36//! match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
37//! # }
38//! ```
39
40#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62/// An event recorded by [`WireMockServer`].
63#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66    DnsLookup(String),
67    NewConnection,
68    Response(ReplayedEvent),
69}
70
71type Matcher = (
72    Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73    &'static str,
74);
75
76/// This method should only be used by the macro
77pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78    let mut events_iter = events.iter();
79    let mut matcher_iter = matchers.iter();
80    let mut idx = -1;
81    loop {
82        idx += 1;
83        let bail = |err: Box<dyn Error>| {
84            panic!("failed on event {idx}:\n  {err}\n  actual recorded events: {events:?}")
85        };
86        match (events_iter.next(), matcher_iter.next()) {
87            (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
88            (None, None) => return,
89            (Some(event), None) => {
90                bail(format!("got {event:?} but no more events were expected").into())
91            }
92            (None, Some((_expect, msg))) => {
93                bail(format!("expected {msg:?} but no more events were expected").into())
94            }
95        }
96    }
97}
98
99#[macro_export]
100macro_rules! matcher {
101    ($expect:tt) => {
102        (
103            Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
104                if !matches!(event, $expect) {
105                    return Err(
106                        format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
107                    );
108                }
109                Ok(())
110            }),
111            stringify!($expect),
112        )
113    };
114}
115
116/// Helper macro to generate a series of test expectations
117#[macro_export]
118macro_rules! match_events {
119        ($( $expect:pat),*) => {
120            |events| {
121                $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
122            }
123        };
124    }
125
126/// Helper to generate match expressions for events
127#[macro_export]
128macro_rules! ev {
129    (http($status:expr)) => {
130        $crate::test_util::wire::RecordedEvent::Response(
131            $crate::test_util::wire::ReplayedEvent::HttpResponse {
132                status: $status,
133                ..
134            },
135        )
136    };
137    (dns) => {
138        $crate::test_util::wire::RecordedEvent::DnsLookup(_)
139    };
140    (connect) => {
141        $crate::test_util::wire::RecordedEvent::NewConnection
142    };
143    (timeout) => {
144        $crate::test_util::wire::RecordedEvent::Response(
145            $crate::test_util::wire::ReplayedEvent::Timeout,
146        )
147    };
148}
149
150pub use {ev, match_events, matcher};
151
152#[non_exhaustive]
153#[derive(Clone, Debug, PartialEq, Eq)]
154pub enum ReplayedEvent {
155    Timeout,
156    HttpResponse { status: u16, body: Bytes },
157}
158
159impl ReplayedEvent {
160    pub fn ok() -> Self {
161        Self::HttpResponse {
162            status: 200,
163            body: Bytes::new(),
164        }
165    }
166
167    pub fn with_body(body: impl AsRef<[u8]>) -> Self {
168        Self::HttpResponse {
169            status: 200,
170            body: Bytes::copy_from_slice(body.as_ref()),
171        }
172    }
173
174    pub fn status(status: u16) -> Self {
175        Self::HttpResponse {
176            status,
177            body: Bytes::new(),
178        }
179    }
180}
181
182/// Test server that binds to 127.0.0.1:0
183///
184/// See the [module docs](crate::test_util::wire) for a usage example.
185///
186/// Usage:
187/// - Call [`WireMockServer::start`] to start the server
188/// - Use [`WireMockServer::http_client`] or [`dns_resolver`](WireMockServer::dns_resolver) to configure your client.
189/// - Make requests to [`endpoint_url`](WireMockServer::endpoint_url).
190/// - Once the test is complete, retrieve a list of events from [`WireMockServer::events`]
191#[derive(Debug)]
192pub struct WireMockServer {
193    event_log: Arc<Mutex<Vec<RecordedEvent>>>,
194    bind_addr: SocketAddr,
195    // when the sender is dropped, that stops the server
196    shutdown_hook: oneshot::Sender<()>,
197}
198
199#[derive(Debug, Clone)]
200struct SharedGraceful {
201    graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
202}
203
204impl SharedGraceful {
205    fn new() -> Self {
206        Self {
207            graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
208        }
209    }
210
211    fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
212        let graceful = self.graceful.lock().unwrap();
213        graceful
214            .as_ref()
215            .expect("graceful not shutdown")
216            .watch(conn)
217    }
218
219    async fn shutdown(&self) {
220        let graceful = { self.graceful.lock().unwrap().take() };
221
222        if let Some(graceful) = graceful {
223            graceful.shutdown().await;
224        }
225    }
226}
227
228impl WireMockServer {
229    /// Start a wire mock server with the given events to replay.
230    pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
231        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
232        let (tx, mut rx) = oneshot::channel();
233        let listener_addr = listener.local_addr().unwrap();
234        response_events.reverse();
235        let response_events = Arc::new(Mutex::new(response_events));
236        let handler_events = response_events;
237        let wire_events = Arc::new(Mutex::new(vec![]));
238        let wire_log_for_service = wire_events.clone();
239        let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
240        let graceful = SharedGraceful::new();
241        let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
242            TokioExecutor::new(),
243        ));
244
245        let server = async move {
246            let poisoned_conns = poisoned_conns.clone();
247            let events = handler_events.clone();
248            let wire_log = wire_log_for_service.clone();
249            loop {
250                tokio::select! {
251                    Ok((stream, remote_addr)) = listener.accept() => {
252                        tracing::info!("established connection: {:?}", remote_addr);
253                        let poisoned_conns = poisoned_conns.clone();
254                        let events = events.clone();
255                        let wire_log = wire_log.clone();
256                        wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
257                        let io = TokioIo::new(stream);
258
259                        let svc = service_fn(move |_req| {
260                            let poisoned_conns = poisoned_conns.clone();
261                            let events = events.clone();
262                            let wire_log = wire_log.clone();
263                            if poisoned_conns.lock().unwrap().contains(&remote_addr) {
264                                tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
265                                panic!("poisoned connection was reused!");
266                            }
267                            let next_event = events.clone().lock().unwrap().pop();
268                            async move {
269                                let next_event = next_event
270                                    .unwrap_or_else(|| panic!("no more events! Log: {wire_log:?}"));
271
272                                wire_log
273                                    .lock()
274                                    .unwrap()
275                                    .push(RecordedEvent::Response(next_event.clone()));
276
277                                if next_event == ReplayedEvent::Timeout {
278                                    tracing::info!("{} is poisoned", remote_addr);
279                                    poisoned_conns.lock().unwrap().insert(remote_addr);
280                                }
281                                tracing::debug!("replying with {:?}", next_event);
282                                let event = generate_response_event(next_event).await;
283                                dbg!(event)
284                            }
285                        });
286
287                        let conn_builder = conn_builder.clone();
288                        let graceful = graceful.clone();
289                        tokio::spawn(async move {
290                            let conn = conn_builder.serve_connection(io, svc);
291                            let fut = graceful.watch(conn);
292                            if let Err(e) = fut.await {
293                                panic!("Error serving connection: {e:?}");
294                            }
295                        });
296                    },
297                    _ = &mut rx => {
298                        tracing::info!("wire server: shutdown signalled");
299                        graceful.shutdown().await;
300                        tracing::info!("wire server: shutdown complete!");
301                        break;
302                    }
303                }
304            }
305        };
306
307        tokio::spawn(server);
308        Self {
309            event_log: wire_events,
310            bind_addr: listener_addr,
311            shutdown_hook: tx,
312        }
313    }
314
315    /// Retrieve the events recorded by this connection
316    pub fn events(&self) -> Vec<RecordedEvent> {
317        self.event_log.lock().unwrap().clone()
318    }
319
320    fn bind_addr(&self) -> SocketAddr {
321        self.bind_addr
322    }
323
324    pub fn dns_resolver(&self) -> LoggingDnsResolver {
325        let event_log = self.event_log.clone();
326        let bind_addr = self.bind_addr;
327        LoggingDnsResolver(InnerDnsResolver {
328            log: event_log,
329            socket_addr: bind_addr,
330        })
331    }
332
333    /// Prebuilt [`HttpClient`](aws_smithy_runtime_api::client::http::HttpClient) with correctly wired DNS resolver.
334    ///
335    /// **Note**: This must be used in tandem with [`Self::dns_resolver`]
336    pub fn http_client(&self) -> SharedHttpClient {
337        let resolver = self.dns_resolver();
338        crate::client::build_with_tcp_conn_fn(None, None, move || {
339            hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
340                resolver.clone().0,
341            )
342        })
343    }
344
345    /// Endpoint to use when connecting
346    ///
347    /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
348    pub fn endpoint_url(&self) -> String {
349        format!(
350            "http://this-url-is-converted-to-localhost.com:{}",
351            self.bind_addr().port()
352        )
353    }
354
355    /// Shuts down the mock server.
356    pub fn shutdown(self) {
357        let _ = self.shutdown_hook.send(());
358    }
359}
360
361async fn generate_response_event(
362    event: ReplayedEvent,
363) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
364    let resp = match event {
365        ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
366            .status(status)
367            .body(Full::new(body))
368            .unwrap(),
369        ReplayedEvent::Timeout => {
370            Never::new().await;
371            unreachable!()
372        }
373    };
374    Ok::<_, Infallible>(resp)
375}
376
377/// DNS resolver that keeps a log of all lookups
378///
379/// Regardless of what hostname is requested, it will always return the same socket address.
380#[derive(Clone, Debug)]
381pub struct LoggingDnsResolver(InnerDnsResolver);
382
383// internal implementation so we don't have to expose hyper_util
384#[derive(Clone, Debug)]
385struct InnerDnsResolver {
386    log: Arc<Mutex<Vec<RecordedEvent>>>,
387    socket_addr: SocketAddr,
388}
389
390impl tower::Service<Name> for InnerDnsResolver {
391    type Response = Once<SocketAddr>;
392    type Error = Infallible;
393    type Future = BoxFuture<'static, Self::Response, Self::Error>;
394
395    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
396        Poll::Ready(Ok(()))
397    }
398
399    fn call(&mut self, req: Name) -> Self::Future {
400        let socket_addr = self.socket_addr;
401        let log = self.log.clone();
402        Box::pin(async move {
403            println!("looking up {req:?}, replying with {socket_addr:?}");
404            log.lock()
405                .unwrap()
406                .push(RecordedEvent::DnsLookup(req.to_string()));
407            Ok(std::iter::once(socket_addr))
408        })
409    }
410}
411
412#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
413impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
414    type Response = Once<SocketAddr>;
415    type Error = Infallible;
416    type Future = BoxFuture<'static, Self::Response, Self::Error>;
417
418    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
419        self.0.poll_ready(cx)
420    }
421
422    fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
423        use std::str::FromStr;
424        let adapter = Name::from_str(req.as_str()).expect("valid conversion");
425        self.0.call(adapter)
426    }
427}