aws_smithy_checksums/body/
validate.rs1use crate::http::HttpChecksum;
10
11use aws_smithy_types::body::SdkBody;
12
13use bytes::Bytes;
14use pin_project_lite::pin_project;
15
16use std::fmt::Display;
17use std::pin::Pin;
18use std::task::{Context, Poll};
19
20pin_project! {
21 pub struct ChecksumBody<InnerBody> {
24 #[pin]
25 inner: InnerBody,
26 checksum: Option<Box<dyn HttpChecksum>>,
27 precalculated_checksum: Bytes,
28 }
29}
30
31impl ChecksumBody<SdkBody> {
32 pub fn new(
35 body: SdkBody,
36 checksum: Box<dyn HttpChecksum>,
37 precalculated_checksum: Bytes,
38 ) -> Self {
39 Self {
40 inner: body,
41 checksum: Some(checksum),
42 precalculated_checksum,
43 }
44 }
45
46 fn poll_inner(
47 self: Pin<&mut Self>,
48 cx: &mut Context<'_>,
49 ) -> Poll<Option<Result<http_body_1x::Frame<Bytes>, aws_smithy_types::body::Error>>> {
50 use http_body_1x::Body;
51
52 let this = self.project();
53 let checksum = this.checksum;
54
55 match this.inner.poll_frame(cx) {
56 Poll::Ready(Some(Ok(frame))) => {
57 let data = frame.data_ref().expect("Data frame should have data");
58 tracing::trace!(
59 "reading {} bytes from the body and updating the checksum calculation",
60 data.len()
61 );
62 let checksum = match checksum.as_mut() {
63 Some(checksum) => checksum,
64 None => {
65 unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
66 }
67 };
68
69 checksum.update(data);
70 Poll::Ready(Some(Ok(frame)))
71 }
72 Poll::Ready(None) => {
75 tracing::trace!("finished reading from body, calculating final checksum");
76 let checksum = match checksum.take() {
77 Some(checksum) => checksum,
78 None => {
79 return Poll::Ready(None);
82 }
83 };
84
85 let actual_checksum = checksum.finalize();
86 if *this.precalculated_checksum == actual_checksum {
87 Poll::Ready(None)
88 } else {
89 Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
91 expected: this.precalculated_checksum.clone(),
92 actual: actual_checksum,
93 }))))
94 }
95 }
96 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
97 Poll::Pending => Poll::Pending,
98 }
99 }
100}
101
102#[derive(Debug, Eq, PartialEq)]
104#[non_exhaustive]
105pub enum Error {
106 ChecksumMismatch { expected: Bytes, actual: Bytes },
109}
110
111impl Display for Error {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
113 match self {
114 Error::ChecksumMismatch { expected, actual } => write!(
115 f,
116 "body checksum mismatch. expected body checksum to be {} but it was {}",
117 hex::encode(expected),
118 hex::encode(actual)
119 ),
120 }
121 }
122}
123
124impl std::error::Error for Error {}
125
126impl http_body_1x::Body for ChecksumBody<SdkBody> {
127 type Data = Bytes;
128 type Error = aws_smithy_types::body::Error;
129
130 fn poll_frame(
131 self: Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 ) -> Poll<Option<Result<http_body_1x::Frame<Self::Data>, Self::Error>>> {
134 self.poll_inner(cx)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use crate::body::validate::{ChecksumBody, Error};
141 use crate::ChecksumAlgorithm;
142 use aws_smithy_types::body::SdkBody;
143 use bytes::{Buf, Bytes};
144 use bytes_utils::SegmentedBuf;
145 use http_body_util::BodyExt;
146 use std::io::Read;
147
148 fn calculate_crc32_checksum(input: &str) -> Bytes {
149 let checksum =
150 crc_fast::checksum(crc_fast::CrcAlgorithm::Crc32IsoHdlc, input.as_bytes()) as u32;
151
152 Bytes::copy_from_slice(&checksum.to_be_bytes())
153 }
154
155 #[tokio::test]
156 async fn test_checksum_validated_body_errors_on_mismatch() {
157 let input_text = "This is some test text for an SdkBody";
158 let actual_checksum = calculate_crc32_checksum(input_text);
159 let body = SdkBody::from(input_text);
160 let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
161 let mut body = ChecksumBody::new(
162 body,
163 "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl(),
164 non_matching_checksum.clone(),
165 );
166
167 while let Some(data) = body.frame().await {
168 match data {
169 Ok(_) => { }
170 Err(e) => {
171 match e.downcast_ref::<Error>().unwrap() {
172 Error::ChecksumMismatch { expected, actual } => {
173 assert_eq!(expected, &non_matching_checksum);
174 assert_eq!(actual, &actual_checksum);
175 }
176 }
177
178 return;
179 }
180 }
181 }
182
183 panic!("didn't hit expected error condition");
184 }
185
186 #[tokio::test]
187 async fn test_checksum_validated_body_succeeds_on_match() {
188 let input_text = "This is some test text for an SdkBody";
189 let actual_checksum = calculate_crc32_checksum(input_text);
190 let body = SdkBody::from(input_text);
191 let http_checksum = "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl();
192 let mut body = ChecksumBody::new(body, http_checksum, actual_checksum);
193
194 let mut output = SegmentedBuf::new();
195 while let Some(buf) = body.frame().await {
196 let data = buf.unwrap().into_data().unwrap();
197 output.push(data);
198 }
199
200 let mut output_text = String::new();
201 output
202 .reader()
203 .read_to_string(&mut output_text)
204 .expect("Doesn't cause IO errors");
205 assert_eq!(input_text, output_text);
207 }
208}