aws_smithy_checksums/body/
calculate.rs1use super::ChecksumCache;
9use crate::http::HttpChecksum;
10
11use aws_smithy_http::header::append_merge_header_maps;
12use aws_smithy_types::body::SdkBody;
13
14use http::HeaderMap;
15use http_body::SizeHint;
16use pin_project_lite::pin_project;
17
18use std::pin::Pin;
19use std::task::{Context, Poll};
20use tracing::warn;
21
22pin_project! {
23 pub struct ChecksumBody<InnerBody> {
25 #[pin]
26 body: InnerBody,
27 checksum: Option<Box<dyn HttpChecksum>>,
28 cache: Option<ChecksumCache>
29 }
30}
31
32impl ChecksumBody<SdkBody> {
33 pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
35 Self {
36 body,
37 checksum: Some(checksum),
38 cache: None,
39 }
40 }
41
42 pub fn with_cache(self, cache: ChecksumCache) -> Self {
47 Self {
48 body: self.body,
49 checksum: self.checksum,
50 cache: Some(cache),
51 }
52 }
53}
54
55impl http_body::Body for ChecksumBody<SdkBody> {
56 type Data = bytes::Bytes;
57 type Error = aws_smithy_types::body::Error;
58
59 fn poll_data(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
63 let this = self.project();
64 match this.checksum {
65 Some(checksum) => {
66 let poll_res = this.body.poll_data(cx);
67 if let Poll::Ready(Some(Ok(data))) = &poll_res {
68 checksum.update(data);
69 }
70
71 poll_res
72 }
73 None => unreachable!("This can only fail if poll_data is called again after poll_trailers, which is invalid"),
74 }
75 }
76
77 fn poll_trailers(
78 self: Pin<&mut Self>,
79 cx: &mut Context<'_>,
80 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
81 let this = self.project();
82 let poll_res = this.body.poll_trailers(cx);
83
84 if let Poll::Ready(Ok(maybe_inner_trailers)) = poll_res {
85 let checksum_headers = if let Some(checksum) = this.checksum.take() {
86 let calculated_headers = checksum.headers();
87
88 if let Some(cache) = this.cache {
89 if let Some(cached_headers) = cache.get() {
90 if cached_headers != calculated_headers {
91 warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
92 }
93 cached_headers
94 } else {
95 cache.set(calculated_headers.clone());
96 calculated_headers
97 }
98 } else {
99 calculated_headers
100 }
101 } else {
102 return Poll::Ready(Ok(None));
103 };
104
105 return match maybe_inner_trailers {
106 Some(inner_trailers) => Poll::Ready(Ok(Some(append_merge_header_maps(
107 inner_trailers,
108 checksum_headers,
109 )))),
110 None => Poll::Ready(Ok(Some(checksum_headers))),
111 };
112 }
113
114 poll_res
115 }
116
117 fn is_end_stream(&self) -> bool {
118 self.body.is_end_stream() && self.checksum.is_none()
121 }
122
123 fn size_hint(&self) -> SizeHint {
124 self.body.size_hint()
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::ChecksumBody;
131 use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
132 use aws_smithy_types::base64;
133 use aws_smithy_types::body::SdkBody;
134 use bytes::Buf;
135 use bytes_utils::SegmentedBuf;
136 use http_body::Body;
137 use std::fmt::Write;
138 use std::io::Read;
139
140 fn header_value_as_checksum_string(header_value: &http::HeaderValue) -> String {
141 let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
142 let decoded_checksum = decoded_checksum
143 .into_iter()
144 .fold(String::new(), |mut acc, byte| {
145 write!(acc, "{byte:02X?}").expect("string will always be writeable");
146 acc
147 });
148
149 format!("0x{}", decoded_checksum)
150 }
151
152 #[tokio::test]
153 async fn test_checksum_body() {
154 let input_text = "This is some test text for an SdkBody";
155 let body = SdkBody::from(input_text);
156 let checksum = CRC_32_NAME
157 .parse::<ChecksumAlgorithm>()
158 .unwrap()
159 .into_impl();
160 let mut body = ChecksumBody::new(body, checksum);
161
162 let mut output = SegmentedBuf::new();
163 while let Some(buf) = body.data().await {
164 output.push(buf.unwrap());
165 }
166
167 let mut output_text = String::new();
168 output
169 .reader()
170 .read_to_string(&mut output_text)
171 .expect("Doesn't cause IO errors");
172 assert_eq!(input_text, output_text);
174
175 let trailers = body
176 .trailers()
177 .await
178 .expect("checksum generation was without error")
179 .expect("trailers were set");
180 let checksum_trailer = trailers
181 .get(CRC_32_HEADER_NAME)
182 .expect("trailers contain crc32 checksum");
183 let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
184
185 assert_eq!("0x99B01F72", checksum_trailer);
187 }
188}