aws_smithy_http_client/test_util/dvr/
replay.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
7use crate::test_util::replay::DEFAULT_RELAXED_HEADERS;
8use aws_smithy_protocol_test::MediaType;
9use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
10use aws_smithy_runtime_api::client::http::{
11    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
12};
13use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
14use aws_smithy_runtime_api::client::result::ConnectorError;
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use bytes::{Bytes, BytesMut};
20use std::collections::{HashMap, VecDeque};
21use std::error::Error;
22use std::fmt;
23use std::ops::DerefMut;
24use std::path::Path;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::{Arc, Mutex};
27use tokio::task::JoinHandle;
28
29/// Wrapper type to enable optionally waiting for a future to complete
30#[derive(Debug)]
31enum Waitable<T> {
32    Loading(JoinHandle<T>),
33    Value(T),
34}
35
36impl<T> Waitable<T> {
37    /// Consumes the future and returns the value
38    async fn take(self) -> T {
39        match self {
40            Waitable::Loading(f) => f.await.expect("join failed"),
41            Waitable::Value(value) => value,
42        }
43    }
44
45    /// Waits for the future to be ready
46    async fn wait(&mut self) {
47        match self {
48            Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
49            Waitable::Value(_) => {}
50        }
51    }
52}
53
54/// Replay traffic recorded by a [`RecordingClient`](super::RecordingClient)
55#[derive(Clone)]
56pub struct ReplayingClient {
57    live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
58    verifiable_events: Arc<HashMap<ConnectionId, http_1x::Request<Bytes>>>,
59    num_events: Arc<AtomicUsize>,
60    recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http_1x::Request<Bytes>>>>>,
61}
62
63// Ideally, this would just derive Debug, but that makes the tests in aws-config think they found AWS secrets
64// when really it's just the test response data they're seeing from the Debug impl of this client.
65// This is just a quick workaround. A better fix can be considered later.
66impl fmt::Debug for ReplayingClient {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.write_str("test_util::dvr::ReplayingClient")
69    }
70}
71
72enum HeadersToCheck<'a> {
73    Include(&'a [&'a str]),
74    Exclude(Option<&'a [&'a str]>),
75}
76
77impl ReplayingClient {
78    fn next_id(&self) -> ConnectionId {
79        ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
80    }
81
82    /// Validate all headers and bodies
83    pub async fn full_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
84        self.validate_body_and_headers(None, media_type).await
85    }
86
87    /// Convenience method to validate that the bodies match, using a given [`MediaType`] for
88    /// comparison, and that the headers are also match excluding the default relaxed headers
89    ///
90    /// The current default relaxed headers:
91    /// - x-amz-user-agent
92    /// - authorization
93    pub async fn relaxed_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
94        self.validate_body_and_headers_except(DEFAULT_RELAXED_HEADERS, media_type)
95            .await
96    }
97
98    /// Validate actual requests against expected requests
99    pub async fn validate(
100        self,
101        checked_headers: &[&str],
102        body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
103    ) -> Result<(), Box<dyn Error>> {
104        self.validate_base(HeadersToCheck::Include(checked_headers), body_comparer)
105            .await
106    }
107
108    /// Validate that the bodies match, using a given [`MediaType`] for comparison
109    ///
110    /// The specified headers are also validated. If `checked_headers` is a `None`, it means
111    /// checking all headers.
112    pub async fn validate_body_and_headers(
113        self,
114        checked_headers: Option<&[&str]>,
115        media_type: &str,
116    ) -> Result<(), Box<dyn Error>> {
117        let headers_to_check = match checked_headers {
118            Some(headers) => HeadersToCheck::Include(headers),
119            None => HeadersToCheck::Exclude(None),
120        };
121        self.validate_base(headers_to_check, |b1, b2| {
122            aws_smithy_protocol_test::validate_body(
123                b1,
124                std::str::from_utf8(b2).unwrap(),
125                MediaType::from(media_type),
126            )
127            .map_err(|e| Box::new(e) as _)
128        })
129        .await
130    }
131
132    /// Validate that the bodies match, using a given [`MediaType`] for comparison
133    ///
134    /// The headers are also validated unless listed in `excluded_headers`
135    pub async fn validate_body_and_headers_except(
136        self,
137        excluded_headers: &[&str],
138        media_type: &str,
139    ) -> Result<(), Box<dyn Error>> {
140        self.validate_base(HeadersToCheck::Exclude(Some(excluded_headers)), |b1, b2| {
141            aws_smithy_protocol_test::validate_body(
142                b1,
143                std::str::from_utf8(b2).unwrap(),
144                MediaType::from(media_type),
145            )
146            .map_err(|e| Box::new(e) as _)
147        })
148        .await
149    }
150
151    async fn validate_base(
152        self,
153        checked_headers: HeadersToCheck<'_>,
154        body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
155    ) -> Result<(), Box<dyn Error>> {
156        let mut actual_requests =
157            std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
158        for conn_id in 0..self.verifiable_events.len() {
159            let conn_id = ConnectionId(conn_id);
160            let expected = self.verifiable_events.get(&conn_id).unwrap();
161            let actual = actual_requests
162                .remove(&conn_id)
163                .ok_or(format!(
164                    "expected connection {:?} but request was never sent",
165                    conn_id
166                ))?
167                .take()
168                .await;
169            body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
170            let actual: HttpRequest = actual.map(SdkBody::from).try_into()?;
171            aws_smithy_protocol_test::assert_uris_match(expected.uri().to_string(), actual.uri());
172            let expected_headers = expected
173                .headers()
174                .keys()
175                .map(|k| k.as_str())
176                .filter(|k| match checked_headers {
177                    HeadersToCheck::Include(headers) => headers.contains(k),
178                    HeadersToCheck::Exclude(excluded) => match excluded {
179                        Some(headers) => !headers.contains(k),
180                        None => true,
181                    },
182                })
183                .flat_map(|key| {
184                    let _ = expected.headers().get(key)?;
185                    Some((
186                        key,
187                        expected
188                            .headers()
189                            .get_all(key)
190                            .iter()
191                            .map(|h| h.to_str().unwrap())
192                            .collect::<Vec<_>>()
193                            .join(", "),
194                    ))
195                })
196                .collect::<Vec<_>>();
197            aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
198                .map_err(|err| {
199                    format!(
200                        "event {} validation failed with: {}",
201                        conn_id.0,
202                        DisplayErrorContext(&err)
203                    )
204                })?;
205        }
206        Ok(())
207    }
208
209    /// Return all the recorded requests for further analysis
210    #[cfg(feature = "legacy-test-util")]
211    pub async fn take_requests(self) -> Vec<http_02x::Request<Bytes>> {
212        let mut recorded_requests =
213            std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
214        let mut out = Vec::with_capacity(recorded_requests.len());
215        for conn_id in 0..recorded_requests.len() {
216            out.push(
217                recorded_requests
218                    .remove(&ConnectionId(conn_id))
219                    .expect("should exist")
220                    .take()
221                    .await,
222            )
223        }
224        out.into_iter()
225            .map(|v1r| {
226                let mut builder = http_02x::Request::builder()
227                    .uri(v1r.uri().to_string())
228                    .method(v1r.method().as_str());
229                for (k, v) in v1r.headers().iter() {
230                    builder = builder.header(k.as_str(), v.as_bytes())
231                }
232                builder.body(v1r.into_body()).expect("valid conversion")
233            })
234            .collect()
235    }
236
237    /// Build a replay connection from a JSON file
238    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
239        let events: NetworkTraffic =
240            serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
241        Ok(Self::new(events.events))
242    }
243
244    /// Build a replay connection from a sequence of events
245    pub fn new(events: Vec<Event>) -> Self {
246        let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
247        for event in events {
248            let event_buffer = event_map.entry(event.connection_id).or_default();
249            event_buffer.push_back(event);
250        }
251        let verifiable_events = event_map
252            .iter()
253            .map(|(id, events)| {
254                let mut body = BytesMut::new();
255                for event in events {
256                    if let Action::Data {
257                        direction: Direction::Request,
258                        data,
259                    } = &event.action
260                    {
261                        body.extend_from_slice(&data.copy_to_vec());
262                    }
263                }
264                let initial_request = events.iter().next().expect("must have one event");
265                let request = match &initial_request.action {
266                    Action::Request { request } => {
267                        http_1x::Request::from(request).map(|_| Bytes::from(body))
268                    }
269                    _ => panic!("invalid first event"),
270                };
271                (*id, request)
272            })
273            .collect();
274        let verifiable_events = Arc::new(verifiable_events);
275
276        ReplayingClient {
277            live_events: Arc::new(Mutex::new(event_map)),
278            num_events: Arc::new(AtomicUsize::new(0)),
279            recorded_requests: Default::default(),
280            verifiable_events,
281        }
282    }
283}
284
285async fn replay_body(events: VecDeque<Event>, mut sender: crate::test_util::body::Sender) {
286    for event in events {
287        match event.action {
288            Action::Request { .. } => panic!(),
289            Action::Response { .. } => panic!(),
290            Action::Data {
291                data,
292                direction: Direction::Response,
293            } => {
294                sender
295                    .send_data(Bytes::from(data.into_bytes()))
296                    .await
297                    .expect("this is in memory traffic that should not fail to send");
298            }
299            Action::Data {
300                data: _data,
301                direction: Direction::Request,
302            } => {}
303            Action::Eof {
304                direction: Direction::Request,
305                ..
306            } => {}
307            Action::Eof {
308                direction: Direction::Response,
309                ok: true,
310                ..
311            } => {
312                drop(sender);
313                break;
314            }
315            Action::Eof {
316                direction: Direction::Response,
317                ok: false,
318                ..
319            } => {
320                sender.abort();
321                break;
322            }
323        }
324    }
325}
326
327impl HttpConnector for ReplayingClient {
328    fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
329        let event_id = self.next_id();
330        tracing::debug!("received event {}: {request:?}", event_id.0);
331        let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
332            Some(traffic) => traffic,
333            None => {
334                return HttpConnectorFuture::ready(Err(ConnectorError::other(
335                    format!("no data for event {}. request: {:?}", event_id.0, request).into(),
336                    None,
337                )));
338            }
339        };
340
341        let _initial_request = events.pop_front().unwrap();
342        let (sender, body) = crate::test_util::body::channel_body();
343        let recording = self.recorded_requests.clone();
344        let recorded_request = tokio::spawn(async move {
345            let mut data_read = vec![];
346            while let Some(data) = crate::test_util::body::next_data_frame(request.body_mut()).await
347            {
348                data_read
349                    .extend_from_slice(data.expect("in memory request should not fail").as_ref())
350            }
351            request
352                .try_into_http1x()
353                .unwrap()
354                .map(|_body| Bytes::from(data_read))
355        });
356        let mut recorded_request = Waitable::Loading(recorded_request);
357        let fut = async move {
358            let resp: Result<_, ConnectorError> = loop {
359                let event = events
360                    .pop_front()
361                    .expect("no events, needed a response event");
362                match event.action {
363                    // to ensure deterministic behavior if the request EOF happens first in the log,
364                    // wait for the request body to be done before returning a response.
365                    Action::Eof {
366                        direction: Direction::Request,
367                        ..
368                    } => {
369                        recorded_request.wait().await;
370                    }
371                    Action::Request { .. } => panic!("invalid"),
372                    Action::Response {
373                        response: Err(error),
374                    } => break Err(ConnectorError::other(error.0.into(), None)),
375                    Action::Response {
376                        response: Ok(response),
377                    } => {
378                        let mut builder = http_1x::Response::builder().status(response.status);
379                        for (name, values) in response.headers {
380                            for value in values {
381                                builder = builder.header(&name, &value);
382                            }
383                        }
384                        tokio::spawn(async move {
385                            replay_body(events, sender).await;
386                            // insert the finalized body into
387                        });
388                        break Ok(HttpResponse::try_from(
389                            builder.body(body).expect("valid builder"),
390                        )
391                        .unwrap());
392                    }
393
394                    Action::Data {
395                        direction: Direction::Request,
396                        data: _data,
397                    } => {
398                        tracing::info!("get request data");
399                    }
400                    Action::Eof {
401                        direction: Direction::Response,
402                        ..
403                    } => panic!("got eof before response"),
404
405                    Action::Data {
406                        data: _,
407                        direction: Direction::Response,
408                    } => panic!("got response data before response"),
409                }
410            };
411            recording.lock().unwrap().insert(event_id, recorded_request);
412            resp
413        };
414        HttpConnectorFuture::new(fut)
415    }
416}
417
418impl HttpClient for ReplayingClient {
419    fn http_connector(
420        &self,
421        _: &HttpConnectorSettings,
422        _: &RuntimeComponents,
423    ) -> SharedHttpConnector {
424        self.clone().into_shared()
425    }
426
427    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
428        Some(ConnectorMetadata::new("replaying-client", None))
429    }
430}