aws_smithy_http_client/test_util/dvr/
replay.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
7use crate::test_util::replay::DEFAULT_RELAXED_HEADERS;
8use aws_smithy_protocol_test::MediaType;
9use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
10use aws_smithy_runtime_api::client::http::{
11    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
12};
13use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
14use aws_smithy_runtime_api::client::result::ConnectorError;
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::error::display::DisplayErrorContext;
19use bytes::{Bytes, BytesMut};
20use std::collections::{HashMap, VecDeque};
21use std::error::Error;
22use std::fmt;
23use std::ops::DerefMut;
24use std::path::Path;
25use std::sync::atomic::{AtomicUsize, Ordering};
26use std::sync::{Arc, Mutex};
27use tokio::task::JoinHandle;
28
29/// Wrapper type to enable optionally waiting for a future to complete
30#[derive(Debug)]
31enum Waitable<T> {
32    Loading(JoinHandle<T>),
33    Value(T),
34}
35
36impl<T> Waitable<T> {
37    /// Consumes the future and returns the value
38    async fn take(self) -> T {
39        match self {
40            Waitable::Loading(f) => f.await.expect("join failed"),
41            Waitable::Value(value) => value,
42        }
43    }
44
45    /// Waits for the future to be ready
46    async fn wait(&mut self) {
47        match self {
48            Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
49            Waitable::Value(_) => {}
50        }
51    }
52}
53
54/// Replay traffic recorded by a [`RecordingClient`](super::RecordingClient)
55#[derive(Clone)]
56pub struct ReplayingClient {
57    live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
58    verifiable_events: Arc<HashMap<ConnectionId, http_1x::Request<Bytes>>>,
59    num_events: Arc<AtomicUsize>,
60    recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http_1x::Request<Bytes>>>>>,
61}
62
63// Ideally, this would just derive Debug, but that makes the tests in aws-config think they found AWS secrets
64// when really it's just the test response data they're seeing from the Debug impl of this client.
65// This is just a quick workaround. A better fix can be considered later.
66impl fmt::Debug for ReplayingClient {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.write_str("test_util::dvr::ReplayingClient")
69    }
70}
71
72enum HeadersToCheck<'a> {
73    Include(&'a [&'a str]),
74    Exclude(Option<&'a [&'a str]>),
75}
76
77impl ReplayingClient {
78    fn next_id(&self) -> ConnectionId {
79        ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
80    }
81
82    /// Validate all headers and bodies
83    pub async fn full_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
84        self.validate_body_and_headers(None, media_type).await
85    }
86
87    /// Convenience method to validate that the bodies match, using a given [`MediaType`] for
88    /// comparison, and that the headers are also match excluding the default relaxed headers
89    ///
90    /// The current default relaxed headers:
91    /// - x-amz-user-agent
92    /// - authorization
93    pub async fn relaxed_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
94        self.validate_body_and_headers_except(DEFAULT_RELAXED_HEADERS, media_type)
95            .await
96    }
97
98    /// Validate actual requests against expected requests
99    pub async fn validate(
100        self,
101        checked_headers: &[&str],
102        body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
103    ) -> Result<(), Box<dyn Error>> {
104        self.validate_base(HeadersToCheck::Include(checked_headers), body_comparer)
105            .await
106    }
107
108    /// Validate that the bodies match, using a given [`MediaType`] for comparison
109    ///
110    /// The specified headers are also validated. If `checked_headers` is a `None`, it means
111    /// checking all headers.
112    pub async fn validate_body_and_headers(
113        self,
114        checked_headers: Option<&[&str]>,
115        media_type: &str,
116    ) -> Result<(), Box<dyn Error>> {
117        let headers_to_check = match checked_headers {
118            Some(headers) => HeadersToCheck::Include(headers),
119            None => HeadersToCheck::Exclude(None),
120        };
121        self.validate_base(headers_to_check, |b1, b2| {
122            aws_smithy_protocol_test::validate_body(
123                b1,
124                std::str::from_utf8(b2).unwrap(),
125                MediaType::from(media_type),
126            )
127            .map_err(|e| Box::new(e) as _)
128        })
129        .await
130    }
131
132    /// Validate that the bodies match, using a given [`MediaType`] for comparison
133    ///
134    /// The headers are also validated unless listed in `excluded_headers`
135    pub async fn validate_body_and_headers_except(
136        self,
137        excluded_headers: &[&str],
138        media_type: &str,
139    ) -> Result<(), Box<dyn Error>> {
140        self.validate_base(HeadersToCheck::Exclude(Some(excluded_headers)), |b1, b2| {
141            aws_smithy_protocol_test::validate_body(
142                b1,
143                std::str::from_utf8(b2).unwrap(),
144                MediaType::from(media_type),
145            )
146            .map_err(|e| Box::new(e) as _)
147        })
148        .await
149    }
150
151    async fn validate_base(
152        self,
153        checked_headers: HeadersToCheck<'_>,
154        body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
155    ) -> Result<(), Box<dyn Error>> {
156        let mut actual_requests =
157            std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
158        for conn_id in 0..self.verifiable_events.len() {
159            let conn_id = ConnectionId(conn_id);
160            let expected = self.verifiable_events.get(&conn_id).unwrap();
161            let actual = actual_requests
162                .remove(&conn_id)
163                .ok_or(format!(
164                    "expected connection {conn_id:?} but request was never sent"
165                ))?
166                .take()
167                .await;
168            body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
169            let actual: HttpRequest = actual.map(SdkBody::from).try_into()?;
170            aws_smithy_protocol_test::assert_uris_match(expected.uri().to_string(), actual.uri());
171            let expected_headers = expected
172                .headers()
173                .keys()
174                .map(|k| k.as_str())
175                .filter(|k| match checked_headers {
176                    HeadersToCheck::Include(headers) => headers.contains(k),
177                    HeadersToCheck::Exclude(excluded) => match excluded {
178                        Some(headers) => !headers.contains(k),
179                        None => true,
180                    },
181                })
182                .flat_map(|key| {
183                    let _ = expected.headers().get(key)?;
184                    Some((
185                        key,
186                        expected
187                            .headers()
188                            .get_all(key)
189                            .iter()
190                            .map(|h| h.to_str().unwrap())
191                            .collect::<Vec<_>>()
192                            .join(", "),
193                    ))
194                })
195                .collect::<Vec<_>>();
196            aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
197                .map_err(|err| {
198                    format!(
199                        "event {} validation failed with: {}",
200                        conn_id.0,
201                        DisplayErrorContext(&err)
202                    )
203                })?;
204        }
205        Ok(())
206    }
207
208    /// Return all the recorded requests for further analysis
209    #[cfg(feature = "legacy-test-util")]
210    pub async fn take_requests(self) -> Vec<http_02x::Request<Bytes>> {
211        let mut recorded_requests =
212            std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
213        let mut out = Vec::with_capacity(recorded_requests.len());
214        for conn_id in 0..recorded_requests.len() {
215            out.push(
216                recorded_requests
217                    .remove(&ConnectionId(conn_id))
218                    .expect("should exist")
219                    .take()
220                    .await,
221            )
222        }
223        out.into_iter()
224            .map(|v1r| {
225                let mut builder = http_02x::Request::builder()
226                    .uri(v1r.uri().to_string())
227                    .method(v1r.method().as_str());
228                for (k, v) in v1r.headers().iter() {
229                    builder = builder.header(k.as_str(), v.as_bytes())
230                }
231                builder.body(v1r.into_body()).expect("valid conversion")
232            })
233            .collect()
234    }
235
236    /// Build a replay connection from a JSON file
237    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
238        let events: NetworkTraffic =
239            serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
240        Ok(Self::new(events.events))
241    }
242
243    /// Build a replay connection from a sequence of events
244    pub fn new(events: Vec<Event>) -> Self {
245        let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
246        for event in events {
247            let event_buffer = event_map.entry(event.connection_id).or_default();
248            event_buffer.push_back(event);
249        }
250        let verifiable_events = event_map
251            .iter()
252            .map(|(id, events)| {
253                let mut body = BytesMut::new();
254                for event in events {
255                    if let Action::Data {
256                        direction: Direction::Request,
257                        data,
258                    } = &event.action
259                    {
260                        body.extend_from_slice(&data.copy_to_vec());
261                    }
262                }
263                let initial_request = events.iter().next().expect("must have one event");
264                let request = match &initial_request.action {
265                    Action::Request { request } => {
266                        http_1x::Request::from(request).map(|_| Bytes::from(body))
267                    }
268                    _ => panic!("invalid first event"),
269                };
270                (*id, request)
271            })
272            .collect();
273        let verifiable_events = Arc::new(verifiable_events);
274
275        ReplayingClient {
276            live_events: Arc::new(Mutex::new(event_map)),
277            num_events: Arc::new(AtomicUsize::new(0)),
278            recorded_requests: Default::default(),
279            verifiable_events,
280        }
281    }
282}
283
284async fn replay_body(events: VecDeque<Event>, mut sender: crate::test_util::body::Sender) {
285    for event in events {
286        match event.action {
287            Action::Request { .. } => panic!(),
288            Action::Response { .. } => panic!(),
289            Action::Data {
290                data,
291                direction: Direction::Response,
292            } => {
293                sender
294                    .send_data(Bytes::from(data.into_bytes()))
295                    .await
296                    .expect("this is in memory traffic that should not fail to send");
297            }
298            Action::Data {
299                data: _data,
300                direction: Direction::Request,
301            } => {}
302            Action::Eof {
303                direction: Direction::Request,
304                ..
305            } => {}
306            Action::Eof {
307                direction: Direction::Response,
308                ok: true,
309                ..
310            } => {
311                drop(sender);
312                break;
313            }
314            Action::Eof {
315                direction: Direction::Response,
316                ok: false,
317                ..
318            } => {
319                sender.abort();
320                break;
321            }
322        }
323    }
324}
325
326impl HttpConnector for ReplayingClient {
327    fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
328        let event_id = self.next_id();
329        tracing::debug!("received event {}: {request:?}", event_id.0);
330        let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
331            Some(traffic) => traffic,
332            None => {
333                return HttpConnectorFuture::ready(Err(ConnectorError::other(
334                    format!("no data for event {}. request: {:?}", event_id.0, request).into(),
335                    None,
336                )));
337            }
338        };
339
340        let _initial_request = events.pop_front().unwrap();
341        let (sender, body) = crate::test_util::body::channel_body();
342        let recording = self.recorded_requests.clone();
343        let recorded_request = tokio::spawn(async move {
344            let mut data_read = vec![];
345            while let Some(data) = crate::test_util::body::next_data_frame(request.body_mut()).await
346            {
347                data_read
348                    .extend_from_slice(data.expect("in memory request should not fail").as_ref())
349            }
350            request
351                .try_into_http1x()
352                .unwrap()
353                .map(|_body| Bytes::from(data_read))
354        });
355        let mut recorded_request = Waitable::Loading(recorded_request);
356        let fut = async move {
357            let resp: Result<_, ConnectorError> = loop {
358                let event = events
359                    .pop_front()
360                    .expect("no events, needed a response event");
361                match event.action {
362                    // to ensure deterministic behavior if the request EOF happens first in the log,
363                    // wait for the request body to be done before returning a response.
364                    Action::Eof {
365                        direction: Direction::Request,
366                        ..
367                    } => {
368                        recorded_request.wait().await;
369                    }
370                    Action::Request { .. } => panic!("invalid"),
371                    Action::Response {
372                        response: Err(error),
373                    } => break Err(ConnectorError::other(error.0.into(), None)),
374                    Action::Response {
375                        response: Ok(response),
376                    } => {
377                        let mut builder = http_1x::Response::builder().status(response.status);
378                        for (name, values) in response.headers {
379                            for value in values {
380                                builder = builder.header(&name, &value);
381                            }
382                        }
383                        tokio::spawn(async move {
384                            replay_body(events, sender).await;
385                            // insert the finalized body into
386                        });
387                        break Ok(HttpResponse::try_from(
388                            builder.body(body).expect("valid builder"),
389                        )
390                        .unwrap());
391                    }
392
393                    Action::Data {
394                        direction: Direction::Request,
395                        data: _data,
396                    } => {
397                        tracing::info!("get request data");
398                    }
399                    Action::Eof {
400                        direction: Direction::Response,
401                        ..
402                    } => panic!("got eof before response"),
403
404                    Action::Data {
405                        data: _,
406                        direction: Direction::Response,
407                    } => panic!("got response data before response"),
408                }
409            };
410            recording.lock().unwrap().insert(event_id, recorded_request);
411            resp
412        };
413        HttpConnectorFuture::new(fut)
414    }
415}
416
417impl HttpClient for ReplayingClient {
418    fn http_connector(
419        &self,
420        _: &HttpConnectorSettings,
421        _: &RuntimeComponents,
422    ) -> SharedHttpConnector {
423        self.clone().into_shared()
424    }
425
426    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
427        Some(ConnectorMetadata::new("replaying-client", None))
428    }
429}