5 5 |
|
6 6 | //! Utilities for mocking at the socket level
|
7 7 | //!
|
8 8 | //! Other tools in this module actually operate at the `http::Request` / `http::Response` level. This
|
9 9 | //! is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`WireMockServer`] binds
|
10 10 | //! to an actual socket on the host.
|
11 11 | //!
|
12 12 | //! # Examples
|
13 13 | //! ```no_run
|
14 14 | //! use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
|
15 - | //! use aws_smithy_runtime::client::http::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
|
16 - | //! use aws_smithy_runtime::{match_events, ev};
|
15 + | //! use aws_smithy_http_client::test_util::wire::{check_matches, ReplayedEvent, WireMockServer};
|
16 + | //! use aws_smithy_http_client::{match_events, ev};
|
17 17 | //! # async fn example() {
|
18 18 | //!
|
19 19 | //! // This connection binds to a local address
|
20 20 | //! let mock = WireMockServer::start(vec![
|
21 21 | //! ReplayedEvent::status(503),
|
22 22 | //! ReplayedEvent::status(200)
|
23 23 | //! ]).await;
|
24 24 | //!
|
25 25 | //! # /*
|
26 26 | //! // Create a client using the wire mock
|
27 27 | //! let config = my_generated_client::Config::builder()
|
28 28 | //! .http_client(mock.http_client())
|
29 29 | //! .build();
|
30 30 | //! let client = Client::from_conf(config);
|
31 31 | //!
|
32 32 | //! // ... do something with <client>
|
33 33 | //! # */
|
34 34 | //!
|
35 35 | //! // assert that you got the events you expected
|
36 36 | //! match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events());
|
37 37 | //! # }
|
38 38 | //! ```
|
39 39 |
|
40 40 | #![allow(missing_docs)]
|
41 41 |
|
42 - | use crate::client::http::hyper_014::HyperClientBuilder;
|
43 42 | use aws_smithy_async::future::never::Never;
|
44 43 | use aws_smithy_async::future::BoxFuture;
|
45 44 | use aws_smithy_runtime_api::client::http::SharedHttpClient;
|
46 - | use aws_smithy_runtime_api::shared::IntoShared;
|
47 45 | use bytes::Bytes;
|
48 - | use hyper_0_14::client::connect::dns::Name;
|
49 - | use hyper_0_14::server::conn::AddrStream;
|
50 - | use hyper_0_14::service::{make_service_fn, service_fn, Service};
|
46 + | use http_body_util::Full;
|
47 + | use hyper::service::service_fn;
|
48 + | use hyper_util::client::legacy::connect::dns::Name;
|
49 + | use hyper_util::rt::{TokioExecutor, TokioIo};
|
50 + | use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
|
51 51 | use std::collections::HashSet;
|
52 52 | use std::convert::Infallible;
|
53 53 | use std::error::Error;
|
54 + | use std::future::Future;
|
54 55 | use std::iter::Once;
|
55 - | use std::net::{SocketAddr, TcpListener};
|
56 + | use std::net::SocketAddr;
|
56 57 | use std::sync::{Arc, Mutex};
|
57 58 | use std::task::{Context, Poll};
|
58 - | use tokio::spawn;
|
59 + | use tokio::net::TcpListener;
|
59 60 | use tokio::sync::oneshot;
|
60 61 |
|
61 62 | /// An event recorded by [`WireMockServer`].
|
62 63 | #[non_exhaustive]
|
63 64 | #[derive(Debug, Clone)]
|
64 65 | pub enum RecordedEvent {
|
65 66 | DnsLookup(String),
|
66 67 | NewConnection,
|
67 68 | Response(ReplayedEvent),
|
68 69 | }
|
69 70 |
|
70 71 | type Matcher = (
|
71 72 | Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
|
72 73 | &'static str,
|
73 74 | );
|
74 75 |
|
75 76 | /// This method should only be used by the macro
|
76 77 | pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
|
77 78 | let mut events_iter = events.iter();
|
78 79 | let mut matcher_iter = matchers.iter();
|
79 80 | let mut idx = -1;
|
80 81 | loop {
|
81 82 | idx += 1;
|
82 83 | let bail = |err: Box<dyn Error>| panic!("failed on event {}:\n {}", idx, err);
|
83 84 | match (events_iter.next(), matcher_iter.next()) {
|
84 85 | (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
|
85 86 | (None, None) => return,
|
86 87 | (Some(event), None) => {
|
87 88 | bail(format!("got {:?} but no more events were expected", event).into())
|
88 89 | }
|
89 90 | (None, Some((_expect, msg))) => {
|
90 91 | bail(format!("expected {:?} but no more events were expected", msg).into())
|
91 92 | }
|
92 93 | }
|
93 94 | }
|
94 95 | }
|
95 96 |
|
96 97 | #[macro_export]
|
97 98 | macro_rules! matcher {
|
98 99 | ($expect:tt) => {
|
99 100 | (
|
100 - | Box::new(
|
101 - | |event: &$crate::client::http::test_util::wire::RecordedEvent| {
|
102 - | if !matches!(event, $expect) {
|
103 - | return Err(format!(
|
104 - | "expected `{}` but got {:?}",
|
105 - | stringify!($expect),
|
106 - | event
|
107 - | )
|
108 - | .into());
|
109 - | }
|
110 - | Ok(())
|
111 - | },
|
112 - | ),
|
101 + | Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
|
102 + | if !matches!(event, $expect) {
|
103 + | return Err(
|
104 + | format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
|
105 + | );
|
106 + | }
|
107 + | Ok(())
|
108 + | }),
|
113 109 | stringify!($expect),
|
114 110 | )
|
115 111 | };
|
116 112 | }
|
117 113 |
|
118 114 | /// Helper macro to generate a series of test expectations
|
119 115 | #[macro_export]
|
120 116 | macro_rules! match_events {
|
121 117 | ($( $expect:pat),*) => {
|
122 118 | |events| {
|
123 - | $crate::client::http::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
|
119 + | $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
|
124 120 | }
|
125 121 | };
|
126 122 | }
|
127 123 |
|
128 124 | /// Helper to generate match expressions for events
|
129 125 | #[macro_export]
|
130 126 | macro_rules! ev {
|
131 127 | (http($status:expr)) => {
|
132 - | $crate::client::http::test_util::wire::RecordedEvent::Response(
|
133 - | $crate::client::http::test_util::wire::ReplayedEvent::HttpResponse {
|
128 + | $crate::test_util::wire::RecordedEvent::Response(
|
129 + | $crate::test_util::wire::ReplayedEvent::HttpResponse {
|
134 130 | status: $status,
|
135 131 | ..
|
136 132 | },
|
137 133 | )
|
138 134 | };
|
139 135 | (dns) => {
|
140 - | $crate::client::http::test_util::wire::RecordedEvent::DnsLookup(_)
|
136 + | $crate::test_util::wire::RecordedEvent::DnsLookup(_)
|
141 137 | };
|
142 138 | (connect) => {
|
143 - | $crate::client::http::test_util::wire::RecordedEvent::NewConnection
|
139 + | $crate::test_util::wire::RecordedEvent::NewConnection
|
144 140 | };
|
145 141 | (timeout) => {
|
146 - | $crate::client::http::test_util::wire::RecordedEvent::Response(
|
147 - | $crate::client::http::test_util::wire::ReplayedEvent::Timeout,
|
142 + | $crate::test_util::wire::RecordedEvent::Response(
|
143 + | $crate::test_util::wire::ReplayedEvent::Timeout,
|
148 144 | )
|
149 145 | };
|
150 146 | }
|
151 147 |
|
152 148 | pub use {ev, match_events, matcher};
|
153 149 |
|
154 150 | #[non_exhaustive]
|
155 151 | #[derive(Clone, Debug, PartialEq, Eq)]
|
156 152 | pub enum ReplayedEvent {
|
157 153 | Timeout,
|
158 154 | HttpResponse { status: u16, body: Bytes },
|
159 155 | }
|
160 156 |
|
161 157 | impl ReplayedEvent {
|
162 158 | pub fn ok() -> Self {
|
163 159 | Self::HttpResponse {
|
164 160 | status: 200,
|
165 161 | body: Bytes::new(),
|
166 162 | }
|
167 163 | }
|
168 164 |
|
169 165 | pub fn with_body(body: impl AsRef<[u8]>) -> Self {
|
170 166 | Self::HttpResponse {
|
171 167 | status: 200,
|
172 168 | body: Bytes::copy_from_slice(body.as_ref()),
|
173 169 | }
|
174 170 | }
|
175 171 |
|
176 172 | pub fn status(status: u16) -> Self {
|
177 173 | Self::HttpResponse {
|
178 174 | status,
|
179 175 | body: Bytes::new(),
|
180 176 | }
|
181 177 | }
|
182 178 | }
|
183 179 |
|
184 180 | /// Test server that binds to 127.0.0.1:0
|
185 181 | ///
|
186 - | /// See the [module docs](crate::client::http::test_util::wire) for a usage example.
|
182 + | /// See the [module docs](crate::test_util::wire) for a usage example.
|
187 183 | ///
|
188 184 | /// Usage:
|
189 185 | /// - Call [`WireMockServer::start`] to start the server
|
190 186 | /// - Use [`WireMockServer::http_client`] or [`dns_resolver`](WireMockServer::dns_resolver) to configure your client.
|
191 187 | /// - Make requests to [`endpoint_url`](WireMockServer::endpoint_url).
|
192 188 | /// - Once the test is complete, retrieve a list of events from [`WireMockServer::events`]
|
193 189 | #[derive(Debug)]
|
194 190 | pub struct WireMockServer {
|
195 191 | event_log: Arc<Mutex<Vec<RecordedEvent>>>,
|
196 192 | bind_addr: SocketAddr,
|
197 193 | // when the sender is dropped, that stops the server
|
198 194 | shutdown_hook: oneshot::Sender<()>,
|
199 195 | }
|
200 196 |
|
197 + | #[derive(Debug, Clone)]
|
198 + | struct SharedGraceful {
|
199 + | graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
|
200 + | }
|
201 + |
|
202 + | impl SharedGraceful {
|
203 + | fn new() -> Self {
|
204 + | Self {
|
205 + | graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
|
206 + | }
|
207 + | }
|
208 + |
|
209 + | fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
|
210 + | let graceful = self.graceful.lock().unwrap();
|
211 + | graceful
|
212 + | .as_ref()
|
213 + | .expect("graceful not shutdown")
|
214 + | .watch(conn)
|
215 + | }
|
216 + |
|
217 + | async fn shutdown(&self) {
|
218 + | let graceful = { self.graceful.lock().unwrap().take() };
|
219 + |
|
220 + | if let Some(graceful) = graceful {
|
221 + | graceful.shutdown().await;
|
222 + | }
|
223 + | }
|
224 + | }
|
225 + |
|
201 226 | impl WireMockServer {
|
202 227 | /// Start a wire mock server with the given events to replay.
|
203 228 | pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
|
204 - | let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
205 - | let (tx, rx) = oneshot::channel();
|
229 + | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
230 + | let (tx, mut rx) = oneshot::channel();
|
206 231 | let listener_addr = listener.local_addr().unwrap();
|
207 232 | response_events.reverse();
|
208 233 | let response_events = Arc::new(Mutex::new(response_events));
|
209 234 | let handler_events = response_events;
|
210 235 | let wire_events = Arc::new(Mutex::new(vec![]));
|
211 236 | let wire_log_for_service = wire_events.clone();
|
212 237 | let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
|
213 - | let make_service = make_service_fn(move |connection: &AddrStream| {
|
238 + | let graceful = SharedGraceful::new();
|
239 + | let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
|
240 + | TokioExecutor::new(),
|
241 + | ));
|
242 + |
|
243 + | let server = async move {
|
214 244 | let poisoned_conns = poisoned_conns.clone();
|
215 245 | let events = handler_events.clone();
|
216 246 | let wire_log = wire_log_for_service.clone();
|
217 - | let remote_addr = connection.remote_addr();
|
218 - | tracing::info!("established connection: {:?}", connection);
|
219 - | wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
|
220 - | async move {
|
221 - | Ok::<_, Infallible>(service_fn(move |_: http_02x::Request<hyper_0_14::Body>| {
|
222 - | if poisoned_conns.lock().unwrap().contains(&remote_addr) {
|
223 - | tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
|
224 - | panic!("poisoned connection was reused!");
|
225 - | }
|
226 - | let next_event = events.clone().lock().unwrap().pop();
|
227 - | let wire_log = wire_log.clone();
|
228 - | let poisoned_conns = poisoned_conns.clone();
|
229 - | async move {
|
230 - | let next_event = next_event
|
231 - | .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
|
232 - | wire_log
|
233 - | .lock()
|
234 - | .unwrap()
|
235 - | .push(RecordedEvent::Response(next_event.clone()));
|
236 - | if next_event == ReplayedEvent::Timeout {
|
237 - | tracing::info!("{} is poisoned", remote_addr);
|
238 - | poisoned_conns.lock().unwrap().insert(remote_addr);
|
239 - | }
|
240 - | tracing::debug!("replying with {:?}", next_event);
|
241 - | let event = generate_response_event(next_event).await;
|
242 - | dbg!(event)
|
247 + | loop {
|
248 + | tokio::select! {
|
249 + | Ok((stream, remote_addr)) = listener.accept() => {
|
250 + | tracing::info!("established connection: {:?}", remote_addr);
|
251 + | let poisoned_conns = poisoned_conns.clone();
|
252 + | let events = events.clone();
|
253 + | let wire_log = wire_log.clone();
|
254 + | wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
|
255 + | let io = TokioIo::new(stream);
|
256 + |
|
257 + | let svc = service_fn(move |_req| {
|
258 + | let poisoned_conns = poisoned_conns.clone();
|
259 + | let events = events.clone();
|
260 + | let wire_log = wire_log.clone();
|
261 + | if poisoned_conns.lock().unwrap().contains(&remote_addr) {
|
262 + | tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
|
263 + | panic!("poisoned connection was reused!");
|
264 + | }
|
265 + | let next_event = events.clone().lock().unwrap().pop();
|
266 + | async move {
|
267 + | let next_event = next_event
|
268 + | .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
|
269 + |
|
270 + | wire_log
|
271 + | .lock()
|
272 + | .unwrap()
|
273 + | .push(RecordedEvent::Response(next_event.clone()));
|
274 + |
|
275 + | if next_event == ReplayedEvent::Timeout {
|
276 + | tracing::info!("{} is poisoned", remote_addr);
|
277 + | poisoned_conns.lock().unwrap().insert(remote_addr);
|
278 + | }
|
279 + | tracing::debug!("replying with {:?}", next_event);
|
280 + | let event = generate_response_event(next_event).await;
|
281 + | dbg!(event)
|
282 + | }
|
283 + | });
|
284 + |
|
285 + | let conn_builder = conn_builder.clone();
|
286 + | let graceful = graceful.clone();
|
287 + | tokio::spawn(async move {
|
288 + | let conn = conn_builder.serve_connection(io, svc);
|
289 + | let fut = graceful.watch(conn);
|
290 + | if let Err(e) = fut.await {
|
291 + | panic!("Error serving connection: {:?}", e);
|
292 + | }
|
293 + | });
|
294 + | },
|
295 + | _ = &mut rx => {
|
296 + | tracing::info!("wire server: shutdown signalled");
|
297 + | graceful.shutdown().await;
|
298 + | tracing::info!("wire server: shutdown complete!");
|
299 + | break;
|
243 300 | }
|
244 - | }))
|
301 + | }
|
245 302 | }
|
246 - | });
|
247 - | let server = hyper_0_14::Server::from_tcp(listener)
|
248 - | .unwrap()
|
249 - | .serve(make_service)
|
250 - | .with_graceful_shutdown(async {
|
251 - | rx.await.ok();
|
252 - | tracing::info!("server shutdown!");
|
253 - | });
|
254 - | spawn(server);
|
303 + | };
|
304 + |
|
305 + | tokio::spawn(server);
|
255 306 | Self {
|
256 307 | event_log: wire_events,
|
257 308 | bind_addr: listener_addr,
|
258 309 | shutdown_hook: tx,
|
259 310 | }
|
260 311 | }
|
261 312 |
|
262 313 | /// Retrieve the events recorded by this connection
|
263 314 | pub fn events(&self) -> Vec<RecordedEvent> {
|
264 315 | self.event_log.lock().unwrap().clone()
|
265 316 | }
|
266 317 |
|
267 318 | fn bind_addr(&self) -> SocketAddr {
|
268 319 | self.bind_addr
|
269 320 | }
|
270 321 |
|
271 322 | pub fn dns_resolver(&self) -> LoggingDnsResolver {
|
272 323 | let event_log = self.event_log.clone();
|
273 324 | let bind_addr = self.bind_addr;
|
274 - | LoggingDnsResolver {
|
325 + | LoggingDnsResolver(InnerDnsResolver {
|
275 326 | log: event_log,
|
276 327 | socket_addr: bind_addr,
|
277 - | }
|
328 + | })
|
278 329 | }
|
279 330 |
|
280 331 | /// Prebuilt [`HttpClient`](aws_smithy_runtime_api::client::http::HttpClient) with correctly wired DNS resolver.
|
281 332 | ///
|
282 333 | /// **Note**: This must be used in tandem with [`Self::dns_resolver`]
|
283 334 | pub fn http_client(&self) -> SharedHttpClient {
|
284 - | HyperClientBuilder::new()
|
285 - | .build(hyper_0_14::client::HttpConnector::new_with_resolver(
|
286 - | self.dns_resolver(),
|
287 - | ))
|
288 - | .into_shared()
|
335 + | let resolver = self.dns_resolver();
|
336 + | crate::client::build_with_tcp_conn_fn(None, move || {
|
337 + | hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
|
338 + | resolver.clone().0,
|
339 + | )
|
340 + | })
|
289 341 | }
|
290 342 |
|
291 343 | /// Endpoint to use when connecting
|
292 344 | ///
|
293 345 | /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address
|
294 346 | pub fn endpoint_url(&self) -> String {
|
295 347 | format!(
|
296 348 | "http://this-url-is-converted-to-localhost.com:{}",
|
297 349 | self.bind_addr().port()
|
298 350 | )
|
299 351 | }
|
300 352 |
|
301 353 | /// Shuts down the mock server.
|
302 354 | pub fn shutdown(self) {
|
303 355 | let _ = self.shutdown_hook.send(());
|
304 356 | }
|
305 357 | }
|
306 358 |
|
307 359 | async fn generate_response_event(
|
308 360 | event: ReplayedEvent,
|
309 - | ) -> Result<http_02x::Response<hyper_0_14::Body>, Infallible> {
|
361 + | ) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
|
310 362 | let resp = match event {
|
311 - | ReplayedEvent::HttpResponse { status, body } => http_02x::Response::builder()
|
363 + | ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
|
312 364 | .status(status)
|
313 - | .body(hyper_0_14::Body::from(body))
|
365 + | .body(Full::new(body))
|
314 366 | .unwrap(),
|
315 367 | ReplayedEvent::Timeout => {
|
316 368 | Never::new().await;
|
317 369 | unreachable!()
|
318 370 | }
|
319 371 | };
|
320 372 | Ok::<_, Infallible>(resp)
|
321 373 | }
|
322 374 |
|
323 375 | /// DNS resolver that keeps a log of all lookups
|
324 376 | ///
|
325 377 | /// Regardless of what hostname is requested, it will always return the same socket address.
|
326 378 | #[derive(Clone, Debug)]
|
327 - | pub struct LoggingDnsResolver {
|
379 + | pub struct LoggingDnsResolver(InnerDnsResolver);
|
380 + |
|
381 + | // internal implementation so we don't have to expose hyper_util
|
382 + | #[derive(Clone, Debug)]
|
383 + | struct InnerDnsResolver {
|
328 384 | log: Arc<Mutex<Vec<RecordedEvent>>>,
|
329 385 | socket_addr: SocketAddr,
|
330 386 | }
|
331 387 |
|
332 - | impl Service<Name> for LoggingDnsResolver {
|
388 + | impl tower::Service<Name> for InnerDnsResolver {
|
333 389 | type Response = Once<SocketAddr>;
|
334 390 | type Error = Infallible;
|
335 391 | type Future = BoxFuture<'static, Self::Response, Self::Error>;
|
336 392 |
|
337 393 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
338 394 | Poll::Ready(Ok(()))
|
339 395 | }
|
340 396 |
|
341 397 | fn call(&mut self, req: Name) -> Self::Future {
|
342 398 | let socket_addr = self.socket_addr;
|
343 399 | let log = self.log.clone();
|
344 400 | Box::pin(async move {
|
345 401 | println!("looking up {:?}, replying with {:?}", req, socket_addr);
|
346 402 | log.lock()
|
347 403 | .unwrap()
|
348 404 | .push(RecordedEvent::DnsLookup(req.to_string()));
|
349 405 | Ok(std::iter::once(socket_addr))
|
350 406 | })
|
351 407 | }
|
352 408 | }
|
409 + |
|
410 + | #[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
|
411 + | impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
|
412 + | type Response = Once<SocketAddr>;
|
413 + | type Error = Infallible;
|
414 + | type Future = BoxFuture<'static, Self::Response, Self::Error>;
|
415 + |
|
416 + | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
417 + | self.0.poll_ready(cx)
|
418 + | }
|
419 + |
|
420 + | fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
|
421 + | use std::str::FromStr;
|
422 + | let adapter = Name::from_str(req.as_str()).expect("valid conversion");
|
423 + | self.0.call(adapter)
|
424 + | }
|
425 + | }
|