aws_smithy_checksums/body/
calculate.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Functionality for calculating the checksum of an HTTP body and emitting it as trailers.
7
8use super::ChecksumCache;
9use crate::http::HttpChecksum;
10
11use aws_smithy_http::header::append_merge_header_maps;
12use aws_smithy_types::body::SdkBody;
13
14use http::HeaderMap;
15use http_body::SizeHint;
16use pin_project_lite::pin_project;
17
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tracing::warn;
21
22pin_project! {
23    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
24    pub struct ChecksumBody<InnerBody> {
25            #[pin]
26            body: InnerBody,
27            checksum: Option<Box<dyn HttpChecksum>>,
28            cache: Option<ChecksumCache>
29    }
30}
31
32impl ChecksumBody<SdkBody> {
33    /// Given an `SdkBody` and a `Box<dyn HttpChecksum>`, create a new `ChecksumBody<SdkBody>`.
34    pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
35        Self {
36            body,
37            checksum: Some(checksum),
38            cache: None,
39        }
40    }
41
42    /// Configure a cache for this body.
43    ///
44    /// When used across multiple requests (e.g. retries) a cached checksum previously
45    /// calculated will be favored if available.
46    pub fn with_cache(self, cache: ChecksumCache) -> Self {
47        Self {
48            body: self.body,
49            checksum: self.checksum,
50            cache: Some(cache),
51        }
52    }
53}
54
55impl http_body::Body for ChecksumBody<SdkBody> {
56    type Data = bytes::Bytes;
57    type Error = aws_smithy_types::body::Error;
58
59    fn poll_data(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
63        let this = self.project();
64        match this.checksum {
65            Some(checksum) => {
66                let poll_res = this.body.poll_data(cx);
67                if let Poll::Ready(Some(Ok(data))) = &poll_res {
68                    checksum.update(data);
69                }
70
71                poll_res
72            }
73            None => unreachable!("This can only fail if poll_data is called again after poll_trailers, which is invalid"),
74        }
75    }
76
77    fn poll_trailers(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
81        let this = self.project();
82        let poll_res = this.body.poll_trailers(cx);
83
84        if let Poll::Ready(Ok(maybe_inner_trailers)) = poll_res {
85            let checksum_headers = if let Some(checksum) = this.checksum.take() {
86                let calculated_headers = checksum.headers();
87
88                if let Some(cache) = this.cache {
89                    if let Some(cached_headers) = cache.get() {
90                        if cached_headers != calculated_headers {
91                            warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
92                        }
93                        cached_headers
94                    } else {
95                        cache.set(calculated_headers.clone());
96                        calculated_headers
97                    }
98                } else {
99                    calculated_headers
100                }
101            } else {
102                return Poll::Ready(Ok(None));
103            };
104
105            return match maybe_inner_trailers {
106                Some(inner_trailers) => Poll::Ready(Ok(Some(append_merge_header_maps(
107                    inner_trailers,
108                    checksum_headers,
109                )))),
110                None => Poll::Ready(Ok(Some(checksum_headers))),
111            };
112        }
113
114        poll_res
115    }
116
117    fn is_end_stream(&self) -> bool {
118        // If inner body is finished and we've already consumed the checksum then we must be
119        // at the end of the stream.
120        self.body.is_end_stream() && self.checksum.is_none()
121    }
122
123    fn size_hint(&self) -> SizeHint {
124        self.body.size_hint()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::ChecksumBody;
131    use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
132    use aws_smithy_types::base64;
133    use aws_smithy_types::body::SdkBody;
134    use bytes::Buf;
135    use bytes_utils::SegmentedBuf;
136    use http_body::Body;
137    use std::fmt::Write;
138    use std::io::Read;
139
140    fn header_value_as_checksum_string(header_value: &http::HeaderValue) -> String {
141        let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
142        let decoded_checksum = decoded_checksum
143            .into_iter()
144            .fold(String::new(), |mut acc, byte| {
145                write!(acc, "{byte:02X?}").expect("string will always be writeable");
146                acc
147            });
148
149        format!("0x{}", decoded_checksum)
150    }
151
152    #[tokio::test]
153    async fn test_checksum_body() {
154        let input_text = "This is some test text for an SdkBody";
155        let body = SdkBody::from(input_text);
156        let checksum = CRC_32_NAME
157            .parse::<ChecksumAlgorithm>()
158            .unwrap()
159            .into_impl();
160        let mut body = ChecksumBody::new(body, checksum);
161
162        let mut output = SegmentedBuf::new();
163        while let Some(buf) = body.data().await {
164            output.push(buf.unwrap());
165        }
166
167        let mut output_text = String::new();
168        output
169            .reader()
170            .read_to_string(&mut output_text)
171            .expect("Doesn't cause IO errors");
172        // Verify data is complete and unaltered
173        assert_eq!(input_text, output_text);
174
175        let trailers = body
176            .trailers()
177            .await
178            .expect("checksum generation was without error")
179            .expect("trailers were set");
180        let checksum_trailer = trailers
181            .get(CRC_32_HEADER_NAME)
182            .expect("trailers contain crc32 checksum");
183        let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
184
185        // Known correct checksum for the input "This is some test text for an SdkBody"
186        assert_eq!("0x99B01F72", checksum_trailer);
187    }
188}