aws_smithy_http_client/test_util/dvr/
record.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use super::{
7    Action, BodyData, ConnectionId, Direction, Error, Event, NetworkTraffic, Request, Response,
8    Version,
9};
10use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
11use aws_smithy_runtime_api::client::http::{
12    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
13};
14use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::body::SdkBody;
18use std::path::Path;
19use std::sync::atomic::{AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex, MutexGuard};
21use std::{fs, io};
22use tokio::task::JoinHandle;
23
24/// Recording client
25///
26/// `RecordingClient` wraps an inner connection and records all traffic, enabling traffic replay.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use aws_smithy_async::rt::sleep::default_async_sleep;
32/// use aws_smithy_runtime::client::http::hyper_014::default_connector;
33/// use aws_smithy_http_client::test_util::dvr::RecordingClient;
34/// use aws_smithy_runtime_api::client::http::HttpConnectorSettingsBuilder;
35/// use aws_sdk_s3::{Client, Config};
36///
37/// #[tokio::test]
38/// async fn test_content_length_enforcement_is_not_applied_to_head_request() {
39///     let settings = HttpConnectorSettingsBuilder::default().build();
40///     let http_client = default_connector(&settings, default_async_sleep()).unwrap();
41///     let http_client = RecordingClient::new(http_client);
42///
43///     // Since we need to send a real request for this,
44///     // you'll need to use your real credentials.
45///     let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
46///     let config = Config::from(&config).to_builder()
47///         .http_client(http_client.clone())
48///         .region(Region::new("us-east-1"))
49///         .build();
50///
51///     let client = Client::from_conf(config);
52///     let _resp = client
53///         .head_object()
54///         .key("some-test-file.txt")
55///         .bucket("your-test-bucket")
56///         .send()
57///         .await
58///         .unwrap();
59///
60///     // If the request you want to record has a body, don't forget to poll
61///     // the body to completion BEFORE calling `dump_to_file`. Otherwise, your
62///     // test json won't include the body.
63///     // let _body = _resp.body.collect().await.unwrap();
64///
65///     // This path is relative to your project or workspace `Cargo.toml` file.
66///     http_client.dump_to_file("tests/data/content-length-enforcement/head-object.json").unwrap();
67/// }
68/// ```
69#[derive(Clone, Debug)]
70pub struct RecordingClient {
71    pub(crate) data: Arc<Mutex<Vec<Event>>>,
72    pub(crate) num_events: Arc<AtomicUsize>,
73    pub(crate) inner: SharedHttpConnector,
74}
75
76#[cfg(feature = "legacy-rustls-ring")]
77impl RecordingClient {
78    /// Construct a recording connection wrapping a default HTTPS implementation without any timeouts.
79    pub fn https() -> Self {
80        #[allow(deprecated)]
81        use crate::hyper_014::HyperConnector;
82        Self {
83            data: Default::default(),
84            num_events: Arc::new(AtomicUsize::new(0)),
85            #[allow(deprecated)]
86            inner: SharedHttpConnector::new(HyperConnector::builder().build_https()),
87        }
88    }
89}
90
91impl RecordingClient {
92    /// Create a new recording connection from a connection
93    pub fn new(underlying_connector: impl HttpConnector + 'static) -> Self {
94        Self {
95            data: Default::default(),
96            num_events: Arc::new(AtomicUsize::new(0)),
97            inner: underlying_connector.into_shared(),
98        }
99    }
100
101    /// Return the traffic recorded by this connection
102    pub fn events(&self) -> MutexGuard<'_, Vec<Event>> {
103        self.data.lock().unwrap()
104    }
105
106    /// NetworkTraffic struct suitable for serialization
107    pub fn network_traffic(&self) -> NetworkTraffic {
108        NetworkTraffic {
109            events: self.events().clone(),
110            docs: Some("todo docs".into()),
111            version: Version::V0,
112        }
113    }
114
115    /// Dump the network traffic to a file
116    pub fn dump_to_file(&self, path: impl AsRef<Path>) -> Result<(), io::Error> {
117        fs::write(
118            path,
119            serde_json::to_string(&self.network_traffic()).unwrap(),
120        )
121    }
122
123    fn next_id(&self) -> ConnectionId {
124        ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
125    }
126}
127
128fn record_body(
129    body: &mut SdkBody,
130    event_id: ConnectionId,
131    direction: Direction,
132    event_bus: Arc<Mutex<Vec<Event>>>,
133) -> JoinHandle<()> {
134    let (sender, output_body) = crate::test_util::body::channel_body();
135    let real_body = std::mem::replace(body, output_body);
136    tokio::spawn(async move {
137        let mut real_body = real_body;
138        let mut sender = sender;
139        loop {
140            let data = crate::test_util::body::next_data_frame(&mut real_body).await;
141            match data {
142                Some(Ok(data)) => {
143                    event_bus.lock().unwrap().push(Event {
144                        connection_id: event_id,
145                        action: Action::Data {
146                            data: BodyData::from(data.clone()),
147                            direction,
148                        },
149                    });
150                    // This happens if the real connection is closed during recording.
151                    // Need to think more carefully if this is the correct thing to log in this
152                    // case.
153                    if sender.send_data(data).await.is_err() {
154                        event_bus.lock().unwrap().push(Event {
155                            connection_id: event_id,
156                            action: Action::Eof {
157                                direction: direction.opposite(),
158                                ok: false,
159                            },
160                        })
161                    };
162                }
163                None => {
164                    event_bus.lock().unwrap().push(Event {
165                        connection_id: event_id,
166                        action: Action::Eof {
167                            ok: true,
168                            direction,
169                        },
170                    });
171                    drop(sender);
172                    break;
173                }
174                Some(Err(_err)) => {
175                    event_bus.lock().unwrap().push(Event {
176                        connection_id: event_id,
177                        action: Action::Eof {
178                            ok: false,
179                            direction,
180                        },
181                    });
182                    sender.abort();
183                    break;
184                }
185            }
186        }
187    })
188}
189
190impl HttpConnector for RecordingClient {
191    fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
192        let event_id = self.next_id();
193        // A request has three phases:
194        // 1. A "Request" phase. This is initial HTTP request, headers, & URI
195        // 2. A body phase. This may contain multiple data segments.
196        // 3. A finalization phase. An EOF of some sort is sent on the body to indicate that
197        // the channel should be closed.
198
199        // Phase 1: the initial http request
200        self.data.lock().unwrap().push(Event {
201            connection_id: event_id,
202            action: Action::Request {
203                request: Request::from(&request),
204            },
205        });
206
207        // Phase 2: Swap out the real request body for one that will log all traffic that passes
208        // through it
209        // This will also handle phase three when the request body runs out of data.
210        record_body(
211            request.body_mut(),
212            event_id,
213            Direction::Request,
214            self.data.clone(),
215        );
216        let events = self.data.clone();
217        // create a channel we'll use to stream the data while reading it
218        let resp_fut = self.inner.call(request);
219        let fut = async move {
220            let resp = resp_fut.await;
221            match resp {
222                Ok(mut resp) => {
223                    // push the initial response event
224                    events.lock().unwrap().push(Event {
225                        connection_id: event_id,
226                        action: Action::Response {
227                            response: Ok(Response::from(&resp)),
228                        },
229                    });
230
231                    // instrument the body and record traffic
232                    record_body(resp.body_mut(), event_id, Direction::Response, events);
233                    Ok(resp)
234                }
235                Err(e) => {
236                    events.lock().unwrap().push(Event {
237                        connection_id: event_id,
238                        action: Action::Response {
239                            response: Err(Error(format!("{}", &e))),
240                        },
241                    });
242                    Err(e)
243                }
244            }
245        };
246        HttpConnectorFuture::new(fut)
247    }
248}
249
250impl HttpClient for RecordingClient {
251    fn http_connector(
252        &self,
253        _: &HttpConnectorSettings,
254        _: &RuntimeComponents,
255    ) -> SharedHttpConnector {
256        self.clone().into_shared()
257    }
258
259    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
260        Some(ConnectorMetadata::new("recording-client", None))
261    }
262}