aws_smithy_http_client/test_util/
dvr.rs1use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
13use aws_smithy_runtime_api::http::Headers;
14use aws_smithy_types::base64;
15use bytes::Bytes;
16use indexmap::IndexMap;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::path::Path;
20
21mod record;
22mod replay;
23
24pub use record::RecordingClient;
25pub use replay::ReplayingClient;
26
27#[derive(Debug, Serialize, Deserialize)]
31pub struct NetworkTraffic {
32    events: Vec<Event>,
33    docs: Option<String>,
34    version: Version,
35}
36
37impl NetworkTraffic {
38    pub fn events(&self) -> &Vec<Event> {
40        &self.events
41    }
42
43    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn std::error::Error>> {
45        let contents = std::fs::read_to_string(path)?;
46        Ok(serde_json::from_str(&contents)?)
47    }
48
49    pub fn write_to_file(&self, path: impl AsRef<Path>) -> Result<(), Box<dyn std::error::Error>> {
51        let serialized = serde_json::to_string_pretty(&self)?;
52        Ok(std::fs::write(path, serialized)?)
53    }
54
55    pub fn correct_content_lengths(&mut self) {
57        let mut content_lengths: HashMap<(ConnectionId, Direction), usize> = HashMap::new();
58        for event in &self.events {
59            if let Action::Data { data, direction } = &event.action {
60                let entry = content_lengths.entry((event.connection_id, *direction));
61                *entry.or_default() += data.copy_to_vec().len();
62            }
63        }
64        for event in &mut self.events {
65            let (headers, direction) = match &mut event.action {
66                Action::Request {
67                    request: Request { headers, .. },
68                } => (headers, Direction::Request),
69                Action::Response {
70                    response: Ok(Response { headers, .. }),
71                } => (headers, Direction::Response),
72                _ => continue,
73            };
74            let Some(computed_content_length) =
75                content_lengths.get(&(event.connection_id, direction))
76            else {
77                continue;
78            };
79            if headers.contains_key("content-length") {
80                headers.insert(
81                    "content-length".to_string(),
82                    vec![computed_content_length.to_string()],
83                );
84            }
85        }
86    }
87}
88
89#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub enum Version {
92    V0,
94}
95
96#[derive(Copy, Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
98pub struct ConnectionId(usize);
99
100#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
105pub struct Event {
106    connection_id: ConnectionId,
107    action: Action,
108}
109
110#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
115pub struct Request {
116    uri: String,
117    headers: IndexMap<String, Vec<String>>,
118    method: String,
119}
120
121#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
126pub struct Response {
127    status: u16,
128    headers: IndexMap<String, Vec<String>>,
129}
130
131#[cfg(feature = "legacy-test-util")]
132impl From<&Request> for http_02x::Request<()> {
133    fn from(request: &Request) -> Self {
134        let mut builder = http_02x::Request::builder().uri(request.uri.as_str());
135        for (k, values) in request.headers.iter() {
136            for v in values {
137                builder = builder.header(k, v);
138            }
139        }
140        builder.method(request.method.as_str()).body(()).unwrap()
141    }
142}
143
144impl From<&Request> for http_1x::Request<()> {
145    fn from(request: &Request) -> Self {
146        let mut builder = http_1x::Request::builder().uri(request.uri.as_str());
147        for (k, values) in request.headers.iter() {
148            for v in values {
149                builder = builder.header(k, v);
150            }
151        }
152        builder.method(request.method.as_str()).body(()).unwrap()
153    }
154}
155
156impl<'a> From<&'a HttpRequest> for Request {
157    fn from(req: &'a HttpRequest) -> Self {
158        let uri = req.uri().to_string();
159        let headers = headers_to_map_http(req.headers());
160        let method = req.method().to_string();
161        Self {
162            uri,
163            headers,
164            method,
165        }
166    }
167}
168
169fn headers_to_map_http(headers: &Headers) -> IndexMap<String, Vec<String>> {
170    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
171    for (header_name, header_value) in headers.iter() {
172        let entry = out.entry(header_name.to_string()).or_default();
173        entry.push(header_value.to_string());
174    }
175    out
176}
177
178fn headers_to_map(headers: &Headers) -> IndexMap<String, Vec<String>> {
179    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
180    for (header_name, header_value) in headers.iter() {
181        let entry = out.entry(header_name.to_string()).or_default();
182        entry.push(
183            std::str::from_utf8(header_value.as_ref())
184                .unwrap()
185                .to_string(),
186        );
187    }
188    out
189}
190
191#[cfg(feature = "legacy-test-util")]
192fn headers_to_map_02x(headers: &http_02x::HeaderMap) -> IndexMap<String, Vec<String>> {
193    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
194    for (header_name, header_value) in headers.iter() {
195        let entry = out.entry(header_name.to_string()).or_default();
196        entry.push(
197            std::str::from_utf8(header_value.as_ref())
198                .unwrap()
199                .to_string(),
200        );
201    }
202    out
203}
204
205#[cfg(feature = "legacy-test-util")]
206impl<'a, B> From<&'a http_02x::Response<B>> for Response {
207    fn from(resp: &'a http_02x::Response<B>) -> Self {
208        let status = resp.status().as_u16();
209        let headers = headers_to_map_02x(resp.headers());
210        Self { status, headers }
211    }
212}
213
214fn headers_to_map_1x(headers: &http_1x::HeaderMap) -> IndexMap<String, Vec<String>> {
215    let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
216    for (header_name, header_value) in headers.iter() {
217        let entry = out.entry(header_name.to_string()).or_default();
218        entry.push(
219            std::str::from_utf8(header_value.as_ref())
220                .unwrap()
221                .to_string(),
222        );
223    }
224    out
225}
226
227impl<'a, B> From<&'a http_1x::Response<B>> for Response {
228    fn from(resp: &'a http_1x::Response<B>) -> Self {
229        let status = resp.status().as_u16();
230        let headers = headers_to_map_1x(resp.headers());
231        Self { status, headers }
232    }
233}
234
235impl From<&HttpResponse> for Response {
236    fn from(resp: &HttpResponse) -> Self {
237        Self {
238            status: resp.status().into(),
239            headers: headers_to_map(resp.headers()),
240        }
241    }
242}
243
244#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
246pub struct Error(String);
247
248#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
250#[non_exhaustive]
251pub enum Action {
252    Request {
254        request: Request,
256    },
257
258    Response {
260        response: Result<Response, Error>,
262    },
263
264    Data {
266        data: BodyData,
268        direction: Direction,
270    },
271
272    Eof {
274        ok: bool,
276        direction: Direction,
278    },
279}
280
281#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
285pub enum Direction {
286    Request,
288    Response,
290}
291
292impl Direction {
293    pub fn opposite(self) -> Self {
295        match self {
296            Direction::Request => Direction::Response,
297            Direction::Response => Direction::Request,
298        }
299    }
300}
301
302#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
307#[non_exhaustive]
308pub enum BodyData {
309    Utf8(String),
311
312    Base64(String),
314}
315
316impl BodyData {
317    pub fn into_bytes(self) -> Vec<u8> {
319        match self {
320            BodyData::Utf8(string) => string.into_bytes(),
321            BodyData::Base64(string) => base64::decode(string).unwrap(),
322        }
323    }
324
325    pub fn copy_to_vec(&self) -> Vec<u8> {
327        match self {
328            BodyData::Utf8(string) => string.as_bytes().into(),
329            BodyData::Base64(string) => base64::decode(string).unwrap(),
330        }
331    }
332}
333
334impl From<Bytes> for BodyData {
335    fn from(data: Bytes) -> Self {
336        match std::str::from_utf8(data.as_ref()) {
337            Ok(string) => BodyData::Utf8(string.to_string()),
338            Err(_) => BodyData::Base64(base64::encode(data)),
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    use std::error::Error;
348    use std::fs;
349
350    use aws_smithy_runtime_api::client::http::HttpConnector;
351    use aws_smithy_runtime_api::client::http::SharedHttpConnector;
352    use aws_smithy_types::body::SdkBody;
353    use aws_smithy_types::byte_stream::ByteStream;
354
355    #[tokio::test]
356    async fn correctly_fixes_content_lengths() -> Result<(), Box<dyn Error>> {
357        let network_traffic = fs::read_to_string("test-data/example.com.json")?;
358        let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
359        network_traffic.correct_content_lengths();
360        let Action::Request {
361            request: Request { headers, .. },
362        } = &network_traffic.events[0].action
363        else {
364            panic!("unexpected event")
365        };
366        assert_eq!(headers.get("content-length"), None);
368
369        let Action::Response {
370            response: Ok(Response { headers, .. }),
371        } = &network_traffic.events[3].action
372        else {
373            panic!("unexpected event: {:?}", network_traffic.events[3].action);
374        };
375        let expected_length = "hello from example.com".len();
377        assert_eq!(
378            headers.get("content-length"),
379            Some(&vec![expected_length.to_string()])
380        );
381        Ok(())
382    }
383
384    #[cfg(feature = "legacy-test-util")]
385    #[tokio::test]
386    async fn turtles_all_the_way_down() -> Result<(), Box<dyn Error>> {
387        let network_traffic = fs::read_to_string("test-data/example.com.json")?;
390        let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
391        network_traffic.correct_content_lengths();
392        let inner = ReplayingClient::new(network_traffic.events.clone());
393        let connection = RecordingClient::new(SharedHttpConnector::new(inner.clone()));
394        let req = http_02x::Request::post("https://www.example.com")
395            .body(SdkBody::from("hello world"))
396            .unwrap();
397        let mut resp = connection.call(req.try_into().unwrap()).await.expect("ok");
398        let body = std::mem::replace(resp.body_mut(), SdkBody::taken());
399        let data = ByteStream::new(body).collect().await.unwrap().into_bytes();
400        assert_eq!(
401            String::from_utf8(data.to_vec()).unwrap(),
402            "hello from example.com"
403        );
404        assert_eq!(
405            connection.events().as_slice(),
406            network_traffic.events.as_slice()
407        );
408        let requests = inner.take_requests().await;
409        assert_eq!(
410            requests[0].uri(),
411            &http_02x::Uri::from_static("https://www.example.com")
412        );
413        assert_eq!(
414            requests[0].body(),
415            &Bytes::from_static("hello world".as_bytes())
416        );
417        Ok(())
418    }
419}