use crate::http::HttpChecksum;
use aws_smithy_types::body::SdkBody;
use bytes::Bytes;
use http::{HeaderMap, HeaderValue};
use http_body::SizeHint;
use pin_project_lite::pin_project;
use std::fmt::Display;
use std::pin::Pin;
use std::task::{Context, Poll};
pin_project! {
pub struct ChecksumBody<InnerBody> {
#[pin]
inner: InnerBody,
checksum: Option<Box<dyn HttpChecksum>>,
precalculated_checksum: Bytes,
}
}
impl ChecksumBody<SdkBody> {
pub fn new(
body: SdkBody,
checksum: Box<dyn HttpChecksum>,
precalculated_checksum: Bytes,
) -> Self {
Self {
inner: body,
checksum: Some(checksum),
precalculated_checksum,
}
}
fn poll_inner(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, aws_smithy_types::body::Error>>> {
use http_body::Body;
let this = self.project();
let checksum = this.checksum;
match this.inner.poll_data(cx) {
Poll::Ready(Some(Ok(data))) => {
tracing::trace!(
"reading {} bytes from the body and updating the checksum calculation",
data.len()
);
let checksum = match checksum.as_mut() {
Some(checksum) => checksum,
None => {
unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
}
};
checksum.update(&data);
Poll::Ready(Some(Ok(data)))
}
Poll::Ready(None) => {
tracing::trace!("finished reading from body, calculating final checksum");
let checksum = match checksum.take() {
Some(checksum) => checksum,
None => {
return Poll::Ready(None);
}
};
let actual_checksum = checksum.finalize();
if *this.precalculated_checksum == actual_checksum {
Poll::Ready(None)
} else {
Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
expected: this.precalculated_checksum.clone(),
actual: actual_checksum,
}))))
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum Error {
ChecksumMismatch { expected: Bytes, actual: Bytes },
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Error::ChecksumMismatch { expected, actual } => write!(
f,
"body checksum mismatch. expected body checksum to be {} but it was {}",
hex::encode(expected),
hex::encode(actual)
),
}
}
}
impl std::error::Error for Error {}
impl http_body::Body for ChecksumBody<SdkBody> {
type Data = Bytes;
type Error = aws_smithy_types::body::Error;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.poll_inner(cx)
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
self.project().inner.poll_trailers(cx)
}
fn is_end_stream(&self) -> bool {
self.checksum.is_none()
}
fn size_hint(&self) -> SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use crate::body::validate::{ChecksumBody, Error};
use crate::ChecksumAlgorithm;
use aws_smithy_types::body::SdkBody;
use bytes::{Buf, Bytes};
use bytes_utils::SegmentedBuf;
use http_body::Body;
use std::io::Read;
fn calculate_crc32_checksum(input: &str) -> Bytes {
let checksum = crc32fast::hash(input.as_bytes());
Bytes::copy_from_slice(&checksum.to_be_bytes())
}
#[tokio::test]
async fn test_checksum_validated_body_errors_on_mismatch() {
let input_text = "This is some test text for an SdkBody";
let actual_checksum = calculate_crc32_checksum(input_text);
let body = SdkBody::from(input_text);
let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
let mut body = ChecksumBody::new(
body,
"crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl(),
non_matching_checksum.clone(),
);
while let Some(data) = body.data().await {
match data {
Ok(_) => { }
Err(e) => {
match e.downcast_ref::<Error>().unwrap() {
Error::ChecksumMismatch { expected, actual } => {
assert_eq!(expected, &non_matching_checksum);
assert_eq!(actual, &actual_checksum);
}
}
return;
}
}
}
panic!("didn't hit expected error condition");
}
#[tokio::test]
async fn test_checksum_validated_body_succeeds_on_match() {
let input_text = "This is some test text for an SdkBody";
let actual_checksum = calculate_crc32_checksum(input_text);
let body = SdkBody::from(input_text);
let http_checksum = "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl();
let mut body = ChecksumBody::new(body, http_checksum, actual_checksum);
let mut output = SegmentedBuf::new();
while let Some(buf) = body.data().await {
output.push(buf.unwrap());
}
let mut output_text = String::new();
output
.reader()
.read_to_string(&mut output_text)
.expect("Doesn't cause IO errors");
assert_eq!(input_text, output_text);
}
}