aws_smithy_http_client/test_util/
dvr.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Extremely Experimental Test Connection
7//!
8//! Warning: Extremely experimental, API likely to change.
9//!
10//! DVR is an extremely experimental record & replay framework that supports multi-frame HTTP request / response traffic.
11
12use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
13use aws_smithy_runtime_api::http::Headers;
14use aws_smithy_types::base64;
15use bytes::Bytes;
16use indexmap::IndexMap;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::path::Path;
20
21mod record;
22mod replay;
23
24pub use record::RecordingClient;
25pub use replay::ReplayingClient;
26
27/// A complete traffic recording
28///
29/// A traffic recording can be replayed with [`RecordingClient`].
30#[derive(Debug, Serialize, Deserialize)]
31pub struct NetworkTraffic {
32    events: Vec<Event>,
33    docs: Option<String>,
34    version: Version,
35}
36
37impl NetworkTraffic {
38    /// Network events
39    pub fn events(&self) -> &Vec<Event> {
40        &self.events
41    }
42
43    /// Create a NetworkTraffic instance from a file
44    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn std::error::Error>> {
45        let contents = std::fs::read_to_string(path)?;
46        Ok(serde_json::from_str(&contents)?)
47    }
48
49    /// Create a NetworkTraffic instance from a file
50    pub fn write_to_file(&self, path: impl AsRef<Path>) -> Result<(), Box<dyn std::error::Error>> {
51        let serialized = serde_json::to_string_pretty(&self)?;
52        Ok(std::fs::write(path, serialized)?)
53    }
54
55    /// Update the network traffic with all `content-length` fields fixed to match the contents
56    pub fn correct_content_lengths(&mut self) {
57        let mut content_lengths: HashMap<(ConnectionId, Direction), usize> = HashMap::new();
58        for event in &self.events {
59            if let Action::Data { data, direction } = &event.action {
60                let entry = content_lengths.entry((event.connection_id, *direction));
61                *entry.or_default() += data.copy_to_vec().len();
62            }
63        }
64        for event in &mut self.events {
65            let (headers, direction) = match &mut event.action {
66                Action::Request {
67                    request: Request { headers, .. },
68                } => (headers, Direction::Request),
69                Action::Response {
70                    response: Ok(Response { headers, .. }),
71                } => (headers, Direction::Response),
72                _ => continue,
73            };
74            let Some(computed_content_length) =
75                content_lengths.get(&(event.connection_id, direction))
76            else {
77                continue;
78            };
79            if headers.contains_key("content-length") {
80                headers.insert(
81                    "content-length".to_string(),
82                    vec![computed_content_length.to_string()],
83                );
84            }
85        }
86    }
87}
88
89/// Serialization version of DVR data
90#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub enum Version {
92    /// Initial network traffic version
93    V0,
94}
95
96/// A network traffic recording may contain multiple different connections occurring simultaneously
97#[derive(Copy, Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
98pub struct ConnectionId(usize);
99
100/// A network event
101///
102/// Network events consist of a connection identifier and an action. An event is sufficient to
103/// reproduce traffic later during replay
104#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
105pub struct Event {
106    connection_id: ConnectionId,
107    action: Action,
108}
109
110/// An initial HTTP request, roughly equivalent to `http::Request<()>`
111///
112/// The initial request phase of an HTTP request. The body will be
113/// sent later as a separate action.
114#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
115pub struct Request {
116    uri: String,
117    headers: IndexMap<String, Vec<String>>,
118    method: String,
119}
120
121/// An initial HTTP response roughly equivalent to `http::Response<()>`
122///
123/// The initial response phase of an HTTP request. The body will be
124/// sent later as a separate action.
125#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
126pub struct Response {
127    status: u16,
128    headers: IndexMap<String, Vec<String>>,
129}
130
131#[cfg(feature = "legacy-test-util")]
132impl From<&Request> for http_02x::Request<()> {
133    fn from(request: &Request) -> Self {
134        let mut builder = http_02x::Request::builder().uri(request.uri.as_str());
135        for (k, values) in request.headers.iter() {
136            for v in values {
137                builder = builder.header(k, v);
138            }
139        }
140        builder.method(request.method.as_str()).body(()).unwrap()
141    }
142}
143
144impl From<&Request> for http_1x::Request<()> {
145    fn from(request: &Request) -> Self {
146        let mut builder = http_1x::Request::builder().uri(request.uri.as_str());
147        for (k, values) in request.headers.iter() {
148            for v in values {
149                builder = builder.header(k, v);
150            }
151        }
152        builder.method(request.method.as_str()).body(()).unwrap()
153    }
154}
155
156impl<'a> From<&'a HttpRequest> for Request {
157    fn from(req: &'a HttpRequest) -> Self {
158        let uri = req.uri().to_string();
159        let headers = headers_to_map_http(req.headers());
160        let method = req.method().to_string();
161        Self {
162            uri,
163            headers,
164            method,
165        }
166    }
167}
168
169fn headers_to_map_http(headers: &Headers) -> IndexMap<String, Vec<String>> {
170    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
171    for (header_name, header_value) in headers.iter() {
172        let entry = out.entry(header_name.to_string()).or_default();
173        entry.push(header_value.to_string());
174    }
175    out
176}
177
178fn headers_to_map(headers: &Headers) -> IndexMap<String, Vec<String>> {
179    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
180    for (header_name, header_value) in headers.iter() {
181        let entry = out.entry(header_name.to_string()).or_default();
182        entry.push(
183            std::str::from_utf8(header_value.as_ref())
184                .unwrap()
185                .to_string(),
186        );
187    }
188    out
189}
190
191#[cfg(feature = "legacy-test-util")]
192fn headers_to_map_02x(headers: &http_02x::HeaderMap) -> IndexMap<String, Vec<String>> {
193    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
194    for (header_name, header_value) in headers.iter() {
195        let entry = out.entry(header_name.to_string()).or_default();
196        entry.push(
197            std::str::from_utf8(header_value.as_ref())
198                .unwrap()
199                .to_string(),
200        );
201    }
202    out
203}
204
205#[cfg(feature = "legacy-test-util")]
206impl<'a, B> From<&'a http_02x::Response<B>> for Response {
207    fn from(resp: &'a http_02x::Response<B>) -> Self {
208        let status = resp.status().as_u16();
209        let headers = headers_to_map_02x(resp.headers());
210        Self { status, headers }
211    }
212}
213
214fn headers_to_map_1x(headers: &http_1x::HeaderMap) -> IndexMap<String, Vec<String>> {
215    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
216    for (header_name, header_value) in headers.iter() {
217        let entry = out.entry(header_name.to_string()).or_default();
218        entry.push(
219            std::str::from_utf8(header_value.as_ref())
220                .unwrap()
221                .to_string(),
222        );
223    }
224    out
225}
226
227impl<'a, B> From<&'a http_1x::Response<B>> for Response {
228    fn from(resp: &'a http_1x::Response<B>) -> Self {
229        let status = resp.status().as_u16();
230        let headers = headers_to_map_1x(resp.headers());
231        Self { status, headers }
232    }
233}
234
235impl From<&HttpResponse> for Response {
236    fn from(resp: &HttpResponse) -> Self {
237        Self {
238            status: resp.status().into(),
239            headers: headers_to_map(resp.headers()),
240        }
241    }
242}
243
244/// Error response wrapper
245#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
246pub struct Error(String);
247
248/// Network Action
249#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
250#[non_exhaustive]
251pub enum Action {
252    /// Initial HTTP Request
253    Request {
254        /// HTTP Request headers, method, and URI
255        request: Request,
256    },
257
258    /// Initial HTTP response or failure
259    Response {
260        /// HTTP response or failure
261        response: Result<Response, Error>,
262    },
263
264    /// Data segment
265    Data {
266        /// Body Data
267        data: BodyData,
268        /// Direction: request vs. response
269        direction: Direction,
270    },
271
272    /// End of data
273    Eof {
274        /// Succesful vs. failed termination
275        ok: bool,
276        /// Direction: request vs. response
277        direction: Direction,
278    },
279}
280
281/// Event direction
282///
283/// During replay, this is used to replay data in the right direction
284#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
285pub enum Direction {
286    /// Request phase
287    Request,
288    /// Response phase
289    Response,
290}
291
292impl Direction {
293    /// The opposite of a given direction
294    pub fn opposite(self) -> Self {
295        match self {
296            Direction::Request => Direction::Response,
297            Direction::Response => Direction::Request,
298        }
299    }
300}
301
302/// HTTP Body Data Abstraction
303///
304/// When the data is a UTF-8 encoded string, it will be serialized as a string for readability.
305/// Otherwise, it will be base64 encoded.
306#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
307#[non_exhaustive]
308pub enum BodyData {
309    /// UTF-8 encoded data
310    Utf8(String),
311
312    /// Base64 encoded binary data
313    Base64(String),
314}
315
316impl BodyData {
317    /// Convert [`BodyData`] into Bytes.
318    pub fn into_bytes(self) -> Vec<u8> {
319        match self {
320            BodyData::Utf8(string) => string.into_bytes(),
321            BodyData::Base64(string) => base64::decode(string).unwrap(),
322        }
323    }
324
325    /// Copy [`BodyData`] into a `Vec<u8>`.
326    pub fn copy_to_vec(&self) -> Vec<u8> {
327        match self {
328            BodyData::Utf8(string) => string.as_bytes().into(),
329            BodyData::Base64(string) => base64::decode(string).unwrap(),
330        }
331    }
332}
333
334impl From<Bytes> for BodyData {
335    fn from(data: Bytes) -> Self {
336        match std::str::from_utf8(data.as_ref()) {
337            Ok(string) => BodyData::Utf8(string.to_string()),
338            Err(_) => BodyData::Base64(base64::encode(data)),
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use aws_smithy_runtime_api::client::http::{HttpConnector, SharedHttpConnector};
347    use aws_smithy_types::body::SdkBody;
348    use aws_smithy_types::byte_stream::ByteStream;
349    use bytes::Bytes;
350    use std::error::Error;
351    use std::fs;
352
353    #[tokio::test]
354    async fn correctly_fixes_content_lengths() -> Result<(), Box<dyn Error>> {
355        let network_traffic = fs::read_to_string("test-data/example.com.json")?;
356        let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
357        network_traffic.correct_content_lengths();
358        let Action::Request {
359            request: Request { headers, .. },
360        } = &network_traffic.events[0].action
361        else {
362            panic!("unexpected event")
363        };
364        // content length is not added when it wasn't initially present
365        assert_eq!(headers.get("content-length"), None);
366
367        let Action::Response {
368            response: Ok(Response { headers, .. }),
369        } = &network_traffic.events[3].action
370        else {
371            panic!("unexpected event: {:?}", network_traffic.events[3].action);
372        };
373        // content length is not added when it wasn't initially present
374        let expected_length = "hello from example.com".len();
375        assert_eq!(
376            headers.get("content-length"),
377            Some(&vec![expected_length.to_string()])
378        );
379        Ok(())
380    }
381
382    #[cfg(feature = "legacy-test-util")]
383    #[tokio::test]
384    async fn turtles_all_the_way_down() -> Result<(), Box<dyn Error>> {
385        // create a replaying connection from a recording, wrap a recording connection around it,
386        // make a request, then verify that the same traffic was recorded.
387        let network_traffic = fs::read_to_string("test-data/example.com.json")?;
388        let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
389        network_traffic.correct_content_lengths();
390        let inner = ReplayingClient::new(network_traffic.events.clone());
391        let connection = RecordingClient::new(SharedHttpConnector::new(inner.clone()));
392        let req = http_02x::Request::post("https://www.example.com")
393            .body(SdkBody::from("hello world"))
394            .unwrap();
395        let mut resp = connection.call(req.try_into().unwrap()).await.expect("ok");
396        let body = std::mem::replace(resp.body_mut(), SdkBody::taken());
397        let data = ByteStream::new(body).collect().await.unwrap().into_bytes();
398        assert_eq!(
399            String::from_utf8(data.to_vec()).unwrap(),
400            "hello from example.com"
401        );
402        assert_eq!(
403            connection.events().as_slice(),
404            network_traffic.events.as_slice()
405        );
406        let requests = inner.take_requests().await;
407        assert_eq!(
408            requests[0].uri(),
409            &http_02x::Uri::from_static("https://www.example.com")
410        );
411        assert_eq!(
412            requests[0].body(),
413            &Bytes::from_static("hello world".as_bytes())
414        );
415        Ok(())
416    }
417}