aws_smithy_http_client/test_util/dvr/
replay.rs1use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
7use crate::test_util::replay::DEFAULT_RELAXED_HEADERS;
8use aws_smithy_protocol_test::MediaType;
9use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
10use aws_smithy_runtime_api::client::http::{
11 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
12};
13use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
14use aws_smithy_runtime_api::client::result::ConnectorError;
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use bytes::{Bytes, BytesMut};
20use std::collections::{HashMap, VecDeque};
21use std::error::Error;
22use std::fmt;
23use std::ops::DerefMut;
24use std::path::Path;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::{Arc, Mutex};
27use tokio::task::JoinHandle;
28
29#[derive(Debug)]
31enum Waitable<T> {
32 Loading(JoinHandle<T>),
33 Value(T),
34}
35
36impl<T> Waitable<T> {
37 async fn take(self) -> T {
39 match self {
40 Waitable::Loading(f) => f.await.expect("join failed"),
41 Waitable::Value(value) => value,
42 }
43 }
44
45 async fn wait(&mut self) {
47 match self {
48 Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
49 Waitable::Value(_) => {}
50 }
51 }
52}
53
54#[derive(Clone)]
56pub struct ReplayingClient {
57 live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
58 verifiable_events: Arc<HashMap<ConnectionId, http_1x::Request<Bytes>>>,
59 num_events: Arc<AtomicUsize>,
60 recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http_1x::Request<Bytes>>>>>,
61}
62
63impl fmt::Debug for ReplayingClient {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 f.write_str("test_util::dvr::ReplayingClient")
69 }
70}
71
72enum HeadersToCheck<'a> {
73 Include(&'a [&'a str]),
74 Exclude(Option<&'a [&'a str]>),
75}
76
77impl ReplayingClient {
78 fn next_id(&self) -> ConnectionId {
79 ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
80 }
81
82 pub async fn full_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
84 self.validate_body_and_headers(None, media_type).await
85 }
86
87 pub async fn relaxed_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
94 self.validate_body_and_headers_except(DEFAULT_RELAXED_HEADERS, media_type)
95 .await
96 }
97
98 pub async fn validate(
100 self,
101 checked_headers: &[&str],
102 body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
103 ) -> Result<(), Box<dyn Error>> {
104 self.validate_base(HeadersToCheck::Include(checked_headers), body_comparer)
105 .await
106 }
107
108 pub async fn validate_body_and_headers(
113 self,
114 checked_headers: Option<&[&str]>,
115 media_type: &str,
116 ) -> Result<(), Box<dyn Error>> {
117 let headers_to_check = match checked_headers {
118 Some(headers) => HeadersToCheck::Include(headers),
119 None => HeadersToCheck::Exclude(None),
120 };
121 self.validate_base(headers_to_check, |b1, b2| {
122 aws_smithy_protocol_test::validate_body(
123 b1,
124 std::str::from_utf8(b2).unwrap(),
125 MediaType::from(media_type),
126 )
127 .map_err(|e| Box::new(e) as _)
128 })
129 .await
130 }
131
132 pub async fn validate_body_and_headers_except(
136 self,
137 excluded_headers: &[&str],
138 media_type: &str,
139 ) -> Result<(), Box<dyn Error>> {
140 self.validate_base(HeadersToCheck::Exclude(Some(excluded_headers)), |b1, b2| {
141 aws_smithy_protocol_test::validate_body(
142 b1,
143 std::str::from_utf8(b2).unwrap(),
144 MediaType::from(media_type),
145 )
146 .map_err(|e| Box::new(e) as _)
147 })
148 .await
149 }
150
151 async fn validate_base(
152 self,
153 checked_headers: HeadersToCheck<'_>,
154 body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
155 ) -> Result<(), Box<dyn Error>> {
156 let mut actual_requests =
157 std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
158 for conn_id in 0..self.verifiable_events.len() {
159 let conn_id = ConnectionId(conn_id);
160 let expected = self.verifiable_events.get(&conn_id).unwrap();
161 let actual = actual_requests
162 .remove(&conn_id)
163 .ok_or(format!(
164 "expected connection {:?} but request was never sent",
165 conn_id
166 ))?
167 .take()
168 .await;
169 body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
170 let actual: HttpRequest = actual.map(SdkBody::from).try_into()?;
171 aws_smithy_protocol_test::assert_uris_match(expected.uri().to_string(), actual.uri());
172 let expected_headers = expected
173 .headers()
174 .keys()
175 .map(|k| k.as_str())
176 .filter(|k| match checked_headers {
177 HeadersToCheck::Include(headers) => headers.contains(k),
178 HeadersToCheck::Exclude(excluded) => match excluded {
179 Some(headers) => !headers.contains(k),
180 None => true,
181 },
182 })
183 .flat_map(|key| {
184 let _ = expected.headers().get(key)?;
185 Some((
186 key,
187 expected
188 .headers()
189 .get_all(key)
190 .iter()
191 .map(|h| h.to_str().unwrap())
192 .collect::<Vec<_>>()
193 .join(", "),
194 ))
195 })
196 .collect::<Vec<_>>();
197 aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
198 .map_err(|err| {
199 format!(
200 "event {} validation failed with: {}",
201 conn_id.0,
202 DisplayErrorContext(&err)
203 )
204 })?;
205 }
206 Ok(())
207 }
208
209 #[cfg(feature = "legacy-test-util")]
211 pub async fn take_requests(self) -> Vec<http_02x::Request<Bytes>> {
212 let mut recorded_requests =
213 std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
214 let mut out = Vec::with_capacity(recorded_requests.len());
215 for conn_id in 0..recorded_requests.len() {
216 out.push(
217 recorded_requests
218 .remove(&ConnectionId(conn_id))
219 .expect("should exist")
220 .take()
221 .await,
222 )
223 }
224 out.into_iter()
225 .map(|v1r| {
226 let mut builder = http_02x::Request::builder()
227 .uri(v1r.uri().to_string())
228 .method(v1r.method().as_str());
229 for (k, v) in v1r.headers().iter() {
230 builder = builder.header(k.as_str(), v.as_bytes())
231 }
232 builder.body(v1r.into_body()).expect("valid conversion")
233 })
234 .collect()
235 }
236
237 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
239 let events: NetworkTraffic =
240 serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
241 Ok(Self::new(events.events))
242 }
243
244 pub fn new(events: Vec<Event>) -> Self {
246 let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
247 for event in events {
248 let event_buffer = event_map.entry(event.connection_id).or_default();
249 event_buffer.push_back(event);
250 }
251 let verifiable_events = event_map
252 .iter()
253 .map(|(id, events)| {
254 let mut body = BytesMut::new();
255 for event in events {
256 if let Action::Data {
257 direction: Direction::Request,
258 data,
259 } = &event.action
260 {
261 body.extend_from_slice(&data.copy_to_vec());
262 }
263 }
264 let initial_request = events.iter().next().expect("must have one event");
265 let request = match &initial_request.action {
266 Action::Request { request } => {
267 http_1x::Request::from(request).map(|_| Bytes::from(body))
268 }
269 _ => panic!("invalid first event"),
270 };
271 (*id, request)
272 })
273 .collect();
274 let verifiable_events = Arc::new(verifiable_events);
275
276 ReplayingClient {
277 live_events: Arc::new(Mutex::new(event_map)),
278 num_events: Arc::new(AtomicUsize::new(0)),
279 recorded_requests: Default::default(),
280 verifiable_events,
281 }
282 }
283}
284
285async fn replay_body(events: VecDeque<Event>, mut sender: crate::test_util::body::Sender) {
286 for event in events {
287 match event.action {
288 Action::Request { .. } => panic!(),
289 Action::Response { .. } => panic!(),
290 Action::Data {
291 data,
292 direction: Direction::Response,
293 } => {
294 sender
295 .send_data(Bytes::from(data.into_bytes()))
296 .await
297 .expect("this is in memory traffic that should not fail to send");
298 }
299 Action::Data {
300 data: _data,
301 direction: Direction::Request,
302 } => {}
303 Action::Eof {
304 direction: Direction::Request,
305 ..
306 } => {}
307 Action::Eof {
308 direction: Direction::Response,
309 ok: true,
310 ..
311 } => {
312 drop(sender);
313 break;
314 }
315 Action::Eof {
316 direction: Direction::Response,
317 ok: false,
318 ..
319 } => {
320 sender.abort();
321 break;
322 }
323 }
324 }
325}
326
327impl HttpConnector for ReplayingClient {
328 fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
329 let event_id = self.next_id();
330 tracing::debug!("received event {}: {request:?}", event_id.0);
331 let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
332 Some(traffic) => traffic,
333 None => {
334 return HttpConnectorFuture::ready(Err(ConnectorError::other(
335 format!("no data for event {}. request: {:?}", event_id.0, request).into(),
336 None,
337 )));
338 }
339 };
340
341 let _initial_request = events.pop_front().unwrap();
342 let (sender, body) = crate::test_util::body::channel_body();
343 let recording = self.recorded_requests.clone();
344 let recorded_request = tokio::spawn(async move {
345 let mut data_read = vec![];
346 while let Some(data) = crate::test_util::body::next_data_frame(request.body_mut()).await
347 {
348 data_read
349 .extend_from_slice(data.expect("in memory request should not fail").as_ref())
350 }
351 request
352 .try_into_http1x()
353 .unwrap()
354 .map(|_body| Bytes::from(data_read))
355 });
356 let mut recorded_request = Waitable::Loading(recorded_request);
357 let fut = async move {
358 let resp: Result<_, ConnectorError> = loop {
359 let event = events
360 .pop_front()
361 .expect("no events, needed a response event");
362 match event.action {
363 Action::Eof {
366 direction: Direction::Request,
367 ..
368 } => {
369 recorded_request.wait().await;
370 }
371 Action::Request { .. } => panic!("invalid"),
372 Action::Response {
373 response: Err(error),
374 } => break Err(ConnectorError::other(error.0.into(), None)),
375 Action::Response {
376 response: Ok(response),
377 } => {
378 let mut builder = http_1x::Response::builder().status(response.status);
379 for (name, values) in response.headers {
380 for value in values {
381 builder = builder.header(&name, &value);
382 }
383 }
384 tokio::spawn(async move {
385 replay_body(events, sender).await;
386 });
388 break Ok(HttpResponse::try_from(
389 builder.body(body).expect("valid builder"),
390 )
391 .unwrap());
392 }
393
394 Action::Data {
395 direction: Direction::Request,
396 data: _data,
397 } => {
398 tracing::info!("get request data");
399 }
400 Action::Eof {
401 direction: Direction::Response,
402 ..
403 } => panic!("got eof before response"),
404
405 Action::Data {
406 data: _,
407 direction: Direction::Response,
408 } => panic!("got response data before response"),
409 }
410 };
411 recording.lock().unwrap().insert(event_id, recorded_request);
412 resp
413 };
414 HttpConnectorFuture::new(fut)
415 }
416}
417
418impl HttpClient for ReplayingClient {
419 fn http_connector(
420 &self,
421 _: &HttpConnectorSettings,
422 _: &RuntimeComponents,
423 ) -> SharedHttpConnector {
424 self.clone().into_shared()
425 }
426
427 fn connector_metadata(&self) -> Option<ConnectorMetadata> {
428 Some(ConnectorMetadata::new("replaying-client", None))
429 }
430}