aws_smithy_http_client/test_util/dvr/
replay.rs1use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
7use crate::test_util::replay::DEFAULT_RELAXED_HEADERS;
8use aws_smithy_protocol_test::MediaType;
9use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
10use aws_smithy_runtime_api::client::http::{
11 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
12};
13use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
14use aws_smithy_runtime_api::client::result::ConnectorError;
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use bytes::{Bytes, BytesMut};
20use std::collections::{HashMap, VecDeque};
21use std::error::Error;
22use std::fmt;
23use std::ops::DerefMut;
24use std::path::Path;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::{Arc, Mutex};
27use tokio::task::JoinHandle;
28
29#[derive(Debug)]
31enum Waitable<T> {
32 Loading(JoinHandle<T>),
33 Value(T),
34}
35
36impl<T> Waitable<T> {
37 async fn take(self) -> T {
39 match self {
40 Waitable::Loading(f) => f.await.expect("join failed"),
41 Waitable::Value(value) => value,
42 }
43 }
44
45 async fn wait(&mut self) {
47 match self {
48 Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
49 Waitable::Value(_) => {}
50 }
51 }
52}
53
54#[derive(Clone)]
56pub struct ReplayingClient {
57 live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
58 verifiable_events: Arc<HashMap<ConnectionId, http_1x::Request<Bytes>>>,
59 num_events: Arc<AtomicUsize>,
60 recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http_1x::Request<Bytes>>>>>,
61}
62
63impl fmt::Debug for ReplayingClient {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.write_str("test_util::dvr::ReplayingClient")
69 }
70}
71
72enum HeadersToCheck<'a> {
73 Include(&'a [&'a str]),
74 Exclude(Option<&'a [&'a str]>),
75}
76
77impl ReplayingClient {
78 fn next_id(&self) -> ConnectionId {
79 ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
80 }
81
82 pub async fn full_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
84 self.validate_body_and_headers(None, media_type).await
85 }
86
87 pub async fn relaxed_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
94 self.validate_body_and_headers_except(DEFAULT_RELAXED_HEADERS, media_type)
95 .await
96 }
97
98 pub async fn validate(
100 self,
101 checked_headers: &[&str],
102 body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
103 ) -> Result<(), Box<dyn Error>> {
104 self.validate_base(HeadersToCheck::Include(checked_headers), body_comparer)
105 .await
106 }
107
108 pub async fn validate_body_and_headers(
113 self,
114 checked_headers: Option<&[&str]>,
115 media_type: &str,
116 ) -> Result<(), Box<dyn Error>> {
117 let headers_to_check = match checked_headers {
118 Some(headers) => HeadersToCheck::Include(headers),
119 None => HeadersToCheck::Exclude(None),
120 };
121 self.validate_base(headers_to_check, |b1, b2| {
122 aws_smithy_protocol_test::validate_body(
123 b1,
124 std::str::from_utf8(b2).unwrap(),
125 MediaType::from(media_type),
126 )
127 .map_err(|e| Box::new(e) as _)
128 })
129 .await
130 }
131
132 pub async fn validate_body_and_headers_except(
136 self,
137 excluded_headers: &[&str],
138 media_type: &str,
139 ) -> Result<(), Box<dyn Error>> {
140 self.validate_base(HeadersToCheck::Exclude(Some(excluded_headers)), |b1, b2| {
141 aws_smithy_protocol_test::validate_body(
142 b1,
143 std::str::from_utf8(b2).unwrap(),
144 MediaType::from(media_type),
145 )
146 .map_err(|e| Box::new(e) as _)
147 })
148 .await
149 }
150
151 async fn validate_base(
152 self,
153 checked_headers: HeadersToCheck<'_>,
154 body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
155 ) -> Result<(), Box<dyn Error>> {
156 let mut actual_requests =
157 std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
158 for conn_id in 0..self.verifiable_events.len() {
159 let conn_id = ConnectionId(conn_id);
160 let expected = self.verifiable_events.get(&conn_id).unwrap();
161 let actual = actual_requests
162 .remove(&conn_id)
163 .ok_or(format!(
164 "expected connection {conn_id:?} but request was never sent"
165 ))?
166 .take()
167 .await;
168 body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
169 let actual: HttpRequest = actual.map(SdkBody::from).try_into()?;
170 aws_smithy_protocol_test::assert_uris_match(expected.uri().to_string(), actual.uri());
171 let expected_headers = expected
172 .headers()
173 .keys()
174 .map(|k| k.as_str())
175 .filter(|k| match checked_headers {
176 HeadersToCheck::Include(headers) => headers.contains(k),
177 HeadersToCheck::Exclude(excluded) => match excluded {
178 Some(headers) => !headers.contains(k),
179 None => true,
180 },
181 })
182 .flat_map(|key| {
183 let _ = expected.headers().get(key)?;
184 Some((
185 key,
186 expected
187 .headers()
188 .get_all(key)
189 .iter()
190 .map(|h| h.to_str().unwrap())
191 .collect::<Vec<_>>()
192 .join(", "),
193 ))
194 })
195 .collect::<Vec<_>>();
196 aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
197 .map_err(|err| {
198 format!(
199 "event {} validation failed with: {}",
200 conn_id.0,
201 DisplayErrorContext(&err)
202 )
203 })?;
204 }
205 Ok(())
206 }
207
208 #[cfg(feature = "legacy-test-util")]
210 pub async fn take_requests(self) -> Vec<http_02x::Request<Bytes>> {
211 let mut recorded_requests =
212 std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
213 let mut out = Vec::with_capacity(recorded_requests.len());
214 for conn_id in 0..recorded_requests.len() {
215 out.push(
216 recorded_requests
217 .remove(&ConnectionId(conn_id))
218 .expect("should exist")
219 .take()
220 .await,
221 )
222 }
223 out.into_iter()
224 .map(|v1r| {
225 let mut builder = http_02x::Request::builder()
226 .uri(v1r.uri().to_string())
227 .method(v1r.method().as_str());
228 for (k, v) in v1r.headers().iter() {
229 builder = builder.header(k.as_str(), v.as_bytes())
230 }
231 builder.body(v1r.into_body()).expect("valid conversion")
232 })
233 .collect()
234 }
235
236 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
238 let events: NetworkTraffic =
239 serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
240 Ok(Self::new(events.events))
241 }
242
243 pub fn new(events: Vec<Event>) -> Self {
245 let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
246 for event in events {
247 let event_buffer = event_map.entry(event.connection_id).or_default();
248 event_buffer.push_back(event);
249 }
250 let verifiable_events = event_map
251 .iter()
252 .map(|(id, events)| {
253 let mut body = BytesMut::new();
254 for event in events {
255 if let Action::Data {
256 direction: Direction::Request,
257 data,
258 } = &event.action
259 {
260 body.extend_from_slice(&data.copy_to_vec());
261 }
262 }
263 let initial_request = events.iter().next().expect("must have one event");
264 let request = match &initial_request.action {
265 Action::Request { request } => {
266 http_1x::Request::from(request).map(|_| Bytes::from(body))
267 }
268 _ => panic!("invalid first event"),
269 };
270 (*id, request)
271 })
272 .collect();
273 let verifiable_events = Arc::new(verifiable_events);
274
275 ReplayingClient {
276 live_events: Arc::new(Mutex::new(event_map)),
277 num_events: Arc::new(AtomicUsize::new(0)),
278 recorded_requests: Default::default(),
279 verifiable_events,
280 }
281 }
282}
283
284async fn replay_body(events: VecDeque<Event>, mut sender: crate::test_util::body::Sender) {
285 for event in events {
286 match event.action {
287 Action::Request { .. } => panic!(),
288 Action::Response { .. } => panic!(),
289 Action::Data {
290 data,
291 direction: Direction::Response,
292 } => {
293 sender
294 .send_data(Bytes::from(data.into_bytes()))
295 .await
296 .expect("this is in memory traffic that should not fail to send");
297 }
298 Action::Data {
299 data: _data,
300 direction: Direction::Request,
301 } => {}
302 Action::Eof {
303 direction: Direction::Request,
304 ..
305 } => {}
306 Action::Eof {
307 direction: Direction::Response,
308 ok: true,
309 ..
310 } => {
311 drop(sender);
312 break;
313 }
314 Action::Eof {
315 direction: Direction::Response,
316 ok: false,
317 ..
318 } => {
319 sender.abort();
320 break;
321 }
322 }
323 }
324}
325
326impl HttpConnector for ReplayingClient {
327 fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
328 let event_id = self.next_id();
329 tracing::debug!("received event {}: {request:?}", event_id.0);
330 let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
331 Some(traffic) => traffic,
332 None => {
333 return HttpConnectorFuture::ready(Err(ConnectorError::other(
334 format!("no data for event {}. request: {:?}", event_id.0, request).into(),
335 None,
336 )));
337 }
338 };
339
340 let _initial_request = events.pop_front().unwrap();
341 let (sender, body) = crate::test_util::body::channel_body();
342 let recording = self.recorded_requests.clone();
343 let recorded_request = tokio::spawn(async move {
344 let mut data_read = vec![];
345 while let Some(data) = crate::test_util::body::next_data_frame(request.body_mut()).await
346 {
347 data_read
348 .extend_from_slice(data.expect("in memory request should not fail").as_ref())
349 }
350 request
351 .try_into_http1x()
352 .unwrap()
353 .map(|_body| Bytes::from(data_read))
354 });
355 let mut recorded_request = Waitable::Loading(recorded_request);
356 let fut = async move {
357 let resp: Result<_, ConnectorError> = loop {
358 let event = events
359 .pop_front()
360 .expect("no events, needed a response event");
361 match event.action {
362 Action::Eof {
365 direction: Direction::Request,
366 ..
367 } => {
368 recorded_request.wait().await;
369 }
370 Action::Request { .. } => panic!("invalid"),
371 Action::Response {
372 response: Err(error),
373 } => break Err(ConnectorError::other(error.0.into(), None)),
374 Action::Response {
375 response: Ok(response),
376 } => {
377 let mut builder = http_1x::Response::builder().status(response.status);
378 for (name, values) in response.headers {
379 for value in values {
380 builder = builder.header(&name, &value);
381 }
382 }
383 tokio::spawn(async move {
384 replay_body(events, sender).await;
385 });
387 break Ok(HttpResponse::try_from(
388 builder.body(body).expect("valid builder"),
389 )
390 .unwrap());
391 }
392
393 Action::Data {
394 direction: Direction::Request,
395 data: _data,
396 } => {
397 tracing::info!("get request data");
398 }
399 Action::Eof {
400 direction: Direction::Response,
401 ..
402 } => panic!("got eof before response"),
403
404 Action::Data {
405 data: _,
406 direction: Direction::Response,
407 } => panic!("got response data before response"),
408 }
409 };
410 recording.lock().unwrap().insert(event_id, recorded_request);
411 resp
412 };
413 HttpConnectorFuture::new(fut)
414 }
415}
416
417impl HttpClient for ReplayingClient {
418 fn http_connector(
419 &self,
420 _: &HttpConnectorSettings,
421 _: &RuntimeComponents,
422 ) -> SharedHttpConnector {
423 self.clone().into_shared()
424 }
425
426 fn connector_metadata(&self) -> Option<ConnectorMetadata> {
427 Some(ConnectorMetadata::new("replaying-client", None))
428 }
429}