aws_smithy_http_client/test_util/
wire.rs1#![allow(missing_docs)]
41
42use aws_smithy_async::future::never::Never;
43use aws_smithy_async::future::BoxFuture;
44use aws_smithy_runtime_api::client::http::SharedHttpClient;
45use bytes::Bytes;
46use http_body_util::Full;
47use hyper::service::service_fn;
48use hyper_util::client::legacy::connect::dns::Name;
49use hyper_util::rt::{TokioExecutor, TokioIo};
50use hyper_util::server::graceful::{GracefulConnection, GracefulShutdown};
51use std::collections::HashSet;
52use std::convert::Infallible;
53use std::error::Error;
54use std::future::Future;
55use std::iter::Once;
56use std::net::SocketAddr;
57use std::sync::{Arc, Mutex};
58use std::task::{Context, Poll};
59use tokio::net::TcpListener;
60use tokio::sync::oneshot;
61
62#[non_exhaustive]
64#[derive(Debug, Clone)]
65pub enum RecordedEvent {
66 DnsLookup(String),
67 NewConnection,
68 Response(ReplayedEvent),
69}
70
71type Matcher = (
72 Box<dyn Fn(&RecordedEvent) -> Result<(), Box<dyn Error>>>,
73 &'static str,
74);
75
76pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) {
78 let mut events_iter = events.iter();
79 let mut matcher_iter = matchers.iter();
80 let mut idx = -1;
81 loop {
82 idx += 1;
83 let bail = |err: Box<dyn Error>| {
84 panic!(
85 "failed on event {}:\n {}\n actual recorded events: {:?}",
86 idx, err, events
87 )
88 };
89 match (events_iter.next(), matcher_iter.next()) {
90 (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail),
91 (None, None) => return,
92 (Some(event), None) => {
93 bail(format!("got {:?} but no more events were expected", event).into())
94 }
95 (None, Some((_expect, msg))) => {
96 bail(format!("expected {:?} but no more events were expected", msg).into())
97 }
98 }
99 }
100}
101
102#[macro_export]
103macro_rules! matcher {
104 ($expect:tt) => {
105 (
106 Box::new(|event: &$crate::test_util::wire::RecordedEvent| {
107 if !matches!(event, $expect) {
108 return Err(
109 format!("expected `{}` but got {:?}", stringify!($expect), event).into(),
110 );
111 }
112 Ok(())
113 }),
114 stringify!($expect),
115 )
116 };
117}
118
119#[macro_export]
121macro_rules! match_events {
122 ($( $expect:pat),*) => {
123 |events| {
124 $crate::test_util::wire::check_matches(events, &[$( $crate::matcher!($expect) ),*]);
125 }
126 };
127 }
128
129#[macro_export]
131macro_rules! ev {
132 (http($status:expr)) => {
133 $crate::test_util::wire::RecordedEvent::Response(
134 $crate::test_util::wire::ReplayedEvent::HttpResponse {
135 status: $status,
136 ..
137 },
138 )
139 };
140 (dns) => {
141 $crate::test_util::wire::RecordedEvent::DnsLookup(_)
142 };
143 (connect) => {
144 $crate::test_util::wire::RecordedEvent::NewConnection
145 };
146 (timeout) => {
147 $crate::test_util::wire::RecordedEvent::Response(
148 $crate::test_util::wire::ReplayedEvent::Timeout,
149 )
150 };
151}
152
153pub use {ev, match_events, matcher};
154
155#[non_exhaustive]
156#[derive(Clone, Debug, PartialEq, Eq)]
157pub enum ReplayedEvent {
158 Timeout,
159 HttpResponse { status: u16, body: Bytes },
160}
161
162impl ReplayedEvent {
163 pub fn ok() -> Self {
164 Self::HttpResponse {
165 status: 200,
166 body: Bytes::new(),
167 }
168 }
169
170 pub fn with_body(body: impl AsRef<[u8]>) -> Self {
171 Self::HttpResponse {
172 status: 200,
173 body: Bytes::copy_from_slice(body.as_ref()),
174 }
175 }
176
177 pub fn status(status: u16) -> Self {
178 Self::HttpResponse {
179 status,
180 body: Bytes::new(),
181 }
182 }
183}
184
185#[derive(Debug)]
195pub struct WireMockServer {
196 event_log: Arc<Mutex<Vec<RecordedEvent>>>,
197 bind_addr: SocketAddr,
198 shutdown_hook: oneshot::Sender<()>,
200}
201
202#[derive(Debug, Clone)]
203struct SharedGraceful {
204 graceful: Arc<Mutex<Option<hyper_util::server::graceful::GracefulShutdown>>>,
205}
206
207impl SharedGraceful {
208 fn new() -> Self {
209 Self {
210 graceful: Arc::new(Mutex::new(Some(GracefulShutdown::new()))),
211 }
212 }
213
214 fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
215 let graceful = self.graceful.lock().unwrap();
216 graceful
217 .as_ref()
218 .expect("graceful not shutdown")
219 .watch(conn)
220 }
221
222 async fn shutdown(&self) {
223 let graceful = { self.graceful.lock().unwrap().take() };
224
225 if let Some(graceful) = graceful {
226 graceful.shutdown().await;
227 }
228 }
229}
230
231impl WireMockServer {
232 pub async fn start(mut response_events: Vec<ReplayedEvent>) -> Self {
234 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
235 let (tx, mut rx) = oneshot::channel();
236 let listener_addr = listener.local_addr().unwrap();
237 response_events.reverse();
238 let response_events = Arc::new(Mutex::new(response_events));
239 let handler_events = response_events;
240 let wire_events = Arc::new(Mutex::new(vec![]));
241 let wire_log_for_service = wire_events.clone();
242 let poisoned_conns: Arc<Mutex<HashSet<SocketAddr>>> = Default::default();
243 let graceful = SharedGraceful::new();
244 let conn_builder = Arc::new(hyper_util::server::conn::auto::Builder::new(
245 TokioExecutor::new(),
246 ));
247
248 let server = async move {
249 let poisoned_conns = poisoned_conns.clone();
250 let events = handler_events.clone();
251 let wire_log = wire_log_for_service.clone();
252 loop {
253 tokio::select! {
254 Ok((stream, remote_addr)) = listener.accept() => {
255 tracing::info!("established connection: {:?}", remote_addr);
256 let poisoned_conns = poisoned_conns.clone();
257 let events = events.clone();
258 let wire_log = wire_log.clone();
259 wire_log.lock().unwrap().push(RecordedEvent::NewConnection);
260 let io = TokioIo::new(stream);
261
262 let svc = service_fn(move |_req| {
263 let poisoned_conns = poisoned_conns.clone();
264 let events = events.clone();
265 let wire_log = wire_log.clone();
266 if poisoned_conns.lock().unwrap().contains(&remote_addr) {
267 tracing::error!("poisoned connection {:?} was reused!", &remote_addr);
268 panic!("poisoned connection was reused!");
269 }
270 let next_event = events.clone().lock().unwrap().pop();
271 async move {
272 let next_event = next_event
273 .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log));
274
275 wire_log
276 .lock()
277 .unwrap()
278 .push(RecordedEvent::Response(next_event.clone()));
279
280 if next_event == ReplayedEvent::Timeout {
281 tracing::info!("{} is poisoned", remote_addr);
282 poisoned_conns.lock().unwrap().insert(remote_addr);
283 }
284 tracing::debug!("replying with {:?}", next_event);
285 let event = generate_response_event(next_event).await;
286 dbg!(event)
287 }
288 });
289
290 let conn_builder = conn_builder.clone();
291 let graceful = graceful.clone();
292 tokio::spawn(async move {
293 let conn = conn_builder.serve_connection(io, svc);
294 let fut = graceful.watch(conn);
295 if let Err(e) = fut.await {
296 panic!("Error serving connection: {:?}", e);
297 }
298 });
299 },
300 _ = &mut rx => {
301 tracing::info!("wire server: shutdown signalled");
302 graceful.shutdown().await;
303 tracing::info!("wire server: shutdown complete!");
304 break;
305 }
306 }
307 }
308 };
309
310 tokio::spawn(server);
311 Self {
312 event_log: wire_events,
313 bind_addr: listener_addr,
314 shutdown_hook: tx,
315 }
316 }
317
318 pub fn events(&self) -> Vec<RecordedEvent> {
320 self.event_log.lock().unwrap().clone()
321 }
322
323 fn bind_addr(&self) -> SocketAddr {
324 self.bind_addr
325 }
326
327 pub fn dns_resolver(&self) -> LoggingDnsResolver {
328 let event_log = self.event_log.clone();
329 let bind_addr = self.bind_addr;
330 LoggingDnsResolver(InnerDnsResolver {
331 log: event_log,
332 socket_addr: bind_addr,
333 })
334 }
335
336 pub fn http_client(&self) -> SharedHttpClient {
340 let resolver = self.dns_resolver();
341 crate::client::build_with_tcp_conn_fn(None, move || {
342 hyper_util::client::legacy::connect::HttpConnector::new_with_resolver(
343 resolver.clone().0,
344 )
345 })
346 }
347
348 pub fn endpoint_url(&self) -> String {
352 format!(
353 "http://this-url-is-converted-to-localhost.com:{}",
354 self.bind_addr().port()
355 )
356 }
357
358 pub fn shutdown(self) {
360 let _ = self.shutdown_hook.send(());
361 }
362}
363
364async fn generate_response_event(
365 event: ReplayedEvent,
366) -> Result<http_1x::Response<Full<Bytes>>, Infallible> {
367 let resp = match event {
368 ReplayedEvent::HttpResponse { status, body } => http_1x::Response::builder()
369 .status(status)
370 .body(Full::new(body))
371 .unwrap(),
372 ReplayedEvent::Timeout => {
373 Never::new().await;
374 unreachable!()
375 }
376 };
377 Ok::<_, Infallible>(resp)
378}
379
380#[derive(Clone, Debug)]
384pub struct LoggingDnsResolver(InnerDnsResolver);
385
386#[derive(Clone, Debug)]
388struct InnerDnsResolver {
389 log: Arc<Mutex<Vec<RecordedEvent>>>,
390 socket_addr: SocketAddr,
391}
392
393impl tower::Service<Name> for InnerDnsResolver {
394 type Response = Once<SocketAddr>;
395 type Error = Infallible;
396 type Future = BoxFuture<'static, Self::Response, Self::Error>;
397
398 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
399 Poll::Ready(Ok(()))
400 }
401
402 fn call(&mut self, req: Name) -> Self::Future {
403 let socket_addr = self.socket_addr;
404 let log = self.log.clone();
405 Box::pin(async move {
406 println!("looking up {:?}, replying with {:?}", req, socket_addr);
407 log.lock()
408 .unwrap()
409 .push(RecordedEvent::DnsLookup(req.to_string()));
410 Ok(std::iter::once(socket_addr))
411 })
412 }
413}
414
415#[cfg(all(feature = "legacy-test-util", feature = "hyper-014"))]
416impl hyper_0_14::service::Service<hyper_0_14::client::connect::dns::Name> for LoggingDnsResolver {
417 type Response = Once<SocketAddr>;
418 type Error = Infallible;
419 type Future = BoxFuture<'static, Self::Response, Self::Error>;
420
421 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
422 self.0.poll_ready(cx)
423 }
424
425 fn call(&mut self, req: hyper_0_14::client::connect::dns::Name) -> Self::Future {
426 use std::str::FromStr;
427 let adapter = Name::from_str(req.as_str()).expect("valid conversion");
428 self.0.call(adapter)
429 }
430}