aws_smithy_http_client/test_util/
wire.rs1#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66 DnsLookup(String),
67 NewConnection,
68 Response(ReplayedEvent),
69}
70
71type Matcher = (
72 Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73 &'static str,
74);
75
76pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78 let mut events_iter = events.iter();
79 let mut matcher_iter = matchers.iter();
80 let mut idx = -1;
81 loop {
82 idx += 1;
83 let bail = |err: Box<dyn Error>| {
84 panic!("failed on event {idx}:\n {err}\n actual recorded events: {events:?}")
85 };
86 match (events_iter.next(), matcher_iter.next()) {
87 (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
88 (None, None) => return,
89 (Some(event), None) => {
90 bail(format!("got {event:?} but no more events were expected").into())
91 }
92 (None, Some((_expect, msg))) => {
93 bail(format!("expected {msg:?} but no more events were expected").into())
94 }
95 }
96 }
97}
98
99#[macro_export]
100macro_rules! matcher {
101 ($expect:tt) => {
102 (
103 Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
104 if !matches!(event, $expect) {
105 return Err(
106 format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
107 );
108 }
109 Ok(())
110 }),
111 stringify!($expect),
112 )
113 };
114}
115
116#[macro_export]
118macro_rules! match_events {
119 ($( $expect:pat),*) => {
120 |events| {
121 $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
122 }
123 };
124 }
125
126#[macro_export]
128macro_rules! ev {
129 (http($status:expr)) => {
130 $crate::test_util::wire::RecordedEvent::Response(
131 $crate::test_util::wire::ReplayedEvent::HttpResponse {
132 status: $status,
133 ..
134 },
135 )
136 };
137 (dns) => {
138 $crate::test_util::wire::RecordedEvent::DnsLookup(_)
139 };
140 (connect) => {
141 $crate::test_util::wire::RecordedEvent::NewConnection
142 };
143 (timeout) => {
144 $crate::test_util::wire::RecordedEvent::Response(
145 $crate::test_util::wire::ReplayedEvent::Timeout,
146 )
147 };
148}
149
150pub use {ev, match_events, matcher};
151
152#[non_exhaustive]
153#[derive(Clone, Debug, PartialEq, Eq)]
154pub enum ReplayedEvent {
155 Timeout,
156 HttpResponse { status: u16, body: Bytes },
157}
158
159impl ReplayedEvent {
160 pub fn ok() -> Self {
161 Self::HttpResponse {
162 status: 200,
163 body: Bytes::new(),
164 }
165 }
166
167 pub fn with_body(body: impl AsRef<[u8]>) -> Self {
168 Self::HttpResponse {
169 status: 200,
170 body: Bytes::copy_from_slice(body.as_ref()),
171 }
172 }
173
174 pub fn status(status: u16) -> Self {
175 Self::HttpResponse {
176 status,
177 body: Bytes::new(),
178 }
179 }
180}
181
182#[derive(Debug)]
192pub struct WireMockServer {
193 event_log: Arc<Mutex<Vec<RecordedEvent>>>,
194 bind_addr: SocketAddr,
195 shutdown_hook: oneshot::Sender<()>,
197}
198
199#[derive(Debug, Clone)]
200struct SharedGraceful {
201 graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
202}
203
204impl SharedGraceful {
205 fn new() -> Self {
206 Self {
207 graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
208 }
209 }
210
211 fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
212 let graceful = self.graceful.lock().unwrap();
213 graceful
214 .as_ref()
215 .expect("graceful not shutdown")
216 .watch(conn)
217 }
218
219 async fn shutdown(&self) {
220 let graceful = { self.graceful.lock().unwrap().take() };
221
222 if let Some(graceful) = graceful {
223 graceful.shutdown().await;
224 }
225 }
226}
227
228impl WireMockServer {
229 pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
231 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
232 let (tx, mut rx) = oneshot::channel();
233 let listener_addr = listener.local_addr().unwrap();
234 response_events.reverse();
235 let response_events = Arc::new(Mutex::new(response_events));
236 let handler_events = response_events;
237 let wire_events = Arc::new(Mutex::new(vec![]));
238 let wire_log_for_service = wire_events.clone();
239 let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
240 let graceful = SharedGraceful::new();
241 let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
242 TokioExecutor::new(),
243 ));
244
245 let server = async move {
246 let poisoned_conns = poisoned_conns.clone();
247 let events = handler_events.clone();
248 let wire_log = wire_log_for_service.clone();
249 loop {
250 tokio::select! {
251 Ok((stream, remote_addr)) = listener.accept() => {
252 tracing::info!("established connection: {:?}", remote_addr);
253 let poisoned_conns = poisoned_conns.clone();
254 let events = events.clone();
255 let wire_log = wire_log.clone();
256 wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
257 let io = TokioIo::new(stream);
258
259 let svc = service_fn(move |_req| {
260 let poisoned_conns = poisoned_conns.clone();
261 let events = events.clone();
262 let wire_log = wire_log.clone();
263 if poisoned_conns.lock().unwrap().contains(&remote_addr) {
264 tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
265 panic!("poisoned connection was reused!");
266 }
267 let next_event = events.clone().lock().unwrap().pop();
268 async move {
269 let next_event = next_event
270 .unwrap_or_else(|| panic!("no more events! Log: {wire_log:?}"));
271
272 wire_log
273 .lock()
274 .unwrap()
275 .push(RecordedEvent::Response(next_event.clone()));
276
277 if next_event == ReplayedEvent::Timeout {
278 tracing::info!("{} is poisoned", remote_addr);
279 poisoned_conns.lock().unwrap().insert(remote_addr);
280 }
281 tracing::debug!("replying with {:?}", next_event);
282 let event = generate_response_event(next_event).await;
283 dbg!(event)
284 }
285 });
286
287 let conn_builder = conn_builder.clone();
288 let graceful = graceful.clone();
289 tokio::spawn(async move {
290 let conn = conn_builder.serve_connection(io, svc);
291 let fut = graceful.watch(conn);
292 if let Err(e) = fut.await {
293 panic!("Error serving connection: {e:?}");
294 }
295 });
296 },
297 _ = &mut rx => {
298 tracing::info!("wire server: shutdown signalled");
299 graceful.shutdown().await;
300 tracing::info!("wire server: shutdown complete!");
301 break;
302 }
303 }
304 }
305 };
306
307 tokio::spawn(server);
308 Self {
309 event_log: wire_events,
310 bind_addr: listener_addr,
311 shutdown_hook: tx,
312 }
313 }
314
315 pub fn events(&self) -> Vec<RecordedEvent> {
317 self.event_log.lock().unwrap().clone()
318 }
319
320 fn bind_addr(&self) -> SocketAddr {
321 self.bind_addr
322 }
323
324 pub fn dns_resolver(&self) -> LoggingDnsResolver {
325 let event_log = self.event_log.clone();
326 let bind_addr = self.bind_addr;
327 LoggingDnsResolver(InnerDnsResolver {
328 log: event_log,
329 socket_addr: bind_addr,
330 })
331 }
332
333 pub fn http_client(&self) -> SharedHttpClient {
337 let resolver = self.dns_resolver();
338 crate::client::build_with_tcp_conn_fn(None, None, move || {
339 hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
340 resolver.clone().0,
341 )
342 })
343 }
344
345 pub fn endpoint_url(&self) -> String {
349 format!(
350 "http://this-url-is-converted-to-localhost.com:{}",
351 self.bind_addr().port()
352 )
353 }
354
355 pub fn shutdown(self) {
357 let _ = self.shutdown_hook.send(());
358 }
359}
360
361async fn generate_response_event(
362 event: ReplayedEvent,
363) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
364 let resp = match event {
365 ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
366 .status(status)
367 .body(Full::new(body))
368 .unwrap(),
369 ReplayedEvent::Timeout => {
370 Never::new().await;
371 unreachable!()
372 }
373 };
374 Ok::<_, Infallible>(resp)
375}
376
377#[derive(Clone, Debug)]
381pub struct LoggingDnsResolver(InnerDnsResolver);
382
383#[derive(Clone, Debug)]
385struct InnerDnsResolver {
386 log: Arc<Mutex<Vec<RecordedEvent>>>,
387 socket_addr: SocketAddr,
388}
389
390impl tower::Service<Name> for InnerDnsResolver {
391 type Response = Once<SocketAddr>;
392 type Error = Infallible;
393 type Future = BoxFuture<'static, Self::Response, Self::Error>;
394
395 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
396 Poll::Ready(Ok(()))
397 }
398
399 fn call(&mut self, req: Name) -> Self::Future {
400 let socket_addr = self.socket_addr;
401 let log = self.log.clone();
402 Box::pin(async move {
403 println!("looking up {req:?}, replying with {socket_addr:?}");
404 log.lock()
405 .unwrap()
406 .push(RecordedEvent::DnsLookup(req.to_string()));
407 Ok(std::iter::once(socket_addr))
408 })
409 }
410}
411
412#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
413impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
414 type Response = Once<SocketAddr>;
415 type Error = Infallible;
416 type Future = BoxFuture<'static, Self::Response, Self::Error>;
417
418 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
419 self.0.poll_ready(cx)
420 }
421
422 fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
423 use std::str::FromStr;
424 let adapter = Name::from_str(req.as_str()).expect("valid conversion");
425 self.0.call(adapter)
426 }
427}