aws_smithy_checksums/body/
calculate.rs1use super::ChecksumCache;
9use crate::http::HttpChecksum;
10
11use aws_smithy_http::header::append_merge_header_maps_http_1x;
12use aws_smithy_types::body::SdkBody;
13use pin_project_lite::pin_project;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use tracing::warn;
17
18pin_project! {
19 pub struct ChecksumBody<InnerBody> {
21 #[pin]
22 body: InnerBody,
23 checksum: Option<Box<dyn HttpChecksum>>,
24 written_trailers: bool,
25 cache: Option<ChecksumCache>
26 }
27}
28
29impl ChecksumBody<SdkBody> {
30 pub fn new(body: SdkBody, checksum: Box<dyn HttpChecksum>) -> Self {
32 Self {
33 body,
34 checksum: Some(checksum),
35 written_trailers: false,
36 cache: None,
37 }
38 }
39
40 pub fn with_cache(self, cache: ChecksumCache) -> Self {
45 Self {
46 body: self.body,
47 checksum: self.checksum,
48 written_trailers: false,
49 cache: Some(cache),
50 }
51 }
52
53 fn extract_or_set_cached_headers(
56 maybe_cache: &Option<ChecksumCache>,
57 checksum: Box<dyn HttpChecksum>,
58 ) -> http_1x::HeaderMap {
59 let calculated_headers = checksum.headers();
60 if let Some(cache) = maybe_cache {
61 if let Some(cached_headers) = cache.get() {
62 if cached_headers != calculated_headers {
63 warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
64 }
65 cached_headers
66 } else {
67 cache.set(calculated_headers.clone());
68 calculated_headers
69 }
70 } else {
71 calculated_headers
72 }
73 }
74}
75
76impl http_body_1x::Body for ChecksumBody<SdkBody> {
77 type Data = bytes::Bytes;
78 type Error = aws_smithy_types::body::Error;
79
80 fn poll_frame(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 ) -> Poll<Option<Result<http_body_1x::Frame<Self::Data>, Self::Error>>> {
84 let this = self.project();
85 let poll_res = this.body.poll_frame(cx);
86
87 match &poll_res {
88 Poll::Ready(Some(Ok(frame))) => {
89 if frame.is_data() {
91 if let Some(checksum) = this.checksum {
92 checksum.update(frame.data_ref().expect("Data frame has data"));
93 }
94 } else {
95 let checksum_headers = if let Some(checksum) = this.checksum.take() {
97 ChecksumBody::extract_or_set_cached_headers(this.cache, checksum)
98 } else {
99 return Poll::Ready(None);
100 };
101 let trailers = frame
102 .trailers_ref()
103 .expect("Trailers frame has trailers")
104 .clone();
105 *this.written_trailers = true;
106 return Poll::Ready(Some(Ok(http_body_1x::Frame::trailers(
107 append_merge_header_maps_http_1x(trailers, checksum_headers),
108 ))));
109 }
110 }
111 Poll::Ready(None) => {
112 if !*this.written_trailers {
115 let checksum_headers = if let Some(checksum) = this.checksum.take() {
116 ChecksumBody::extract_or_set_cached_headers(this.cache, checksum)
117 } else {
118 return Poll::Ready(None);
119 };
120 let trailers = http_1x::HeaderMap::new();
121 return Poll::Ready(Some(Ok(http_body_1x::Frame::trailers(
122 append_merge_header_maps_http_1x(trailers, checksum_headers),
123 ))));
124 }
125 }
126 _ => {}
127 };
128 poll_res
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::ChecksumBody;
135 use crate::{http::CRC_32_HEADER_NAME, ChecksumAlgorithm, CRC_32_NAME};
136 use aws_smithy_types::base64;
137 use aws_smithy_types::body::SdkBody;
138 use bytes::Buf;
139 use bytes_utils::SegmentedBuf;
140 use http_1x::HeaderMap;
141 use http_body_util::BodyExt;
142 use std::fmt::Write;
143 use std::io::Read;
144
145 fn header_value_as_checksum_string(header_value: &http_1x::HeaderValue) -> String {
146 let decoded_checksum = base64::decode(header_value.to_str().unwrap()).unwrap();
147 let decoded_checksum = decoded_checksum
148 .into_iter()
149 .fold(String::new(), |mut acc, byte| {
150 write!(acc, "{byte:02X?}").expect("string will always be writeable");
151 acc
152 });
153
154 format!("0x{}", decoded_checksum)
155 }
156
157 #[tokio::test]
158 async fn test_checksum_body() {
159 let input_text = "This is some test text for an SdkBody";
160 let body = SdkBody::from(input_text);
161 let checksum = CRC_32_NAME
162 .parse::<ChecksumAlgorithm>()
163 .unwrap()
164 .into_impl();
165 let mut body = ChecksumBody::new(body, checksum);
166
167 let mut output_data = SegmentedBuf::new();
168 let mut trailers = HeaderMap::new();
169 while let Some(buf) = body.frame().await {
170 let buf = buf.unwrap();
171 if buf.is_data() {
172 output_data.push(buf.into_data().unwrap());
173 } else if buf.is_trailers() {
174 let map = buf.into_trailers().unwrap();
175 map.into_iter().for_each(|(k, v)| {
176 trailers.insert(k.unwrap(), v);
177 });
178 }
179 }
180
181 let mut output_text = String::new();
182 output_data
183 .reader()
184 .read_to_string(&mut output_text)
185 .expect("Doesn't cause IO errors");
186 assert_eq!(input_text, output_text);
188
189 let checksum_trailer = trailers
190 .get(CRC_32_HEADER_NAME)
191 .expect("trailers contain crc32 checksum");
192 let checksum_trailer = header_value_as_checksum_string(checksum_trailer);
193
194 assert_eq!("0x99B01F72", checksum_trailer);
196 }
197}