aws_smithy_http_client/test_util/
dvr.rs1use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
13use aws_smithy_runtime_api::http::Headers;
14use aws_smithy_types::base64;
15use bytes::Bytes;
16use indexmap::IndexMap;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::path::Path;
20
21mod record;
22mod replay;
23
24pub use record::RecordingClient;
25pub use replay::ReplayingClient;
26
27#[derive(Debug, Serialize, Deserialize)]
31pub struct NetworkTraffic {
32 events: Vec<Event>,
33 docs: Option<String>,
34 version: Version,
35}
36
37impl NetworkTraffic {
38 pub fn events(&self) -> &Vec<Event> {
40 &self.events
41 }
42
43 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn std::error::Error>> {
45 let contents = std::fs::read_to_string(path)?;
46 Ok(serde_json::from_str(&contents)?)
47 }
48
49 pub fn write_to_file(&self, path: impl AsRef<Path>) -> Result<(), Box<dyn std::error::Error>> {
51 let serialized = serde_json::to_string_pretty(&self)?;
52 Ok(std::fs::write(path, serialized)?)
53 }
54
55 pub fn correct_content_lengths(&mut self) {
57 let mut content_lengths: HashMap<(ConnectionId, Direction), usize> = HashMap::new();
58 for event in &self.events {
59 if let Action::Data { data, direction } = &event.action {
60 let entry = content_lengths.entry((event.connection_id, *direction));
61 *entry.or_default() += data.copy_to_vec().len();
62 }
63 }
64 for event in &mut self.events {
65 let (headers, direction) = match &mut event.action {
66 Action::Request {
67 request: Request { headers, .. },
68 } => (headers, Direction::Request),
69 Action::Response {
70 response: Ok(Response { headers, .. }),
71 } => (headers, Direction::Response),
72 _ => continue,
73 };
74 let Some(computed_content_length) =
75 content_lengths.get(&(event.connection_id, direction))
76 else {
77 continue;
78 };
79 if headers.contains_key("content-length") {
80 headers.insert(
81 "content-length".to_string(),
82 vec![computed_content_length.to_string()],
83 );
84 }
85 }
86 }
87}
88
89#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
91pub enum Version {
92 V0,
94}
95
96#[derive(Copy, Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
98pub struct ConnectionId(usize);
99
100#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
105pub struct Event {
106 connection_id: ConnectionId,
107 action: Action,
108}
109
110#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
115pub struct Request {
116 uri: String,
117 headers: IndexMap<String, Vec<String>>,
118 method: String,
119}
120
121#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
126pub struct Response {
127 status: u16,
128 headers: IndexMap<String, Vec<String>>,
129}
130
131#[cfg(feature = "legacy-test-util")]
132impl From<&Request> for http_02x::Request<()> {
133 fn from(request: &Request) -> Self {
134 let mut builder = http_02x::Request::builder().uri(request.uri.as_str());
135 for (k, values) in request.headers.iter() {
136 for v in values {
137 builder = builder.header(k, v);
138 }
139 }
140 builder.method(request.method.as_str()).body(()).unwrap()
141 }
142}
143
144impl From<&Request> for http_1x::Request<()> {
145 fn from(request: &Request) -> Self {
146 let mut builder = http_1x::Request::builder().uri(request.uri.as_str());
147 for (k, values) in request.headers.iter() {
148 for v in values {
149 builder = builder.header(k, v);
150 }
151 }
152 builder.method(request.method.as_str()).body(()).unwrap()
153 }
154}
155
156impl<'a> From<&'a HttpRequest> for Request {
157 fn from(req: &'a HttpRequest) -> Self {
158 let uri = req.uri().to_string();
159 let headers = headers_to_map_http(req.headers());
160 let method = req.method().to_string();
161 Self {
162 uri,
163 headers,
164 method,
165 }
166 }
167}
168
169fn headers_to_map_http(headers: &Headers) -> IndexMap<String, Vec<String>> {
170 let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
171 for (header_name, header_value) in headers.iter() {
172 let entry = out.entry(header_name.to_string()).or_default();
173 entry.push(header_value.to_string());
174 }
175 out
176}
177
178fn headers_to_map(headers: &Headers) -> IndexMap<String, Vec<String>> {
179 let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
180 for (header_name, header_value) in headers.iter() {
181 let entry = out.entry(header_name.to_string()).or_default();
182 entry.push(
183 std::str::from_utf8(header_value.as_ref())
184 .unwrap()
185 .to_string(),
186 );
187 }
188 out
189}
190
191#[cfg(feature = "legacy-test-util")]
192fn headers_to_map_02x(headers: &http_02x::HeaderMap) -> IndexMap<String, Vec<String>> {
193 let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
194 for (header_name, header_value) in headers.iter() {
195 let entry = out.entry(header_name.to_string()).or_default();
196 entry.push(
197 std::str::from_utf8(header_value.as_ref())
198 .unwrap()
199 .to_string(),
200 );
201 }
202 out
203}
204
205#[cfg(feature = "legacy-test-util")]
206impl<'a, B> From<&'a http_02x::Response<B>> for Response {
207 fn from(resp: &'a http_02x::Response<B>) -> Self {
208 let status = resp.status().as_u16();
209 let headers = headers_to_map_02x(resp.headers());
210 Self { status, headers }
211 }
212}
213
214fn headers_to_map_1x(headers: &http_1x::HeaderMap) -> IndexMap<String, Vec<String>> {
215 let mut out: IndexMap<_, Vec<_>> = IndexMap::new();
216 for (header_name, header_value) in headers.iter() {
217 let entry = out.entry(header_name.to_string()).or_default();
218 entry.push(
219 std::str::from_utf8(header_value.as_ref())
220 .unwrap()
221 .to_string(),
222 );
223 }
224 out
225}
226
227impl<'a, B> From<&'a http_1x::Response<B>> for Response {
228 fn from(resp: &'a http_1x::Response<B>) -> Self {
229 let status = resp.status().as_u16();
230 let headers = headers_to_map_1x(resp.headers());
231 Self { status, headers }
232 }
233}
234
235impl From<&HttpResponse> for Response {
236 fn from(resp: &HttpResponse) -> Self {
237 Self {
238 status: resp.status().into(),
239 headers: headers_to_map(resp.headers()),
240 }
241 }
242}
243
244#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
246pub struct Error(String);
247
248#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
250#[non_exhaustive]
251pub enum Action {
252 Request {
254 request: Request,
256 },
257
258 Response {
260 response: Result<Response, Error>,
262 },
263
264 Data {
266 data: BodyData,
268 direction: Direction,
270 },
271
272 Eof {
274 ok: bool,
276 direction: Direction,
278 },
279}
280
281#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
285pub enum Direction {
286 Request,
288 Response,
290}
291
292impl Direction {
293 pub fn opposite(self) -> Self {
295 match self {
296 Direction::Request => Direction::Response,
297 Direction::Response => Direction::Request,
298 }
299 }
300}
301
302#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
307#[non_exhaustive]
308pub enum BodyData {
309 Utf8(String),
311
312 Base64(String),
314}
315
316impl BodyData {
317 pub fn into_bytes(self) -> Vec<u8> {
319 match self {
320 BodyData::Utf8(string) => string.into_bytes(),
321 BodyData::Base64(string) => base64::decode(string).unwrap(),
322 }
323 }
324
325 pub fn copy_to_vec(&self) -> Vec<u8> {
327 match self {
328 BodyData::Utf8(string) => string.as_bytes().into(),
329 BodyData::Base64(string) => base64::decode(string).unwrap(),
330 }
331 }
332}
333
334impl From<Bytes> for BodyData {
335 fn from(data: Bytes) -> Self {
336 match std::str::from_utf8(data.as_ref()) {
337 Ok(string) => BodyData::Utf8(string.to_string()),
338 Err(_) => BodyData::Base64(base64::encode(data)),
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use aws_smithy_runtime_api::client::http::{HttpConnector, SharedHttpConnector};
347 use aws_smithy_types::body::SdkBody;
348 use aws_smithy_types::byte_stream::ByteStream;
349 use bytes::Bytes;
350 use std::error::Error;
351 use std::fs;
352
353 #[tokio::test]
354 async fn correctly_fixes_content_lengths() -> Result<(), Box<dyn Error>> {
355 let network_traffic = fs::read_to_string("test-data/example.com.json")?;
356 let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
357 network_traffic.correct_content_lengths();
358 let Action::Request {
359 request: Request { headers, .. },
360 } = &network_traffic.events[0].action
361 else {
362 panic!("unexpected event")
363 };
364 assert_eq!(headers.get("content-length"), None);
366
367 let Action::Response {
368 response: Ok(Response { headers, .. }),
369 } = &network_traffic.events[3].action
370 else {
371 panic!("unexpected event: {:?}", network_traffic.events[3].action);
372 };
373 let expected_length = "hello from example.com".len();
375 assert_eq!(
376 headers.get("content-length"),
377 Some(&vec![expected_length.to_string()])
378 );
379 Ok(())
380 }
381
382 #[cfg(feature = "legacy-test-util")]
383 #[tokio::test]
384 async fn turtles_all_the_way_down() -> Result<(), Box<dyn Error>> {
385 let network_traffic = fs::read_to_string("test-data/example.com.json")?;
388 let mut network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;
389 network_traffic.correct_content_lengths();
390 let inner = ReplayingClient::new(network_traffic.events.clone());
391 let connection = RecordingClient::new(SharedHttpConnector::new(inner.clone()));
392 let req = http_02x::Request::post("https://www.example.com")
393 .body(SdkBody::from("hello world"))
394 .unwrap();
395 let mut resp = connection.call(req.try_into().unwrap()).await.expect("ok");
396 let body = std::mem::replace(resp.body_mut(), SdkBody::taken());
397 let data = ByteStream::new(body).collect().await.unwrap().into_bytes();
398 assert_eq!(
399 String::from_utf8(data.to_vec()).unwrap(),
400 "hello from example.com"
401 );
402 assert_eq!(
403 connection.events().as_slice(),
404 network_traffic.events.as_slice()
405 );
406 let requests = inner.take_requests().await;
407 assert_eq!(
408 requests[0].uri(),
409 &http_02x::Uri::from_static("https://www.example.com")
410 );
411 assert_eq!(
412 requests[0].body(),
413 &Bytes::from_static("hello world".as_bytes())
414 );
415 Ok(())
416 }
417}