aws_smithy_checksums/body/
calculate.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Functionality for calculating the checksum of an HTTP body and emitting it as trailers.
7
8use super::ChecksumCache;
9use crate::http::HttpChecksum;
10
11use aws_smithy_http::header::append_merge_header_maps_http_1x;
12use aws_smithy_types::body::SdkBody;
13use pin_project_lite::pin_project;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use tracing::warn;
17
18pin_project! {
19    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
20    pub struct ChecksumBody<InnerBody> {
21            #[pin]
22            body: InnerBody,
23            checksum: Option<Box<dyn HttpChecksum>>,
24            written_trailers: bool,
25            cache: Option<ChecksumCache>
26    }
27}
28
29impl ChecksumBody<SdkBody> {
30    /// Given an `SdkBody` and a `Box<dyn HttpChecksum>`, create a new `ChecksumBody<SdkBody>`.
31    pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
32        Self {
33            body,
34            checksum: Some(checksum),
35            written_trailers: false,
36            cache: None,
37        }
38    }
39
40    /// Configure a cache for this body.
41    ///
42    /// When used across multiple requests (e.g. retries) a cached checksum previously
43    /// calculated will be favored if available.
44    pub fn with_cache(self, cache: ChecksumCache) -> Self {
45        Self {
46            body: self.body,
47            checksum: self.checksum,
48            written_trailers: false,
49            cache: Some(cache),
50        }
51    }
52
53    // It would be nicer if this could take &self, but I couldn't make that
54    // work out with the Pin/Projection types, so its a static method for now
55    fn extract_or_set_cached_headers(
56        maybe_cache: &Option<ChecksumCache>,
57        checksum: Box<dyn HttpChecksum>,
58    ) -> http_1x::HeaderMap {
59        let calculated_headers = checksum.headers();
60        if let Some(cache) = maybe_cache {
61            if let Some(cached_headers) = cache.get() {
62                if cached_headers != calculated_headers {
63                    warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
64                }
65                cached_headers
66            } else {
67                cache.set(calculated_headers.clone());
68                calculated_headers
69            }
70        } else {
71            calculated_headers
72        }
73    }
74}
75
76impl http_body_1x::Body for ChecksumBody<SdkBody> {
77    type Data = bytes::Bytes;
78    type Error = aws_smithy_types::body::Error;
79
80    fn poll_frame(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83    ) -> Poll<Option<Result<http_body_1x::Frame<Self::Data>, Self::Error>>> {
84        let this = self.project();
85        let poll_res = this.body.poll_frame(cx);
86
87        match &poll_res {
88            Poll::Ready(Some(Ok(frame))) => {
89                // Update checksum for data frames
90                if frame.is_data() {
91                    if let Some(checksum) = this.checksum {
92                        checksum.update(frame.data_ref().expect("Data frame has data"));
93                    }
94                } else {
95                    // Add checksum trailer to other trailers if necessary
96                    let checksum_headers = if let Some(checksum) = this.checksum.take() {
97                        ChecksumBody::extract_or_set_cached_headers(this.cache, checksum)
98                    } else {
99                        return Poll::Ready(None);
100                    };
101                    let trailers = frame
102                        .trailers_ref()
103                        .expect("Trailers frame has trailers")
104                        .clone();
105                    *this.written_trailers = true;
106                    return Poll::Ready(Some(Ok(http_body_1x::Frame::trailers(
107                        append_merge_header_maps_http_1x(trailers, checksum_headers),
108                    ))));
109                }
110            }
111            Poll::Ready(None) => {
112                // If the trailers have not already been written (because there were no existing
113                // trailers on the body) we write them here
114                if !*this.written_trailers {
115                    let checksum_headers = if let Some(checksum) = this.checksum.take() {
116                        ChecksumBody::extract_or_set_cached_headers(this.cache, checksum)
117                    } else {
118                        return Poll::Ready(None);
119                    };
120                    let trailers = http_1x::HeaderMap::new();
121                    return Poll::Ready(Some(Ok(http_body_1x::Frame::trailers(
122                        append_merge_header_maps_http_1x(trailers, checksum_headers),
123                    ))));
124                }
125            }
126            _ => {}
127        };
128        poll_res
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::ChecksumBody;
135    use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
136    use aws_smithy_types::base64;
137    use aws_smithy_types::body::SdkBody;
138    use bytes::Buf;
139    use bytes_utils::SegmentedBuf;
140    use http_1x::HeaderMap;
141    use http_body_util::BodyExt;
142    use std::fmt::Write;
143    use std::io::Read;
144
145    fn header_value_as_checksum_string(header_value: &http_1x::HeaderValue) -> String {
146        let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
147        let decoded_checksum = decoded_checksum
148            .into_iter()
149            .fold(String::new(), |mut acc, byte| {
150                write!(acc, "{byte:02X?}").expect("string will always be writeable");
151                acc
152            });
153
154        format!("0x{}", decoded_checksum)
155    }
156
157    #[tokio::test]
158    async fn test_checksum_body() {
159        let input_text = "This is some test text for an SdkBody";
160        let body = SdkBody::from(input_text);
161        let checksum = CRC_32_NAME
162            .parse::<ChecksumAlgorithm>()
163            .unwrap()
164            .into_impl();
165        let mut body = ChecksumBody::new(body, checksum);
166
167        let mut output_data = SegmentedBuf::new();
168        let mut trailers = HeaderMap::new();
169        while let Some(buf) = body.frame().await {
170            let buf = buf.unwrap();
171            if buf.is_data() {
172                output_data.push(buf.into_data().unwrap());
173            } else if buf.is_trailers() {
174                let map = buf.into_trailers().unwrap();
175                map.into_iter().for_each(|(k, v)| {
176                    trailers.insert(k.unwrap(), v);
177                });
178            }
179        }
180
181        let mut output_text = String::new();
182        output_data
183            .reader()
184            .read_to_string(&mut output_text)
185            .expect("Doesn't cause IO errors");
186        // Verify data is complete and unaltered
187        assert_eq!(input_text, output_text);
188
189        let checksum_trailer = trailers
190            .get(CRC_32_HEADER_NAME)
191            .expect("trailers contain crc32 checksum");
192        let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
193
194        // Known correct checksum for the input "This is some test text for an SdkBody"
195        assert_eq!("0x99B01F72", checksum_trailer);
196    }
197}