aws_runtime/
content_encoding.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use bytes::{Bytes, BytesMut};
8use http_02x::{HeaderMap, HeaderValue};
9use http_body_04x::{Body, SizeHint};
10use pin_project_lite::pin_project;
11
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15const CRLF: &str = "\r\n";
16const CHUNK_TERMINATOR: &str = "0\r\n";
17const TRAILER_SEPARATOR: &[u8] = b":";
18
19/// Content encoding header value constants
20pub mod header_value {
21    /// Header value denoting "aws-chunked" encoding
22    pub const AWS_CHUNKED: &str = "aws-chunked";
23}
24
25/// Options used when constructing an [`AwsChunkedBody`].
26#[derive(Clone, Debug, Default)]
27#[non_exhaustive]
28pub struct AwsChunkedBodyOptions {
29    /// The total size of the stream. Because we only support unsigned encoding
30    /// this implies that there will only be a single chunk containing the
31    /// underlying payload.
32    stream_length: u64,
33    /// The length of each trailer sent within an `AwsChunkedBody`. Necessary in
34    /// order to correctly calculate the total size of the body accurately.
35    trailer_lengths: Vec<u64>,
36    /// Whether the aws-chunked encoding is disabled. This could occur, for instance,
37    /// if a user specifies a custom checksum, rendering aws-chunked encoding unnecessary.
38    disabled: bool,
39}
40
41impl Storable for AwsChunkedBodyOptions {
42    type Storer = StoreReplace<Self>;
43}
44
45impl AwsChunkedBodyOptions {
46    /// Create a new [`AwsChunkedBodyOptions`].
47    pub fn new(stream_length: u64, trailer_lengths: Vec<u64>) -> Self {
48        Self {
49            stream_length,
50            trailer_lengths,
51            disabled: false,
52        }
53    }
54
55    fn total_trailer_length(&self) -> u64 {
56        self.trailer_lengths.iter().sum::<u64>()
57            // We need to account for a CRLF after each trailer name/value pair
58            + (self.trailer_lengths.len() * CRLF.len()) as u64
59    }
60
61    /// Set the stream length in the options
62    pub fn with_stream_length(mut self, stream_length: u64) -> Self {
63        self.stream_length = stream_length;
64        self
65    }
66
67    /// Append a trailer length to the options
68    pub fn with_trailer_len(mut self, trailer_len: u64) -> Self {
69        self.trailer_lengths.push(trailer_len);
70        self
71    }
72
73    /// Create a new [`AwsChunkedBodyOptions`] with aws-chunked encoding disabled.
74    ///
75    /// When the option is disabled, the body must not be wrapped in an `AwsChunkedBody`.
76    pub fn disable_chunked_encoding() -> Self {
77        Self {
78            disabled: true,
79            ..Default::default()
80        }
81    }
82
83    /// Return whether aws-chunked encoding is disabled.
84    pub fn disabled(&self) -> bool {
85        self.disabled
86    }
87
88    /// Return the length of the body after `aws-chunked` encoding is applied
89    pub fn encoded_length(&self) -> u64 {
90        let mut length = 0;
91        if self.stream_length != 0 {
92            length += get_unsigned_chunk_bytes_length(self.stream_length);
93        }
94
95        // End chunk
96        length += CHUNK_TERMINATOR.len() as u64;
97
98        // Trailers
99        for len in self.trailer_lengths.iter() {
100            length += len + CRLF.len() as u64;
101        }
102
103        // Encoding terminator
104        length += CRLF.len() as u64;
105
106        length
107    }
108}
109
110#[derive(Debug, PartialEq, Eq)]
111enum AwsChunkedBodyState {
112    /// Write out the size of the chunk that will follow. Then, transition into the
113    /// `WritingChunk` state.
114    WritingChunkSize,
115    /// Write out the next chunk of data. Multiple polls of the inner body may need to occur before
116    /// all data is written out. Once there is no more data to write, transition into the
117    /// `WritingTrailers` state.
118    WritingChunk,
119    /// Write out all trailers associated with this `AwsChunkedBody` and then transition into the
120    /// `Closed` state.
121    WritingTrailers,
122    /// This is the final state. Write out the body terminator and then remain in this state.
123    Closed,
124}
125
126pin_project! {
127    /// A request body compatible with `Content-Encoding: aws-chunked`. This implementation is only
128    /// capable of writing a single chunk and does not support signed chunks.
129    ///
130    /// Chunked-Body grammar is defined in [ABNF] as:
131    ///
132    /// ```txt
133    /// Chunked-Body    = *chunk
134    ///                   last-chunk
135    ///                   chunked-trailer
136    ///                   CRLF
137    ///
138    /// chunk           = chunk-size CRLF chunk-data CRLF
139    /// chunk-size      = 1*HEXDIG
140    /// last-chunk      = 1*("0") CRLF
141    /// chunked-trailer = *( entity-header CRLF )
142    /// entity-header   = field-name ":" OWS field-value OWS
143    /// ```
144    /// For more info on what the abbreviations mean, see https://datatracker.ietf.org/doc/html/rfc7230#section-1.2
145    ///
146    /// [ABNF]:https://en.wikipedia.org/wiki/Augmented_Backus%E2%80%93Naur_form
147    #[derive(Debug)]
148    pub struct AwsChunkedBody<InnerBody> {
149        #[pin]
150        inner: InnerBody,
151        #[pin]
152        state: AwsChunkedBodyState,
153        options: AwsChunkedBodyOptions,
154        inner_body_bytes_read_so_far: usize,
155    }
156}
157
158impl<Inner> AwsChunkedBody<Inner> {
159    /// Wrap the given body in an outer body compatible with `Content-Encoding: aws-chunked`
160    pub fn new(body: Inner, options: AwsChunkedBodyOptions) -> Self {
161        Self {
162            inner: body,
163            state: AwsChunkedBodyState::WritingChunkSize,
164            options,
165            inner_body_bytes_read_so_far: 0,
166        }
167    }
168}
169
170fn get_unsigned_chunk_bytes_length(payload_length: u64) -> u64 {
171    let hex_repr_len = int_log16(payload_length);
172    hex_repr_len + CRLF.len() as u64 + payload_length + CRLF.len() as u64
173}
174
175/// Writes trailers out into a `string` and then converts that `String` to a `Bytes` before
176/// returning.
177///
178/// - Trailer names are separated by a single colon only, no space.
179/// - Trailer names with multiple values will be written out one line per value, with the name
180///   appearing on each line.
181fn trailers_as_aws_chunked_bytes(
182    trailer_map: Option<HeaderMap>,
183    estimated_length: u64,
184) -> BytesMut {
185    if let Some(trailer_map) = trailer_map {
186        let mut current_header_name = None;
187        let mut trailers = BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
188
189        for (header_name, header_value) in trailer_map.into_iter() {
190            // When a header has multiple values, the name only comes up in iteration the first time
191            // we see it. Therefore, we need to keep track of the last name we saw and fall back to
192            // it when `header_name == None`.
193            current_header_name = header_name.or(current_header_name);
194
195            // In practice, this will always exist, but `if let` is nicer than unwrap
196            if let Some(header_name) = current_header_name.as_ref() {
197                trailers.extend_from_slice(header_name.as_ref());
198                trailers.extend_from_slice(TRAILER_SEPARATOR);
199                trailers.extend_from_slice(header_value.as_bytes());
200                trailers.extend_from_slice(CRLF.as_bytes());
201            }
202        }
203
204        trailers
205    } else {
206        BytesMut::new()
207    }
208}
209
210/// Given an optional `HeaderMap`, calculate the total number of bytes required to represent the
211/// `HeaderMap`. If no `HeaderMap` is given as input, return 0.
212///
213/// - Trailer names are separated by a single colon only, no space.
214/// - Trailer names with multiple values will be written out one line per value, with the name
215///   appearing on each line.
216fn total_rendered_length_of_trailers(trailer_map: Option<&HeaderMap>) -> u64 {
217    match trailer_map {
218        Some(trailer_map) => trailer_map
219            .iter()
220            .map(|(trailer_name, trailer_value)| {
221                trailer_name.as_str().len()
222                    + TRAILER_SEPARATOR.len()
223                    + trailer_value.len()
224                    + CRLF.len()
225            })
226            .sum::<usize>() as u64,
227        None => 0,
228    }
229}
230
231impl<Inner> Body for AwsChunkedBody<Inner>
232where
233    Inner: Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
234{
235    type Data = Bytes;
236    type Error = aws_smithy_types::body::Error;
237
238    fn poll_data(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
242        tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
243        let mut this = self.project();
244
245        match *this.state {
246            AwsChunkedBodyState::WritingChunkSize => {
247                if this.options.stream_length == 0 {
248                    // If the stream is empty, we skip to writing trailers after writing the CHUNK_TERMINATOR.
249                    *this.state = AwsChunkedBodyState::WritingTrailers;
250                    tracing::trace!("stream is empty, writing chunk terminator");
251                    Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
252                } else {
253                    *this.state = AwsChunkedBodyState::WritingChunk;
254                    // A chunk must be prefixed by chunk size in hexadecimal
255                    let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
256                    tracing::trace!(%chunk_size, "writing chunk size");
257                    let chunk_size = Bytes::from(chunk_size);
258                    Poll::Ready(Some(Ok(chunk_size)))
259                }
260            }
261            AwsChunkedBodyState::WritingChunk => match this.inner.poll_data(cx) {
262                Poll::Ready(Some(Ok(data))) => {
263                    tracing::trace!(len = data.len(), "writing chunk data");
264                    *this.inner_body_bytes_read_so_far += data.len();
265                    Poll::Ready(Some(Ok(data)))
266                }
267                Poll::Ready(None) => {
268                    let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
269                    let expected_stream_length = this.options.stream_length;
270                    if actual_stream_length != expected_stream_length {
271                        let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
272                            actual: actual_stream_length,
273                            expected: expected_stream_length,
274                        });
275                        return Poll::Ready(Some(Err(err)));
276                    };
277
278                    tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
279                    *this.state = AwsChunkedBodyState::WritingTrailers;
280                    // Since we wrote chunk data, we end it with a CRLF and since we only write
281                    // a single chunk, we write the CHUNK_TERMINATOR immediately after
282                    Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
283                }
284                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
285                Poll::Pending => Poll::Pending,
286            },
287            AwsChunkedBodyState::WritingTrailers => {
288                return match this.inner.poll_trailers(cx) {
289                    Poll::Ready(Ok(trailers)) => {
290                        *this.state = AwsChunkedBodyState::Closed;
291                        let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
292                        let actual_length = this.options.total_trailer_length();
293
294                        if expected_length != actual_length {
295                            let err =
296                                Box::new(AwsChunkedBodyError::ReportedTrailerLengthMismatch {
297                                    actual: actual_length,
298                                    expected: expected_length,
299                                });
300                            return Poll::Ready(Some(Err(err)));
301                        }
302
303                        let mut trailers =
304                            trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
305                        // Insert the final CRLF to close the body
306                        trailers.extend_from_slice(CRLF.as_bytes());
307
308                        Poll::Ready(Some(Ok(trailers.into())))
309                    }
310                    Poll::Pending => Poll::Pending,
311                    Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
312                };
313            }
314            AwsChunkedBodyState::Closed => Poll::Ready(None),
315        }
316    }
317
318    fn poll_trailers(
319        self: Pin<&mut Self>,
320        _cx: &mut Context<'_>,
321    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
322        // Trailers were already appended to the body because of the content encoding scheme
323        Poll::Ready(Ok(None))
324    }
325
326    fn is_end_stream(&self) -> bool {
327        self.state == AwsChunkedBodyState::Closed
328    }
329
330    fn size_hint(&self) -> SizeHint {
331        SizeHint::with_exact(self.options.encoded_length())
332    }
333}
334
335/// Errors related to `AwsChunkedBody`
336#[derive(Debug)]
337enum AwsChunkedBodyError {
338    /// Error that occurs when the sum of `trailer_lengths` set when creating an `AwsChunkedBody` is
339    /// not equal to the actual length of the trailers returned by the inner `http_body::Body`
340    /// implementor. These trailer lengths are necessary in order to correctly calculate the total
341    /// size of the body for setting the content length header.
342    ReportedTrailerLengthMismatch { actual: u64, expected: u64 },
343    /// Error that occurs when the `stream_length` set when creating an `AwsChunkedBody` is not
344    /// equal to the actual length of the body returned by the inner `http_body::Body` implementor.
345    /// `stream_length` must be correct in order to set an accurate content length header.
346    StreamLengthMismatch { actual: u64, expected: u64 },
347}
348
349impl std::fmt::Display for AwsChunkedBodyError {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        match self {
352            Self::ReportedTrailerLengthMismatch { actual, expected } => {
353                write!(f, "When creating this AwsChunkedBody, length of trailers was reported as {expected}. However, when double checking during trailer encoding, length was found to be {actual} instead.")
354            }
355            Self::StreamLengthMismatch { actual, expected } => {
356                write!(f, "When creating this AwsChunkedBody, stream length was reported as {expected}. However, when double checking during body encoding, length was found to be {actual} instead.")
357            }
358        }
359    }
360}
361
362impl std::error::Error for AwsChunkedBodyError {}
363
364// Used for finding how many hexadecimal digits it takes to represent a base 10 integer
365fn int_log16<T>(mut i: T) -> u64
366where
367    T: std::ops::DivAssign + PartialOrd + From<u8> + Copy,
368{
369    let mut len = 0;
370    let zero = T::from(0);
371    let sixteen = T::from(16);
372
373    while i > zero {
374        i /= sixteen;
375        len += 1;
376    }
377
378    len
379}
380
381#[cfg(test)]
382mod tests {
383    use super::{
384        total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes, AwsChunkedBody,
385        AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF,
386    };
387
388    use aws_smithy_types::body::SdkBody;
389    use bytes::{Buf, Bytes};
390    use bytes_utils::SegmentedBuf;
391    use http_02x::{HeaderMap, HeaderValue};
392    use http_body_04x::{Body, SizeHint};
393    use pin_project_lite::pin_project;
394
395    use std::io::Read;
396    use std::pin::Pin;
397    use std::task::{Context, Poll};
398    use std::time::Duration;
399
400    pin_project! {
401        struct SputteringBody {
402            parts: Vec<Option<Bytes>>,
403            cursor: usize,
404            delay_in_millis: u64,
405        }
406    }
407
408    impl SputteringBody {
409        fn len(&self) -> usize {
410            self.parts.iter().flatten().map(|b| b.len()).sum()
411        }
412    }
413
414    impl Body for SputteringBody {
415        type Data = Bytes;
416        type Error = aws_smithy_types::body::Error;
417
418        fn poll_data(
419            self: Pin<&mut Self>,
420            cx: &mut Context<'_>,
421        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
422            if self.cursor == self.parts.len() {
423                return Poll::Ready(None);
424            }
425
426            let this = self.project();
427            let delay_in_millis = *this.delay_in_millis;
428            let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
429
430            match next_part {
431                None => {
432                    *this.cursor += 1;
433                    let waker = cx.waker().clone();
434                    tokio::spawn(async move {
435                        tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
436                        waker.wake();
437                    });
438                    Poll::Pending
439                }
440                Some(data) => {
441                    *this.cursor += 1;
442                    Poll::Ready(Some(Ok(data)))
443                }
444            }
445        }
446
447        fn poll_trailers(
448            self: Pin<&mut Self>,
449            _cx: &mut Context<'_>,
450        ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
451            Poll::Ready(Ok(None))
452        }
453
454        fn is_end_stream(&self) -> bool {
455            false
456        }
457
458        fn size_hint(&self) -> SizeHint {
459            SizeHint::new()
460        }
461    }
462
463    #[tokio::test]
464    async fn test_aws_chunked_encoding() {
465        let test_fut = async {
466            let input_str = "Hello world";
467            let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
468            let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
469
470            let mut output = SegmentedBuf::new();
471            while let Some(buf) = body.data().await {
472                output.push(buf.unwrap());
473            }
474
475            let mut actual_output = String::new();
476            output
477                .reader()
478                .read_to_string(&mut actual_output)
479                .expect("Doesn't cause IO errors");
480
481            let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
482
483            assert_eq!(expected_output, actual_output);
484            assert!(
485                body.trailers()
486                    .await
487                    .expect("no errors occurred during trailer polling")
488                    .is_none(),
489                "aws-chunked encoded bodies don't have normal HTTP trailers"
490            );
491
492            // You can insert a `tokio::time::sleep` here to verify the timeout works as intended
493        };
494
495        let timeout_duration = Duration::from_secs(3);
496        if tokio::time::timeout(timeout_duration, test_fut)
497            .await
498            .is_err()
499        {
500            panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
501        }
502    }
503
504    #[tokio::test]
505    async fn test_aws_chunked_encoding_sputtering_body() {
506        let test_fut = async {
507            let input = SputteringBody {
508                parts: vec![
509                    Some(Bytes::from_static(b"chunk 1, ")),
510                    None,
511                    Some(Bytes::from_static(b"chunk 2, ")),
512                    Some(Bytes::from_static(b"chunk 3, ")),
513                    None,
514                    None,
515                    Some(Bytes::from_static(b"chunk 4, ")),
516                    Some(Bytes::from_static(b"chunk 5, ")),
517                    Some(Bytes::from_static(b"chunk 6")),
518                ],
519                cursor: 0,
520                delay_in_millis: 500,
521            };
522            let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
523            let mut body = AwsChunkedBody::new(input, opts);
524
525            let mut output = SegmentedBuf::new();
526            while let Some(buf) = body.data().await {
527                output.push(buf.unwrap());
528            }
529
530            let mut actual_output = String::new();
531            output
532                .reader()
533                .read_to_string(&mut actual_output)
534                .expect("Doesn't cause IO errors");
535
536            let expected_output =
537                "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
538
539            assert_eq!(expected_output, actual_output);
540            assert!(
541                body.trailers()
542                    .await
543                    .expect("no errors occurred during trailer polling")
544                    .is_none(),
545                "aws-chunked encoded bodies don't have normal HTTP trailers"
546            );
547        };
548
549        let timeout_duration = Duration::from_secs(3);
550        if tokio::time::timeout(timeout_duration, test_fut)
551            .await
552            .is_err()
553        {
554            panic!(
555                "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
556            );
557        }
558    }
559
560    #[tokio::test]
561    #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
562    async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
563        let input_str = "Hello world";
564        // Test body has no trailers, so this length is incorrect and will trigger an assert panic
565        // When the panic occurs, it will actually expect a length of 44. This is because, when using
566        // aws-chunked encoding, each trailer will end with a CRLF which is 2 bytes long.
567        let wrong_trailer_len = 42;
568        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
569        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
570
571        // We don't care about the body contents but we have to read it all before checking for trailers
572        while let Some(buf) = body.data().await {
573            drop(buf.unwrap());
574        }
575
576        assert!(
577            body.trailers()
578                .await
579                .expect("no errors occurred during trailer polling")
580                .is_none(),
581            "aws-chunked encoded bodies don't have normal HTTP trailers"
582        );
583    }
584
585    #[tokio::test]
586    async fn test_aws_chunked_encoding_empty_body() {
587        let input_str = "";
588        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
589        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
590
591        let mut output = SegmentedBuf::new();
592        while let Some(buf) = body.data().await {
593            output.push(buf.unwrap());
594        }
595
596        let mut actual_output = String::new();
597        output
598            .reader()
599            .read_to_string(&mut actual_output)
600            .expect("Doesn't cause IO errors");
601
602        let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
603
604        assert_eq!(expected_output, actual_output);
605        assert!(
606            body.trailers()
607                .await
608                .expect("no errors occurred during trailer polling")
609                .is_none(),
610            "aws-chunked encoded bodies don't have normal HTTP trailers"
611        );
612    }
613
614    #[tokio::test]
615    async fn test_total_rendered_length_of_trailers() {
616        let mut trailers = HeaderMap::new();
617
618        trailers.insert("empty_value", HeaderValue::from_static(""));
619
620        trailers.insert("single_value", HeaderValue::from_static("value 1"));
621
622        trailers.insert("two_values", HeaderValue::from_static("value 1"));
623        trailers.append("two_values", HeaderValue::from_static("value 2"));
624
625        trailers.insert("three_values", HeaderValue::from_static("value 1"));
626        trailers.append("three_values", HeaderValue::from_static("value 2"));
627        trailers.append("three_values", HeaderValue::from_static("value 3"));
628
629        let trailers = Some(trailers);
630        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
631        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
632
633        assert_eq!(expected_length, actual_length);
634    }
635
636    #[tokio::test]
637    async fn test_total_rendered_length_of_empty_trailers() {
638        let trailers = Some(HeaderMap::new());
639        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
640        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
641
642        assert_eq!(expected_length, actual_length);
643    }
644}