use super::{Action, ConnectionId, Direction, Event, NetworkTraffic};
use crate::client::http::test_util::replay::DEFAULT_RELAXED_HEADERS;
use aws_smithy_protocol_test::MediaType;
use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
use aws_smithy_runtime_api::client::http::{
HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::error::display::DisplayErrorContext;
use bytes::{Bytes, BytesMut};
use std::collections::{HashMap, VecDeque};
use std::error::Error;
use std::fmt;
use std::ops::DerefMut;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
#[derive(Debug)]
enum Waitable<T> {
Loading(JoinHandle<T>),
Value(T),
}
impl<T> Waitable<T> {
async fn take(self) -> T {
match self {
Waitable::Loading(f) => f.await.expect("join failed"),
Waitable::Value(value) => value,
}
}
async fn wait(&mut self) {
match self {
Waitable::Loading(f) => *self = Waitable::Value(f.await.expect("join failed")),
Waitable::Value(_) => {}
}
}
}
#[derive(Clone)]
pub struct ReplayingClient {
live_events: Arc<Mutex<HashMap<ConnectionId, VecDeque<Event>>>>,
verifiable_events: Arc<HashMap<ConnectionId, http_02x::Request<Bytes>>>,
num_events: Arc<AtomicUsize>,
recorded_requests: Arc<Mutex<HashMap<ConnectionId, Waitable<http_02x::Request<Bytes>>>>>,
}
impl fmt::Debug for ReplayingClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("test_util::dvr::ReplayingClient")
}
}
enum HeadersToCheck<'a> {
Include(&'a [&'a str]),
Exclude(Option<&'a [&'a str]>),
}
impl ReplayingClient {
fn next_id(&self) -> ConnectionId {
ConnectionId(self.num_events.fetch_add(1, Ordering::Relaxed))
}
pub async fn full_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
self.validate_body_and_headers(None, media_type).await
}
pub async fn relaxed_validate(self, media_type: &str) -> Result<(), Box<dyn Error>> {
self.validate_body_and_headers_except(DEFAULT_RELAXED_HEADERS, media_type)
.await
}
pub async fn validate(
self,
checked_headers: &[&str],
body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
) -> Result<(), Box<dyn Error>> {
self.validate_base(HeadersToCheck::Include(checked_headers), body_comparer)
.await
}
pub async fn validate_body_and_headers(
self,
checked_headers: Option<&[&str]>,
media_type: &str,
) -> Result<(), Box<dyn Error>> {
let headers_to_check = match checked_headers {
Some(headers) => HeadersToCheck::Include(headers),
None => HeadersToCheck::Exclude(None),
};
self.validate_base(headers_to_check, |b1, b2| {
aws_smithy_protocol_test::validate_body(
b1,
std::str::from_utf8(b2).unwrap(),
MediaType::from(media_type),
)
.map_err(|e| Box::new(e) as _)
})
.await
}
pub async fn validate_body_and_headers_except(
self,
excluded_headers: &[&str],
media_type: &str,
) -> Result<(), Box<dyn Error>> {
self.validate_base(HeadersToCheck::Exclude(Some(excluded_headers)), |b1, b2| {
aws_smithy_protocol_test::validate_body(
b1,
std::str::from_utf8(b2).unwrap(),
MediaType::from(media_type),
)
.map_err(|e| Box::new(e) as _)
})
.await
}
async fn validate_base(
self,
checked_headers: HeadersToCheck<'_>,
body_comparer: impl Fn(&[u8], &[u8]) -> Result<(), Box<dyn Error>>,
) -> Result<(), Box<dyn Error>> {
let mut actual_requests =
std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
for conn_id in 0..self.verifiable_events.len() {
let conn_id = ConnectionId(conn_id);
let expected = self.verifiable_events.get(&conn_id).unwrap();
let actual = actual_requests
.remove(&conn_id)
.ok_or(format!(
"expected connection {:?} but request was never sent",
conn_id
))?
.take()
.await;
body_comparer(expected.body().as_ref(), actual.body().as_ref())?;
let actual: HttpRequest = actual.map(SdkBody::from).try_into()?;
aws_smithy_protocol_test::assert_uris_match(&expected.uri().to_string(), actual.uri());
let expected_headers = expected
.headers()
.keys()
.map(|k| k.as_str())
.filter(|k| match checked_headers {
HeadersToCheck::Include(headers) => headers.contains(k),
HeadersToCheck::Exclude(excluded) => match excluded {
Some(headers) => !headers.contains(k),
None => true,
},
})
.flat_map(|key| {
let _ = expected.headers().get(key)?;
Some((
key,
expected
.headers()
.get_all(key)
.iter()
.map(|h| h.to_str().unwrap())
.collect::<Vec<_>>()
.join(", "),
))
})
.collect::<Vec<_>>();
aws_smithy_protocol_test::validate_headers(actual.headers(), expected_headers)
.map_err(|err| {
format!(
"event {} validation failed with: {}",
conn_id.0,
DisplayErrorContext(&err)
)
})?;
}
Ok(())
}
pub async fn take_requests(self) -> Vec<http_02x::Request<Bytes>> {
let mut recorded_requests =
std::mem::take(self.recorded_requests.lock().unwrap().deref_mut());
let mut out = Vec::with_capacity(recorded_requests.len());
for conn_id in 0..recorded_requests.len() {
out.push(
recorded_requests
.remove(&ConnectionId(conn_id))
.expect("should exist")
.take()
.await,
)
}
out
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
let events: NetworkTraffic =
serde_json::from_str(&std::fs::read_to_string(path.as_ref())?)?;
Ok(Self::new(events.events))
}
pub fn new(events: Vec<Event>) -> Self {
let mut event_map: HashMap<_, VecDeque<_>> = HashMap::new();
for event in events {
let event_buffer = event_map.entry(event.connection_id).or_default();
event_buffer.push_back(event);
}
let verifiable_events = event_map
.iter()
.map(|(id, events)| {
let mut body = BytesMut::new();
for event in events {
if let Action::Data {
direction: Direction::Request,
data,
} = &event.action
{
body.extend_from_slice(&data.copy_to_vec());
}
}
let initial_request = events.iter().next().expect("must have one event");
let request = match &initial_request.action {
Action::Request { request } => {
http_02x::Request::from(request).map(|_| Bytes::from(body))
}
_ => panic!("invalid first event"),
};
(*id, request)
})
.collect();
let verifiable_events = Arc::new(verifiable_events);
ReplayingClient {
live_events: Arc::new(Mutex::new(event_map)),
num_events: Arc::new(AtomicUsize::new(0)),
recorded_requests: Default::default(),
verifiable_events,
}
}
}
async fn replay_body(events: VecDeque<Event>, mut sender: hyper_0_14::body::Sender) {
for event in events {
match event.action {
Action::Request { .. } => panic!(),
Action::Response { .. } => panic!(),
Action::Data {
data,
direction: Direction::Response,
} => {
sender
.send_data(Bytes::from(data.into_bytes()))
.await
.expect("this is in memory traffic that should not fail to send");
}
Action::Data {
data: _data,
direction: Direction::Request,
} => {}
Action::Eof {
direction: Direction::Request,
..
} => {}
Action::Eof {
direction: Direction::Response,
ok: true,
..
} => {
drop(sender);
break;
}
Action::Eof {
direction: Direction::Response,
ok: false,
..
} => {
sender.abort();
break;
}
}
}
}
impl HttpConnector for ReplayingClient {
fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture {
use http_body_04x::Body;
let event_id = self.next_id();
tracing::debug!("received event {}: {request:?}", event_id.0);
let mut events = match self.live_events.lock().unwrap().remove(&event_id) {
Some(traffic) => traffic,
None => {
return HttpConnectorFuture::ready(Err(ConnectorError::other(
format!("no data for event {}. request: {:?}", event_id.0, request).into(),
None,
)));
}
};
let _initial_request = events.pop_front().unwrap();
let (sender, response_body) = hyper_0_14::Body::channel();
let body = SdkBody::from_body_0_4(response_body);
let recording = self.recorded_requests.clone();
let recorded_request = tokio::spawn(async move {
let mut data_read = vec![];
while let Some(data) = request.body_mut().data().await {
data_read
.extend_from_slice(data.expect("in memory request should not fail").as_ref())
}
request
.try_into_http02x()
.unwrap()
.map(|_body| Bytes::from(data_read))
});
let mut recorded_request = Waitable::Loading(recorded_request);
let fut = async move {
let resp: Result<_, ConnectorError> = loop {
let event = events
.pop_front()
.expect("no events, needed a response event");
match event.action {
Action::Eof {
direction: Direction::Request,
..
} => {
recorded_request.wait().await;
}
Action::Request { .. } => panic!("invalid"),
Action::Response {
response: Err(error),
} => break Err(ConnectorError::other(error.0.into(), None)),
Action::Response {
response: Ok(response),
} => {
let mut builder = http_02x::Response::builder().status(response.status);
for (name, values) in response.headers {
for value in values {
builder = builder.header(&name, &value);
}
}
tokio::spawn(async move {
replay_body(events, sender).await;
});
break Ok(HttpResponse::try_from(
builder.body(body).expect("valid builder"),
)
.unwrap());
}
Action::Data {
direction: Direction::Request,
data: _data,
} => {
tracing::info!("get request data");
}
Action::Eof {
direction: Direction::Response,
..
} => panic!("got eof before response"),
Action::Data {
data: _,
direction: Direction::Response,
} => panic!("got response data before response"),
}
};
recording.lock().unwrap().insert(event_id, recorded_request);
resp
};
HttpConnectorFuture::new(fut)
}
}
impl HttpClient for ReplayingClient {
fn http_connector(
&self,
_: &HttpConnectorSettings,
_: &RuntimeComponents,
) -> SharedHttpConnector {
self.clone().into_shared()
}
fn connector_metadata(&self) -> Option<ConnectorMetadata> {
Some(ConnectorMetadata::new("replaying-client", None))
}
}