aws_smithy_checksums/body/
validate.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Functionality for validating an HTTP body against a given precalculated checksum and emitting an
7//! error if it doesn't match.
8
9use crate::http::HttpChecksum;
10
11use aws_smithy_types::body::SdkBody;
12
13use bytes::Bytes;
14use pin_project_lite::pin_project;
15
16use std::fmt::Display;
17use std::pin::Pin;
18use std::task::{Context, Poll};
19
20pin_project! {
21    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit an error if it
22    /// doesn't match the precalculated checksum.
23    pub struct ChecksumBody<InnerBody> {
24        #[pin]
25        inner: InnerBody,
26        checksum: Option<Box<dyn HttpChecksum>>,
27        precalculated_checksum: Bytes,
28    }
29}
30
31impl ChecksumBody<SdkBody> {
32    /// Given an `SdkBody`, a `Box<dyn HttpChecksum>`, and a precalculated checksum represented
33    /// as `Bytes`, create a new `ChecksumBody<SdkBody>`.
34    pub fn new(
35        body: SdkBody,
36        checksum: Box<dyn HttpChecksum>,
37        precalculated_checksum: Bytes,
38    ) -> Self {
39        Self {
40            inner: body,
41            checksum: Some(checksum),
42            precalculated_checksum,
43        }
44    }
45
46    fn poll_inner(
47        self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49    ) -> Poll<Option<Result<http_body_1x::Frame<Bytes>, aws_smithy_types::body::Error>>> {
50        use http_body_1x::Body;
51
52        let this = self.project();
53        let checksum = this.checksum;
54
55        match this.inner.poll_frame(cx) {
56            Poll::Ready(Some(Ok(frame))) => {
57                let data = frame.data_ref().expect("Data frame should have data");
58                tracing::trace!(
59                    "reading {} bytes from the body and updating the checksum calculation",
60                    data.len()
61                );
62                let checksum = match checksum.as_mut() {
63                    Some(checksum) => checksum,
64                    None => {
65                        unreachable!("The checksum must exist because it's only taken out once the inner body has been completely polled.");
66                    }
67                };
68
69                checksum.update(data);
70                Poll::Ready(Some(Ok(frame)))
71            }
72            // Once the inner body has stopped returning data, check the checksum
73            // and return an error if it doesn't match.
74            Poll::Ready(None) => {
75                tracing::trace!("finished reading from body, calculating final checksum");
76                let checksum = match checksum.take() {
77                    Some(checksum) => checksum,
78                    None => {
79                        // If the checksum was already taken and this was polled again anyways,
80                        // then return nothing
81                        return Poll::Ready(None);
82                    }
83                };
84
85                let actual_checksum = checksum.finalize();
86                if *this.precalculated_checksum == actual_checksum {
87                    Poll::Ready(None)
88                } else {
89                    // So many parens it's starting to look like LISP
90                    Poll::Ready(Some(Err(Box::new(Error::ChecksumMismatch {
91                        expected: this.precalculated_checksum.clone(),
92                        actual: actual_checksum,
93                    }))))
94                }
95            }
96            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
97            Poll::Pending => Poll::Pending,
98        }
99    }
100}
101
102/// Errors related to checksum calculation and validation
103#[derive(Debug, Eq, PartialEq)]
104#[non_exhaustive]
105pub enum Error {
106    /// The actual checksum didn't match the expected checksum. The checksummed data has been
107    /// altered since the expected checksum was calculated.
108    ChecksumMismatch { expected: Bytes, actual: Bytes },
109}
110
111impl Display for Error {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
113        match self {
114            Error::ChecksumMismatch { expected, actual } => write!(
115                f,
116                "body checksum mismatch. expected body checksum to be {} but it was {}",
117                hex::encode(expected),
118                hex::encode(actual)
119            ),
120        }
121    }
122}
123
124impl std::error::Error for Error {}
125
126impl http_body_1x::Body for ChecksumBody<SdkBody> {
127    type Data = Bytes;
128    type Error = aws_smithy_types::body::Error;
129
130    fn poll_frame(
131        self: Pin<&mut Self>,
132        cx: &mut Context<'_>,
133    ) -> Poll<Option<Result<http_body_1x::Frame<Self::Data>, Self::Error>>> {
134        self.poll_inner(cx)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use crate::body::validate::{ChecksumBody, Error};
141    use crate::ChecksumAlgorithm;
142    use aws_smithy_types::body::SdkBody;
143    use bytes::{Buf, Bytes};
144    use bytes_utils::SegmentedBuf;
145    use http_body_util::BodyExt;
146    use std::io::Read;
147
148    fn calculate_crc32_checksum(input: &str) -> Bytes {
149        let checksum =
150            crc_fast::checksum(crc_fast::CrcAlgorithm::Crc32IsoHdlc, input.as_bytes()) as u32;
151
152        Bytes::copy_from_slice(&checksum.to_be_bytes())
153    }
154
155    #[tokio::test]
156    async fn test_checksum_validated_body_errors_on_mismatch() {
157        let input_text = "This is some test text for an SdkBody";
158        let actual_checksum = calculate_crc32_checksum(input_text);
159        let body = SdkBody::from(input_text);
160        let non_matching_checksum = Bytes::copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);
161        let mut body = ChecksumBody::new(
162            body,
163            "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl(),
164            non_matching_checksum.clone(),
165        );
166
167        while let Some(data) = body.frame().await {
168            match data {
169                Ok(_) => { /* Do nothing */ }
170                Err(e) => {
171                    match e.downcast_ref::<Error>().unwrap() {
172                        Error::ChecksumMismatch { expected, actual } => {
173                            assert_eq!(expected, &non_matching_checksum);
174                            assert_eq!(actual, &actual_checksum);
175                        }
176                    }
177
178                    return;
179                }
180            }
181        }
182
183        panic!("didn't hit expected error condition");
184    }
185
186    #[tokio::test]
187    async fn test_checksum_validated_body_succeeds_on_match() {
188        let input_text = "This is some test text for an SdkBody";
189        let actual_checksum = calculate_crc32_checksum(input_text);
190        let body = SdkBody::from(input_text);
191        let http_checksum = "crc32".parse::<ChecksumAlgorithm>().unwrap().into_impl();
192        let mut body = ChecksumBody::new(body, http_checksum, actual_checksum);
193
194        let mut output = SegmentedBuf::new();
195        while let Some(buf) = body.frame().await {
196            let data = buf.unwrap().into_data().unwrap();
197            output.push(data);
198        }
199
200        let mut output_text = String::new();
201        output
202            .reader()
203            .read_to_string(&mut output_text)
204            .expect("Doesn't cause IO errors");
205        // Verify data is complete and unaltered
206        assert_eq!(input_text, output_text);
207    }
208}