1use aws_smithy_types::config_bag::{Storable, StoreReplace};
7use bytes::{Bytes, BytesMut};
8use http_02x::{HeaderMap, HeaderValue};
9use http_body_04x::{Body, SizeHint};
10use pin_project_lite::pin_project;
11
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15const CRLF: &str = "\r\n";
16const CHUNK_TERMINATOR: &str = "0\r\n";
17const TRAILER_SEPARATOR: &[u8] = b":";
18
19pub mod header_value {
21 pub const AWS_CHUNKED: &str = "aws-chunked";
23}
24
25#[derive(Clone, Debug, Default)]
27#[non_exhaustive]
28pub struct AwsChunkedBodyOptions {
29 stream_length: u64,
33 trailer_lengths: Vec<u64>,
36 disabled: bool,
39}
40
41impl Storable for AwsChunkedBodyOptions {
42 type Storer = StoreReplace<Self>;
43}
44
45impl AwsChunkedBodyOptions {
46 pub fn new(stream_length: u64, trailer_lengths: Vec<u64>) -> Self {
48 Self {
49 stream_length,
50 trailer_lengths,
51 disabled: false,
52 }
53 }
54
55 fn total_trailer_length(&self) -> u64 {
56 self.trailer_lengths.iter().sum::<u64>()
57 + (self.trailer_lengths.len() * CRLF.len()) as u64
59 }
60
61 pub fn with_stream_length(mut self, stream_length: u64) -> Self {
63 self.stream_length = stream_length;
64 self
65 }
66
67 pub fn with_trailer_len(mut self, trailer_len: u64) -> Self {
69 self.trailer_lengths.push(trailer_len);
70 self
71 }
72
73 pub fn disable_chunked_encoding() -> Self {
77 Self {
78 disabled: true,
79 ..Default::default()
80 }
81 }
82
83 pub fn disabled(&self) -> bool {
85 self.disabled
86 }
87
88 pub fn encoded_length(&self) -> u64 {
90 let mut length = 0;
91 if self.stream_length != 0 {
92 length += get_unsigned_chunk_bytes_length(self.stream_length);
93 }
94
95 length += CHUNK_TERMINATOR.len() as u64;
97
98 for len in self.trailer_lengths.iter() {
100 length += len + CRLF.len() as u64;
101 }
102
103 length += CRLF.len() as u64;
105
106 length
107 }
108}
109
110#[derive(Debug, PartialEq, Eq)]
111enum AwsChunkedBodyState {
112 WritingChunkSize,
115 WritingChunk,
119 WritingTrailers,
122 Closed,
124}
125
126pin_project! {
127 #[derive(Debug)]
148 pub struct AwsChunkedBody<InnerBody> {
149 #[pin]
150 inner: InnerBody,
151 #[pin]
152 state: AwsChunkedBodyState,
153 options: AwsChunkedBodyOptions,
154 inner_body_bytes_read_so_far: usize,
155 }
156}
157
158impl<Inner> AwsChunkedBody<Inner> {
159 pub fn new(body: Inner, options: AwsChunkedBodyOptions) -> Self {
161 Self {
162 inner: body,
163 state: AwsChunkedBodyState::WritingChunkSize,
164 options,
165 inner_body_bytes_read_so_far: 0,
166 }
167 }
168}
169
170fn get_unsigned_chunk_bytes_length(payload_length: u64) -> u64 {
171 let hex_repr_len = int_log16(payload_length);
172 hex_repr_len + CRLF.len() as u64 + payload_length + CRLF.len() as u64
173}
174
175fn trailers_as_aws_chunked_bytes(
182 trailer_map: Option<HeaderMap>,
183 estimated_length: u64,
184) -> BytesMut {
185 if let Some(trailer_map) = trailer_map {
186 let mut current_header_name = None;
187 let mut trailers = BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
188
189 for (header_name, header_value) in trailer_map.into_iter() {
190 current_header_name = header_name.or(current_header_name);
194
195 if let Some(header_name) = current_header_name.as_ref() {
197 trailers.extend_from_slice(header_name.as_ref());
198 trailers.extend_from_slice(TRAILER_SEPARATOR);
199 trailers.extend_from_slice(header_value.as_bytes());
200 trailers.extend_from_slice(CRLF.as_bytes());
201 }
202 }
203
204 trailers
205 } else {
206 BytesMut::new()
207 }
208}
209
210fn total_rendered_length_of_trailers(trailer_map: Option<&HeaderMap>) -> u64 {
217 match trailer_map {
218 Some(trailer_map) => trailer_map
219 .iter()
220 .map(|(trailer_name, trailer_value)| {
221 trailer_name.as_str().len()
222 + TRAILER_SEPARATOR.len()
223 + trailer_value.len()
224 + CRLF.len()
225 })
226 .sum::<usize>() as u64,
227 None => 0,
228 }
229}
230
231impl<Inner> Body for AwsChunkedBody<Inner>
232where
233 Inner: Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
234{
235 type Data = Bytes;
236 type Error = aws_smithy_types::body::Error;
237
238 fn poll_data(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
242 tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
243 let mut this = self.project();
244
245 match *this.state {
246 AwsChunkedBodyState::WritingChunkSize => {
247 if this.options.stream_length == 0 {
248 *this.state = AwsChunkedBodyState::WritingTrailers;
250 tracing::trace!("stream is empty, writing chunk terminator");
251 Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
252 } else {
253 *this.state = AwsChunkedBodyState::WritingChunk;
254 let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
256 tracing::trace!(%chunk_size, "writing chunk size");
257 let chunk_size = Bytes::from(chunk_size);
258 Poll::Ready(Some(Ok(chunk_size)))
259 }
260 }
261 AwsChunkedBodyState::WritingChunk => match this.inner.poll_data(cx) {
262 Poll::Ready(Some(Ok(data))) => {
263 tracing::trace!(len = data.len(), "writing chunk data");
264 *this.inner_body_bytes_read_so_far += data.len();
265 Poll::Ready(Some(Ok(data)))
266 }
267 Poll::Ready(None) => {
268 let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
269 let expected_stream_length = this.options.stream_length;
270 if actual_stream_length != expected_stream_length {
271 let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
272 actual: actual_stream_length,
273 expected: expected_stream_length,
274 });
275 return Poll::Ready(Some(Err(err)));
276 };
277
278 tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
279 *this.state = AwsChunkedBodyState::WritingTrailers;
280 Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
283 }
284 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
285 Poll::Pending => Poll::Pending,
286 },
287 AwsChunkedBodyState::WritingTrailers => {
288 return match this.inner.poll_trailers(cx) {
289 Poll::Ready(Ok(trailers)) => {
290 *this.state = AwsChunkedBodyState::Closed;
291 let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
292 let actual_length = this.options.total_trailer_length();
293
294 if expected_length != actual_length {
295 let err =
296 Box::new(AwsChunkedBodyError::ReportedTrailerLengthMismatch {
297 actual: actual_length,
298 expected: expected_length,
299 });
300 return Poll::Ready(Some(Err(err)));
301 }
302
303 let mut trailers =
304 trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
305 trailers.extend_from_slice(CRLF.as_bytes());
307
308 Poll::Ready(Some(Ok(trailers.into())))
309 }
310 Poll::Pending => Poll::Pending,
311 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
312 };
313 }
314 AwsChunkedBodyState::Closed => Poll::Ready(None),
315 }
316 }
317
318 fn poll_trailers(
319 self: Pin<&mut Self>,
320 _cx: &mut Context<'_>,
321 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
322 Poll::Ready(Ok(None))
324 }
325
326 fn is_end_stream(&self) -> bool {
327 self.state == AwsChunkedBodyState::Closed
328 }
329
330 fn size_hint(&self) -> SizeHint {
331 SizeHint::with_exact(self.options.encoded_length())
332 }
333}
334
335#[derive(Debug)]
337enum AwsChunkedBodyError {
338 ReportedTrailerLengthMismatch { actual: u64, expected: u64 },
343 StreamLengthMismatch { actual: u64, expected: u64 },
347}
348
349impl std::fmt::Display for AwsChunkedBodyError {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 match self {
352 Self::ReportedTrailerLengthMismatch { actual, expected } => {
353 write!(f, "When creating this AwsChunkedBody, length of trailers was reported as {expected}. However, when double checking during trailer encoding, length was found to be {actual} instead.")
354 }
355 Self::StreamLengthMismatch { actual, expected } => {
356 write!(f, "When creating this AwsChunkedBody, stream length was reported as {expected}. However, when double checking during body encoding, length was found to be {actual} instead.")
357 }
358 }
359 }
360}
361
362impl std::error::Error for AwsChunkedBodyError {}
363
364fn int_log16<T>(mut i: T) -> u64
366where
367 T: std::ops::DivAssign + PartialOrd + From<u8> + Copy,
368{
369 let mut len = 0;
370 let zero = T::from(0);
371 let sixteen = T::from(16);
372
373 while i > zero {
374 i /= sixteen;
375 len += 1;
376 }
377
378 len
379}
380
381#[cfg(test)]
382mod tests {
383 use super::{
384 total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes, AwsChunkedBody,
385 AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF,
386 };
387
388 use aws_smithy_types::body::SdkBody;
389 use bytes::{Buf, Bytes};
390 use bytes_utils::SegmentedBuf;
391 use http_02x::{HeaderMap, HeaderValue};
392 use http_body_04x::{Body, SizeHint};
393 use pin_project_lite::pin_project;
394
395 use std::io::Read;
396 use std::pin::Pin;
397 use std::task::{Context, Poll};
398 use std::time::Duration;
399
400 pin_project! {
401 struct SputteringBody {
402 parts: Vec<Option<Bytes>>,
403 cursor: usize,
404 delay_in_millis: u64,
405 }
406 }
407
408 impl SputteringBody {
409 fn len(&self) -> usize {
410 self.parts.iter().flatten().map(|b| b.len()).sum()
411 }
412 }
413
414 impl Body for SputteringBody {
415 type Data = Bytes;
416 type Error = aws_smithy_types::body::Error;
417
418 fn poll_data(
419 self: Pin<&mut Self>,
420 cx: &mut Context<'_>,
421 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
422 if self.cursor == self.parts.len() {
423 return Poll::Ready(None);
424 }
425
426 let this = self.project();
427 let delay_in_millis = *this.delay_in_millis;
428 let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
429
430 match next_part {
431 None => {
432 *this.cursor += 1;
433 let waker = cx.waker().clone();
434 tokio::spawn(async move {
435 tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
436 waker.wake();
437 });
438 Poll::Pending
439 }
440 Some(data) => {
441 *this.cursor += 1;
442 Poll::Ready(Some(Ok(data)))
443 }
444 }
445 }
446
447 fn poll_trailers(
448 self: Pin<&mut Self>,
449 _cx: &mut Context<'_>,
450 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
451 Poll::Ready(Ok(None))
452 }
453
454 fn is_end_stream(&self) -> bool {
455 false
456 }
457
458 fn size_hint(&self) -> SizeHint {
459 SizeHint::new()
460 }
461 }
462
463 #[tokio::test]
464 async fn test_aws_chunked_encoding() {
465 let test_fut = async {
466 let input_str = "Hello world";
467 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
468 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
469
470 let mut output = SegmentedBuf::new();
471 while let Some(buf) = body.data().await {
472 output.push(buf.unwrap());
473 }
474
475 let mut actual_output = String::new();
476 output
477 .reader()
478 .read_to_string(&mut actual_output)
479 .expect("Doesn't cause IO errors");
480
481 let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
482
483 assert_eq!(expected_output, actual_output);
484 assert!(
485 body.trailers()
486 .await
487 .expect("no errors occurred during trailer polling")
488 .is_none(),
489 "aws-chunked encoded bodies don't have normal HTTP trailers"
490 );
491
492 };
494
495 let timeout_duration = Duration::from_secs(3);
496 if tokio::time::timeout(timeout_duration, test_fut)
497 .await
498 .is_err()
499 {
500 panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
501 }
502 }
503
504 #[tokio::test]
505 async fn test_aws_chunked_encoding_sputtering_body() {
506 let test_fut = async {
507 let input = SputteringBody {
508 parts: vec![
509 Some(Bytes::from_static(b"chunk 1, ")),
510 None,
511 Some(Bytes::from_static(b"chunk 2, ")),
512 Some(Bytes::from_static(b"chunk 3, ")),
513 None,
514 None,
515 Some(Bytes::from_static(b"chunk 4, ")),
516 Some(Bytes::from_static(b"chunk 5, ")),
517 Some(Bytes::from_static(b"chunk 6")),
518 ],
519 cursor: 0,
520 delay_in_millis: 500,
521 };
522 let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
523 let mut body = AwsChunkedBody::new(input, opts);
524
525 let mut output = SegmentedBuf::new();
526 while let Some(buf) = body.data().await {
527 output.push(buf.unwrap());
528 }
529
530 let mut actual_output = String::new();
531 output
532 .reader()
533 .read_to_string(&mut actual_output)
534 .expect("Doesn't cause IO errors");
535
536 let expected_output =
537 "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
538
539 assert_eq!(expected_output, actual_output);
540 assert!(
541 body.trailers()
542 .await
543 .expect("no errors occurred during trailer polling")
544 .is_none(),
545 "aws-chunked encoded bodies don't have normal HTTP trailers"
546 );
547 };
548
549 let timeout_duration = Duration::from_secs(3);
550 if tokio::time::timeout(timeout_duration, test_fut)
551 .await
552 .is_err()
553 {
554 panic!(
555 "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
556 );
557 }
558 }
559
560 #[tokio::test]
561 #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
562 async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
563 let input_str = "Hello world";
564 let wrong_trailer_len = 42;
568 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
569 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
570
571 while let Some(buf) = body.data().await {
573 drop(buf.unwrap());
574 }
575
576 assert!(
577 body.trailers()
578 .await
579 .expect("no errors occurred during trailer polling")
580 .is_none(),
581 "aws-chunked encoded bodies don't have normal HTTP trailers"
582 );
583 }
584
585 #[tokio::test]
586 async fn test_aws_chunked_encoding_empty_body() {
587 let input_str = "";
588 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
589 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
590
591 let mut output = SegmentedBuf::new();
592 while let Some(buf) = body.data().await {
593 output.push(buf.unwrap());
594 }
595
596 let mut actual_output = String::new();
597 output
598 .reader()
599 .read_to_string(&mut actual_output)
600 .expect("Doesn't cause IO errors");
601
602 let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
603
604 assert_eq!(expected_output, actual_output);
605 assert!(
606 body.trailers()
607 .await
608 .expect("no errors occurred during trailer polling")
609 .is_none(),
610 "aws-chunked encoded bodies don't have normal HTTP trailers"
611 );
612 }
613
614 #[tokio::test]
615 async fn test_total_rendered_length_of_trailers() {
616 let mut trailers = HeaderMap::new();
617
618 trailers.insert("empty_value", HeaderValue::from_static(""));
619
620 trailers.insert("single_value", HeaderValue::from_static("value 1"));
621
622 trailers.insert("two_values", HeaderValue::from_static("value 1"));
623 trailers.append("two_values", HeaderValue::from_static("value 2"));
624
625 trailers.insert("three_values", HeaderValue::from_static("value 1"));
626 trailers.append("three_values", HeaderValue::from_static("value 2"));
627 trailers.append("three_values", HeaderValue::from_static("value 3"));
628
629 let trailers = Some(trailers);
630 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
631 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
632
633 assert_eq!(expected_length, actual_length);
634 }
635
636 #[tokio::test]
637 async fn test_total_rendered_length_of_empty_trailers() {
638 let trailers = Some(HeaderMap::new());
639 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
640 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
641
642 assert_eq!(expected_length, actual_length);
643 }
644}