AWS SDK

AWS SDK

rev. b2b472479d5b632758f738dab8641345aa5f60f2 (ignoring whitespace)

Files changed:

tmp-codegen-diff/aws-sdk/sdk/aws-runtime/src/content_encoding.rs

@@ -1,1 +166,197 @@
    1      1   
/*
    2      2   
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    3      3   
 * SPDX-License-Identifier: Apache-2.0
    4      4   
 */
    5      5   
           6  +
use aws_smithy_types::config_bag::{Storable, StoreReplace};
    6      7   
use bytes::{Bytes, BytesMut};
    7      8   
use http_02x::{HeaderMap, HeaderValue};
    8      9   
use http_body_04x::{Body, SizeHint};
    9     10   
use pin_project_lite::pin_project;
   10     11   
   11     12   
use std::pin::Pin;
   12     13   
use std::task::{Context, Poll};
   13     14   
   14     15   
const CRLF: &str = "\r\n";
   15     16   
const CHUNK_TERMINATOR: &str = "0\r\n";
   16     17   
const TRAILER_SEPARATOR: &[u8] = b":";
   17     18   
   18     19   
/// Content encoding header value constants
   19     20   
pub mod header_value {
   20     21   
    /// Header value denoting "aws-chunked" encoding
   21     22   
    pub const AWS_CHUNKED: &str = "aws-chunked";
   22     23   
}
   23     24   
   24     25   
/// Options used when constructing an [`AwsChunkedBody`].
   25         -
#[derive(Debug, Default)]
          26  +
#[derive(Clone, Debug, Default)]
   26     27   
#[non_exhaustive]
   27     28   
pub struct AwsChunkedBodyOptions {
   28     29   
    /// The total size of the stream. Because we only support unsigned encoding
   29     30   
    /// this implies that there will only be a single chunk containing the
   30     31   
    /// underlying payload.
   31     32   
    stream_length: u64,
   32     33   
    /// The length of each trailer sent within an `AwsChunkedBody`. Necessary in
   33     34   
    /// order to correctly calculate the total size of the body accurately.
   34     35   
    trailer_lengths: Vec<u64>,
          36  +
    /// Whether the aws-chunked encoding is disabled. This could occur, for instance,
          37  +
    /// if a user specifies a custom checksum, rendering aws-chunked encoding unnecessary.
          38  +
    disabled: bool,
          39  +
}
          40  +
          41  +
impl Storable for AwsChunkedBodyOptions {
          42  +
    type Storer = StoreReplace<Self>;
   35     43   
}
   36     44   
   37     45   
impl AwsChunkedBodyOptions {
   38     46   
    /// Create a new [`AwsChunkedBodyOptions`].
   39     47   
    pub fn new(stream_length: u64, trailer_lengths: Vec<u64>) -> Self {
   40     48   
        Self {
   41     49   
            stream_length,
   42     50   
            trailer_lengths,
          51  +
            disabled: false,
   43     52   
        }
   44     53   
    }
   45     54   
   46     55   
    fn total_trailer_length(&self) -> u64 {
   47     56   
        self.trailer_lengths.iter().sum::<u64>()
   48     57   
            // We need to account for a CRLF after each trailer name/value pair
   49     58   
            + (self.trailer_lengths.len() * CRLF.len()) as u64
   50     59   
    }
   51     60   
   52         -
    /// Set a trailer len
          61  +
    /// Set the stream length in the options
          62  +
    pub fn with_stream_length(mut self, stream_length: u64) -> Self {
          63  +
        self.stream_length = stream_length;
          64  +
        self
          65  +
    }
          66  +
          67  +
    /// Append a trailer length to the options
   53     68   
    pub fn with_trailer_len(mut self, trailer_len: u64) -> Self {
   54     69   
        self.trailer_lengths.push(trailer_len);
   55     70   
        self
   56     71   
    }
          72  +
          73  +
    /// Create a new [`AwsChunkedBodyOptions`] with aws-chunked encoding disabled.
          74  +
    ///
          75  +
    /// When the option is disabled, the body must not be wrapped in an `AwsChunkedBody`.
          76  +
    pub fn disable_chunked_encoding() -> Self {
          77  +
        Self {
          78  +
            disabled: true,
          79  +
            ..Default::default()
          80  +
        }
          81  +
    }
          82  +
          83  +
    /// Return whether aws-chunked encoding is disabled.
          84  +
    pub fn disabled(&self) -> bool {
          85  +
        self.disabled
          86  +
    }
          87  +
          88  +
    /// Return the length of the body after `aws-chunked` encoding is applied
          89  +
    pub fn encoded_length(&self) -> u64 {
          90  +
        let mut length = 0;
          91  +
        if self.stream_length != 0 {
          92  +
            length += get_unsigned_chunk_bytes_length(self.stream_length);
          93  +
        }
          94  +
          95  +
        // End chunk
          96  +
        length += CHUNK_TERMINATOR.len() as u64;
          97  +
          98  +
        // Trailers
          99  +
        for len in self.trailer_lengths.iter() {
         100  +
            length += len + CRLF.len() as u64;
         101  +
        }
         102  +
         103  +
        // Encoding terminator
         104  +
        length += CRLF.len() as u64;
         105  +
         106  +
        length
         107  +
    }
   57    108   
}
   58    109   
   59    110   
#[derive(Debug, PartialEq, Eq)]
   60    111   
enum AwsChunkedBodyState {
   61    112   
    /// Write out the size of the chunk that will follow. Then, transition into the
   62    113   
    /// `WritingChunk` state.
   63    114   
    WritingChunkSize,
   64    115   
    /// Write out the next chunk of data. Multiple polls of the inner body may need to occur before
   65    116   
    /// all data is written out. Once there is no more data to write, transition into the
   66    117   
    /// `WritingTrailers` state.
   67    118   
    WritingChunk,
   68    119   
    /// Write out all trailers associated with this `AwsChunkedBody` and then transition into the
   69    120   
    /// `Closed` state.
   70    121   
    WritingTrailers,
   71    122   
    /// This is the final state. Write out the body terminator and then remain in this state.
   72    123   
    Closed,
   73    124   
}
   74    125   
   75    126   
pin_project! {
   76    127   
    /// A request body compatible with `Content-Encoding: aws-chunked`. This implementation is only
   77    128   
    /// capable of writing a single chunk and does not support signed chunks.
   78    129   
    ///
   79    130   
    /// Chunked-Body grammar is defined in [ABNF] as:
   80    131   
    ///
   81    132   
    /// ```txt
   82    133   
    /// Chunked-Body    = *chunk
   83    134   
    ///                   last-chunk
   84    135   
    ///                   chunked-trailer
   85    136   
    ///                   CRLF
   86    137   
    ///
   87    138   
    /// chunk           = chunk-size CRLF chunk-data CRLF
   88    139   
    /// chunk-size      = 1*HEXDIG
   89    140   
    /// last-chunk      = 1*("0") CRLF
   90    141   
    /// chunked-trailer = *( entity-header CRLF )
   91    142   
    /// entity-header   = field-name ":" OWS field-value OWS
   92    143   
    /// ```
   93    144   
    /// For more info on what the abbreviations mean, see https://datatracker.ietf.org/doc/html/rfc7230#section-1.2
   94    145   
    ///
   95    146   
    /// [ABNF]:https://en.wikipedia.org/wiki/Augmented_Backus%E2%80%93Naur_form
   96    147   
    #[derive(Debug)]
   97    148   
    pub struct AwsChunkedBody<InnerBody> {
   98    149   
        #[pin]
   99    150   
        inner: InnerBody,
  100    151   
        #[pin]
  101    152   
        state: AwsChunkedBodyState,
  102    153   
        options: AwsChunkedBodyOptions,
  103    154   
        inner_body_bytes_read_so_far: usize,
  104    155   
    }
  105    156   
}
  106    157   
  107    158   
impl<Inner> AwsChunkedBody<Inner> {
  108    159   
    /// Wrap the given body in an outer body compatible with `Content-Encoding: aws-chunked`
  109    160   
    pub fn new(body: Inner, options: AwsChunkedBodyOptions) -> Self {
  110    161   
        Self {
  111    162   
            inner: body,
  112    163   
            state: AwsChunkedBodyState::WritingChunkSize,
  113    164   
            options,
  114    165   
            inner_body_bytes_read_so_far: 0,
  115    166   
        }
  116    167   
    }
  117         -
  118         -
    fn encoded_length(&self) -> u64 {
  119         -
        let mut length = 0;
  120         -
        if self.options.stream_length != 0 {
  121         -
            length += get_unsigned_chunk_bytes_length(self.options.stream_length);
  122         -
        }
  123         -
  124         -
        // End chunk
  125         -
        length += CHUNK_TERMINATOR.len() as u64;
  126         -
  127         -
        // Trailers
  128         -
        for len in self.options.trailer_lengths.iter() {
  129         -
            length += len + CRLF.len() as u64;
  130         -
        }
  131         -
  132         -
        // Encoding terminator
  133         -
        length += CRLF.len() as u64;
  134         -
  135         -
        length
  136         -
    }
  137    168   
}
  138    169   
  139    170   
fn get_unsigned_chunk_bytes_length(payload_length: u64) -> u64 {
  140    171   
    let hex_repr_len = int_log16(payload_length);
  141    172   
    hex_repr_len + CRLF.len() as u64 + payload_length + CRLF.len() as u64
  142    173   
}
  143    174   
  144    175   
/// Writes trailers out into a `string` and then converts that `String` to a `Bytes` before
  145    176   
/// returning.
  146    177   
///
@@ -270,301 +330,361 @@
  290    321   
    ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
  291    322   
        // Trailers were already appended to the body because of the content encoding scheme
  292    323   
        Poll::Ready(Ok(None))
  293    324   
    }
  294    325   
  295    326   
    fn is_end_stream(&self) -> bool {
  296    327   
        self.state == AwsChunkedBodyState::Closed
  297    328   
    }
  298    329   
  299    330   
    fn size_hint(&self) -> SizeHint {
  300         -
        SizeHint::with_exact(self.encoded_length())
         331  +
        SizeHint::with_exact(self.options.encoded_length())
  301    332   
    }
  302    333   
}
  303    334   
  304    335   
/// Errors related to `AwsChunkedBody`
  305    336   
#[derive(Debug)]
  306    337   
enum AwsChunkedBodyError {
  307    338   
    /// Error that occurs when the sum of `trailer_lengths` set when creating an `AwsChunkedBody` is
  308    339   
    /// not equal to the actual length of the trailers returned by the inner `http_body::Body`
  309    340   
    /// implementor. These trailer lengths are necessary in order to correctly calculate the total
  310    341   
    /// size of the body for setting the content length header.

tmp-codegen-diff/aws-sdk/sdk/s3/src/aws_chunked.rs

@@ -0,1 +0,288 @@
           1  +
// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
           2  +
/*
           3  +
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
           4  +
 * SPDX-License-Identifier: Apache-2.0
           5  +
 */
           6  +
           7  +
#![allow(dead_code)]
           8  +
           9  +
use std::fmt;
          10  +
          11  +
use aws_runtime::{
          12  +
    auth::PayloadSigningOverride,
          13  +
    content_encoding::{header_value::AWS_CHUNKED, AwsChunkedBody, AwsChunkedBodyOptions},
          14  +
};
          15  +
use aws_smithy_runtime_api::{
          16  +
    box_error::BoxError,
          17  +
    client::{
          18  +
        interceptors::{context::BeforeTransmitInterceptorContextMut, Intercept},
          19  +
        runtime_components::RuntimeComponents,
          20  +
    },
          21  +
    http::Request,
          22  +
};
          23  +
use aws_smithy_types::{body::SdkBody, config_bag::ConfigBag, error::operation::BuildError};
          24  +
use http::{header, HeaderValue};
          25  +
use http_body::Body;
          26  +
          27  +
const X_AMZ_DECODED_CONTENT_LENGTH: &str = "x-amz-decoded-content-length";
          28  +
          29  +
/// Errors related to constructing aws-chunked encoded HTTP requests.
          30  +
#[derive(Debug)]
          31  +
enum Error {
          32  +
    UnsizedRequestBody,
          33  +
}
          34  +
          35  +
impl fmt::Display for Error {
          36  +
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
          37  +
        match self {
          38  +
            Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be aws-chunk encoded."),
          39  +
        }
          40  +
    }
          41  +
}
          42  +
          43  +
impl std::error::Error for Error {}
          44  +
          45  +
#[derive(Debug)]
          46  +
pub(crate) struct AwsChunkedContentEncodingInterceptor;
          47  +
          48  +
impl Intercept for AwsChunkedContentEncodingInterceptor {
          49  +
    fn name(&self) -> &'static str {
          50  +
        "AwsChunkedContentEncodingInterceptor"
          51  +
    }
          52  +
          53  +
    fn modify_before_signing(
          54  +
        &self,
          55  +
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
          56  +
        _runtime_components: &RuntimeComponents,
          57  +
        cfg: &mut ConfigBag,
          58  +
    ) -> Result<(), BoxError> {
          59  +
        if must_not_use_chunked_encoding(context.request(), cfg) {
          60  +
            tracing::debug!("short-circuiting modify_before_signing because chunked encoding must not be used");
          61  +
            return Ok(());
          62  +
        }
          63  +
          64  +
        let original_body_size = if let Some(size) = context
          65  +
            .request()
          66  +
            .headers()
          67  +
            .get(header::CONTENT_LENGTH)
          68  +
            .and_then(|s| s.parse::<u64>().ok())
          69  +
            .or_else(|| context.request().body().size_hint().exact())
          70  +
        {
          71  +
            size
          72  +
        } else {
          73  +
            return Err(BuildError::other(Error::UnsizedRequestBody))?;
          74  +
        };
          75  +
          76  +
        let chunked_body_options = if let Some(chunked_body_options) = cfg.get_mut_from_interceptor_state::<AwsChunkedBodyOptions>() {
          77  +
            let chunked_body_options = std::mem::take(chunked_body_options);
          78  +
            chunked_body_options.with_stream_length(original_body_size)
          79  +
        } else {
          80  +
            AwsChunkedBodyOptions::default().with_stream_length(original_body_size)
          81  +
        };
          82  +
          83  +
        let request = context.request_mut();
          84  +
        // For for aws-chunked encoding, `x-amz-decoded-content-length` must be set to the original body size.
          85  +
        request.headers_mut().insert(
          86  +
            header::HeaderName::from_static(X_AMZ_DECODED_CONTENT_LENGTH),
          87  +
            HeaderValue::from(original_body_size),
          88  +
        );
          89  +
        // Other than `x-amz-decoded-content-length`, either `content-length` or `transfer-encoding`
          90  +
        // must be set, but not both. For uses cases we support, we know the original body size and
          91  +
        // can calculate the encoded size, so we set `content-length`.
          92  +
        request
          93  +
            .headers_mut()
          94  +
            .insert(header::CONTENT_LENGTH, HeaderValue::from(chunked_body_options.encoded_length()));
          95  +
        // Setting `content-length` above means we must unset `transfer-encoding`.
          96  +
        request.headers_mut().remove(header::TRANSFER_ENCODING);
          97  +
        request.headers_mut().append(
          98  +
            header::CONTENT_ENCODING,
          99  +
            HeaderValue::from_str(AWS_CHUNKED)
         100  +
                .map_err(BuildError::other)
         101  +
                .expect("\"aws-chunked\" will always be a valid HeaderValue"),
         102  +
        );
         103  +
         104  +
        cfg.interceptor_state().store_put(chunked_body_options);
         105  +
        cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
         106  +
         107  +
        Ok(())
         108  +
    }
         109  +
         110  +
    fn modify_before_transmit(
         111  +
        &self,
         112  +
        ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
         113  +
        _runtime_components: &RuntimeComponents,
         114  +
        cfg: &mut ConfigBag,
         115  +
    ) -> Result<(), BoxError> {
         116  +
        if must_not_use_chunked_encoding(ctx.request(), cfg) {
         117  +
            tracing::debug!("short-circuiting modify_before_transmit because chunked encoding must not be used");
         118  +
            return Ok(());
         119  +
        }
         120  +
         121  +
        let request = ctx.request_mut();
         122  +
         123  +
        let mut body = {
         124  +
            let body = std::mem::replace(request.body_mut(), SdkBody::taken());
         125  +
            let opt = cfg
         126  +
                .get_mut_from_interceptor_state::<AwsChunkedBodyOptions>()
         127  +
                .ok_or_else(|| BuildError::other("AwsChunkedBodyOptions missing from config bag"))?;
         128  +
            let aws_chunked_body_options = std::mem::take(opt);
         129  +
            body.map(move |body| {
         130  +
                let body = AwsChunkedBody::new(body, aws_chunked_body_options.clone());
         131  +
                SdkBody::from_body_0_4(body)
         132  +
            })
         133  +
        };
         134  +
         135  +
        std::mem::swap(request.body_mut(), &mut body);
         136  +
         137  +
        Ok(())
         138  +
    }
         139  +
}
         140  +
         141  +
// Determine if chunked encoding must not be used; returns true when any of the following is true:
         142  +
// - If the body is in-memory
         143  +
// - If chunked encoding is disabled via `AwsChunkedBodyOptions`
         144  +
fn must_not_use_chunked_encoding(request: &Request, cfg: &ConfigBag) -> bool {
         145  +
    match (request.body().bytes(), cfg.load::<AwsChunkedBodyOptions>()) {
         146  +
        (Some(_), _) => true,
         147  +
        (_, Some(options)) if options.disabled() => true,
         148  +
        _ => false,
         149  +
    }
         150  +
}
         151  +
         152  +
#[cfg(test)]
         153  +
mod tests {
         154  +
    use super::*;
         155  +
    use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, Input, InterceptorContext};
         156  +
    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
         157  +
    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
         158  +
    use aws_smithy_types::byte_stream::ByteStream;
         159  +
    use bytes::BytesMut;
         160  +
    use http_body::Body;
         161  +
    use tempfile::NamedTempFile;
         162  +
         163  +
    #[tokio::test]
         164  +
    async fn test_aws_chunked_body_is_retryable() {
         165  +
        use std::io::Write;
         166  +
        let mut file = NamedTempFile::new().unwrap();
         167  +
         168  +
        for i in 0..10000 {
         169  +
            let line = format!("This is a large file created for testing purposes {}", i);
         170  +
            file.as_file_mut().write_all(line.as_bytes()).unwrap();
         171  +
        }
         172  +
         173  +
        let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
         174  +
         175  +
        // ensure original SdkBody is retryable
         176  +
        assert!(request.body().try_clone().is_some());
         177  +
         178  +
        let interceptor = AwsChunkedContentEncodingInterceptor;
         179  +
        let mut cfg = ConfigBag::base();
         180  +
        cfg.interceptor_state().store_put(AwsChunkedBodyOptions::default());
         181  +
        let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
         182  +
        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
         183  +
        ctx.enter_serialization_phase();
         184  +
        let _ = ctx.take_input();
         185  +
        ctx.set_request(request);
         186  +
        ctx.enter_before_transmit_phase();
         187  +
        let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
         188  +
        interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
         189  +
         190  +
        // ensure wrapped SdkBody is retryable
         191  +
        let mut body = ctx.request().body().try_clone().expect("body is retryable");
         192  +
         193  +
        let mut body_data = BytesMut::new();
         194  +
        while let Some(data) = body.data().await {
         195  +
            body_data.extend_from_slice(&data.unwrap())
         196  +
        }
         197  +
        let body_str = std::str::from_utf8(&body_data).unwrap();
         198  +
        assert!(body_str.ends_with("0\r\n\r\n"));
         199  +
    }
         200  +
         201  +
    #[tokio::test]
         202  +
    async fn test_short_circuit_modify_before_signing() {
         203  +
        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
         204  +
        ctx.enter_serialization_phase();
         205  +
        let _ = ctx.take_input();
         206  +
        let request = HttpRequest::new(SdkBody::from("in-memory body, must not use chunked encoding"));
         207  +
        ctx.set_request(request);
         208  +
        ctx.enter_before_transmit_phase();
         209  +
        let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
         210  +
         211  +
        let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
         212  +
         213  +
        let mut cfg = ConfigBag::base();
         214  +
        cfg.interceptor_state().store_put(AwsChunkedBodyOptions::default());
         215  +
         216  +
        let interceptor = AwsChunkedContentEncodingInterceptor;
         217  +
        interceptor.modify_before_signing(&mut ctx, &runtime_components, &mut cfg).unwrap();
         218  +
         219  +
        let request = ctx.request();
         220  +
        assert!(request.headers().get(header::CONTENT_ENCODING).is_none());
         221  +
        assert!(request
         222  +
            .headers()
         223  +
            .get(header::HeaderName::from_static(X_AMZ_DECODED_CONTENT_LENGTH))
         224  +
            .is_none());
         225  +
    }
         226  +
         227  +
    #[tokio::test]
         228  +
    async fn test_short_circuit_modify_before_transmit() {
         229  +
        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
         230  +
        ctx.enter_serialization_phase();
         231  +
        let _ = ctx.take_input();
         232  +
        let request = HttpRequest::new(SdkBody::from("in-memory body, must not use chunked encoding"));
         233  +
        ctx.set_request(request);
         234  +
        ctx.enter_before_transmit_phase();
         235  +
        let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
         236  +
         237  +
        let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
         238  +
         239  +
        let mut cfg = ConfigBag::base();
         240  +
        // Don't need to set the stream length properly because we expect the body won't be wrapped by `AwsChunkedBody`.
         241  +
        cfg.interceptor_state().store_put(AwsChunkedBodyOptions::default());
         242  +
         243  +
        let interceptor = AwsChunkedContentEncodingInterceptor;
         244  +
        interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
         245  +
         246  +
        let mut body = ctx.request().body().try_clone().expect("body is retryable");
         247  +
         248  +
        let mut body_data = BytesMut::new();
         249  +
        while let Some(data) = body.data().await {
         250  +
            body_data.extend_from_slice(&data.unwrap())
         251  +
        }
         252  +
        let body_str = std::str::from_utf8(&body_data).unwrap();
         253  +
        // Also implies that `assert!(!body_str.ends_with("0\r\n\r\n"));`, i.e., shouldn't see chunked encoding epilogue.
         254  +
        assert_eq!("in-memory body, must not use chunked encoding", body_str);
         255  +
    }
         256  +
         257  +
    #[test]
         258  +
    fn test_must_not_use_chunked_encoding_with_in_memory_body() {
         259  +
        let request = HttpRequest::new(SdkBody::from("test body"));
         260  +
        let cfg = ConfigBag::base();
         261  +
         262  +
        assert!(must_not_use_chunked_encoding(&request, &cfg));
         263  +
    }
         264  +
         265  +
    async fn streaming_body(path: impl AsRef<std::path::Path>) -> SdkBody {
         266  +
        let file = path.as_ref();
         267  +
        ByteStream::read_from().path(&file).build().await.unwrap().into_inner()
         268  +
    }
         269  +
         270  +
    #[tokio::test]
         271  +
    async fn test_must_not_use_chunked_encoding_with_disabled_option() {
         272  +
        let file = NamedTempFile::new().unwrap();
         273  +
        let request = HttpRequest::new(streaming_body(&file).await);
         274  +
        let mut cfg = ConfigBag::base();
         275  +
        cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
         276  +
         277  +
        assert!(must_not_use_chunked_encoding(&request, &cfg));
         278  +
    }
         279  +
         280  +
    #[tokio::test]
         281  +
    async fn test_chunked_encoding_is_used() {
         282  +
        let file = NamedTempFile::new().unwrap();
         283  +
        let request = HttpRequest::new(streaming_body(&file).await);
         284  +
        let cfg = ConfigBag::base();
         285  +
         286  +
        assert!(!must_not_use_chunked_encoding(&request, &cfg));
         287  +
    }
         288  +
}

tmp-codegen-diff/aws-sdk/sdk/s3/src/http_request_checksum.rs

@@ -1,1 +98,104 @@
    2      2   
/*
    3      3   
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    4      4   
 * SPDX-License-Identifier: Apache-2.0
    5      5   
 */
    6      6   
    7      7   
#![allow(dead_code)]
    8      8   
    9      9   
//! Interceptor for handling Smithy `@httpChecksum` request checksumming with AWS SigV4
   10     10   
   11     11   
use crate::presigning::PresigningMarker;
   12         -
use aws_runtime::auth::PayloadSigningOverride;
   13         -
use aws_runtime::content_encoding::header_value::AWS_CHUNKED;
   14         -
use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
          12  +
use aws_runtime::content_encoding::AwsChunkedBodyOptions;
          13  +
use aws_smithy_checksums::body::calculate;
   15     14   
use aws_smithy_checksums::body::ChecksumCache;
          15  +
use aws_smithy_checksums::http::HttpChecksum;
   16     16   
use aws_smithy_checksums::ChecksumAlgorithm;
   17         -
use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
   18     17   
use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
   19     18   
use aws_smithy_runtime_api::box_error::BoxError;
   20     19   
use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Input};
   21     20   
use aws_smithy_runtime_api::client::interceptors::Intercept;
   22         -
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
   23     21   
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
   24     22   
use aws_smithy_runtime_api::http::Request;
   25     23   
use aws_smithy_types::body::SdkBody;
   26     24   
use aws_smithy_types::checksum_config::RequestChecksumCalculation;
   27         -
use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
   28         -
use aws_smithy_types::error::operation::BuildError;
   29         -
use http::HeaderValue;
   30         -
use http_body::Body;
          25  +
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
          26  +
use http::HeaderMap;
   31     27   
use std::str::FromStr;
   32     28   
use std::sync::atomic::AtomicBool;
   33     29   
use std::sync::atomic::Ordering;
   34     30   
use std::sync::Arc;
   35     31   
use std::{fmt, mem};
   36     32   
   37     33   
/// Errors related to constructing checksum-validated HTTP requests
   38     34   
#[derive(Debug)]
   39     35   
pub(crate) enum Error {
   40         -
    /// Only request bodies with a known size can be checksum validated
   41         -
    UnsizedRequestBody,
   42     36   
    ChecksumHeadersAreUnsupportedForStreamingBody,
   43     37   
}
   44     38   
   45     39   
impl fmt::Display for Error {
   46     40   
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
   47     41   
        match self {
   48         -
            Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
   49     42   
            Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
   50     43   
                f,
   51     44   
                "Checksum header insertion is only supported for non-streaming HTTP bodies. \
   52     45   
                   To checksum validate a streaming body, the checksums must be sent as trailers."
   53     46   
            ),
   54     47   
        }
   55     48   
    }
   56     49   
}
   57     50   
   58     51   
impl std::error::Error for Error {}
   59     52   
   60         -
#[derive(Debug, Clone)]
          53  +
#[derive(Debug, Default, Clone)]
   61     54   
struct RequestChecksumInterceptorState {
   62     55   
    /// The checksum algorithm to calculate
   63     56   
    checksum_algorithm: Option<String>,
   64     57   
    /// This value is set in the model on the `httpChecksum` trait
   65     58   
    request_checksum_required: bool,
   66     59   
    calculate_checksum: Arc<AtomicBool>,
   67     60   
    checksum_cache: ChecksumCache,
   68     61   
}
          62  +
          63  +
impl RequestChecksumInterceptorState {
          64  +
    fn checksum_algorithm(&self) -> Option<ChecksumAlgorithm> {
          65  +
        self.checksum_algorithm
          66  +
            .as_ref()
          67  +
            .and_then(|s| ChecksumAlgorithm::from_str(s.as_str()).ok())
          68  +
    }
          69  +
          70  +
    fn calculate_checksum(&self) -> bool {
          71  +
        self.calculate_checksum.load(Ordering::SeqCst)
          72  +
    }
          73  +
}
          74  +
   69     75   
impl Storable for RequestChecksumInterceptorState {
   70     76   
    type Storer = StoreReplace<Self>;
   71     77   
}
   72     78   
   73     79   
type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
   74     80   
   75     81   
pub(crate) struct DefaultRequestChecksumOverride {
   76     82   
    custom_default: CustomDefaultFn,
   77     83   
}
   78     84   
impl fmt::Debug for DefaultRequestChecksumOverride {
@@ -107,113 +482,448 @@
  127    133   
    }
  128    134   
  129    135   
    fn modify_before_serialization(
  130    136   
        &self,
  131    137   
        context: &mut BeforeSerializationInterceptorContextMut<'_>,
  132    138   
        _runtime_components: &RuntimeComponents,
  133    139   
        cfg: &mut ConfigBag,
  134    140   
    ) -> Result<(), BoxError> {
  135    141   
        let (checksum_algorithm, request_checksum_required) = (self.algorithm_provider)(context.input());
  136    142   
  137         -
        let mut layer = Layer::new("RequestChecksumInterceptor");
  138         -
        layer.store_put(RequestChecksumInterceptorState {
         143  +
        cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
  139    144   
            checksum_algorithm,
  140    145   
            request_checksum_required,
  141    146   
            checksum_cache: ChecksumCache::new(),
  142    147   
            calculate_checksum: Arc::new(AtomicBool::new(false)),
  143    148   
        });
  144         -
        cfg.push_layer(layer);
  145    149   
  146    150   
        Ok(())
  147    151   
    }
  148    152   
  149    153   
    /// Setup state for calculating checksum and setting UA features
  150    154   
    fn modify_before_retry_loop(
  151    155   
        &self,
  152    156   
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
  153    157   
        _runtime_components: &RuntimeComponents,
  154    158   
        cfg: &mut ConfigBag,
  155    159   
    ) -> Result<(), BoxError> {
  156         -
        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
  157         -
  158    160   
        let user_set_checksum_value = (self.checksum_mutator)(context.request_mut(), cfg).expect("Checksum header mutation should not fail");
         161  +
        let is_presigned = cfg.load::<PresigningMarker>().is_some();
  159    162   
  160         -
        // If the user manually set a checksum header we short circuit
  161         -
        if user_set_checksum_value {
         163  +
        // If the user manually set a checksum header or if this is a presigned request, we short circuit
         164  +
        if user_set_checksum_value || is_presigned {
         165  +
            // Disable aws-chunked encoding since either the user has set a custom checksum
         166  +
            cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
  162    167   
            return Ok(());
  163    168   
        }
  164    169   
  165         -
        // This value is from the trait, but is needed for runtime logic
  166         -
        let request_checksum_required = state.request_checksum_required;
         170  +
        let state = cfg
         171  +
            .get_mut_from_interceptor_state::<RequestChecksumInterceptorState>()
         172  +
            .expect("set in `read_before_serialization`");
  167    173   
  168    174   
        // If the algorithm fails to parse it is not one we support and we error
  169    175   
        let checksum_algorithm = state
  170    176   
            .checksum_algorithm
  171    177   
            .clone()
  172    178   
            .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
  173    179   
            .transpose()?;
  174    180   
  175         -
        // This value is set by the user on the SdkConfig to indicate their preference
  176         -
        // We provide a default here for users that use a client config instead of the SdkConfig
  177         -
        let request_checksum_calculation = cfg
  178         -
            .load::<RequestChecksumCalculation>()
  179         -
            .unwrap_or(&RequestChecksumCalculation::WhenSupported);
  180         -
  181         -
        // Need to know if this is a presigned req because we do not calculate checksums for those.
  182         -
        let is_presigned_req = cfg.load::<PresigningMarker>().is_some();
         181  +
        let mut state = std::mem::take(state);
  183    182   
  184         -
        // Determine if we actually calculate the checksum. If this is a presigned request we do not
  185         -
        // If the user setting is WhenSupported (the default) we always calculate it (because this interceptor
  186         -
        // isn't added if it isn't supported). If it is WhenRequired we only calculate it if the checksum
  187         -
        // is marked required on the trait.
  188         -
        let calculate_checksum = match (request_checksum_calculation, is_presigned_req) {
  189         -
            (_, true) => false,
  190         -
            (RequestChecksumCalculation::WhenRequired, false) => request_checksum_required,
  191         -
            (RequestChecksumCalculation::WhenSupported, false) => true,
  192         -
            _ => true,
  193         -
        };
         183  +
        if calculate_checksum(cfg, &state) {
         184  +
            state.calculate_checksum.store(true, Ordering::Release);
  194    185   
  195    186   
            // If a checksum override is set in the ConfigBag we use that instead (currently only used by S3Express)
  196    187   
            // If we have made it this far without a checksum being set we set the default (currently Crc32)
  197    188   
            let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
         189  +
            state.checksum_algorithm = Some(checksum_algorithm.as_str().to_owned());
  198    190   
  199         -
        if calculate_checksum {
  200         -
            state.calculate_checksum.store(true, Ordering::Release);
  201         -
  202         -
            // Set the user-agent metric for the selected checksum algorithm
  203    191   
            // NOTE: We have to do this in modify_before_retry_loop since UA interceptor also runs
  204    192   
            // in modify_before_signing but is registered before this interceptor (client level vs operation level).
  205         -
            match checksum_algorithm {
  206         -
                ChecksumAlgorithm::Crc32 => {
  207         -
                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
  208         -
                }
  209         -
                ChecksumAlgorithm::Crc32c => {
  210         -
                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
  211         -
                }
  212         -
                ChecksumAlgorithm::Crc64Nvme => {
  213         -
                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
  214         -
                }
  215         -
                #[allow(deprecated)]
  216         -
                ChecksumAlgorithm::Md5 => {
  217         -
                    tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
  218         -
                }
  219         -
                ChecksumAlgorithm::Sha1 => {
  220         -
                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
  221         -
                }
  222         -
                ChecksumAlgorithm::Sha256 => {
  223         -
                    cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
  224         -
                }
  225         -
                unsupported => tracing::warn!(
  226         -
                        more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
  227         -
                        unsupported = ?unsupported),
  228         -
            }
         193  +
            track_metric_for_selected_checksum_algorithm(cfg, &checksum_algorithm);
         194  +
        } else {
         195  +
            // No checksum calculation needed so disable aws-chunked encoding
         196  +
            cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
  229    197   
        }
  230    198   
         199  +
        cfg.interceptor_state().store_put(state);
         200  +
  231    201   
        Ok(())
  232    202   
    }
  233    203   
  234         -
    /// Calculate a checksum and modify the request to include the checksum as a header
  235         -
    /// (for in-memory request bodies) or a trailer (for streaming request bodies).
  236         -
    /// Streaming bodies must be sized or this will return an error.
         204  +
    /// Calculate a checksum and modify the request to do either of the following:
         205  +
    /// - include the checksum as a header for signing with in-memory request bodies.
         206  +
    /// - include the checksum as a trailer for streaming request bodies.
  237    207   
    fn modify_before_signing(
  238    208   
        &self,
  239    209   
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
  240    210   
        _runtime_components: &RuntimeComponents,
  241    211   
        cfg: &mut ConfigBag,
  242    212   
    ) -> Result<(), BoxError> {
  243    213   
        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
  244    214   
  245         -
        let checksum_cache = state.checksum_cache.clone();
         215  +
        if !state.calculate_checksum() {
         216  +
            return Ok(());
         217  +
        }
  246    218   
  247         -
        let checksum_algorithm = state
  248         -
            .checksum_algorithm
  249         -
            .clone()
  250         -
            .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
  251         -
            .transpose()?;
         219  +
        let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
         220  +
        let mut checksum = checksum_algorithm.into_impl();
  252    221   
  253         -
        let calculate_checksum = state.calculate_checksum.load(Ordering::SeqCst);
         222  +
        match context.request().body().bytes() {
         223  +
            Some(data) => {
         224  +
                tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
         225  +
                checksum.update(data);
  254    226   
  255         -
        // Calculate the checksum if necessary
  256         -
        if calculate_checksum {
  257         -
            // If a checksum override is set in the ConfigBag we use that instead (currently only used by S3Express)
  258         -
            // If we have made it this far without a checksum being set we set the default (currently Crc32)
  259         -
            let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
         227  +
                for (hdr_name, hdr_value) in get_or_cache_headers(checksum.headers(), &state.checksum_cache).iter() {
         228  +
                    context.request_mut().headers_mut().insert(hdr_name.clone(), hdr_value.clone());
         229  +
                }
         230  +
            }
         231  +
            None => {
         232  +
                tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
         233  +
                context
         234  +
                    .request_mut()
         235  +
                    .headers_mut()
         236  +
                    .insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
  260    237   
  261         -
            let request = context.request_mut();
  262         -
            add_checksum_for_request_body(request, checksum_algorithm, checksum_cache, cfg)?;
         238  +
                // Take checksum header into account for `AwsChunkedBodyOptions`'s trailer length
         239  +
                let trailer_len = HttpChecksum::size(checksum.as_ref());
         240  +
                let chunked_body_options = AwsChunkedBodyOptions::default().with_trailer_len(trailer_len);
         241  +
                cfg.interceptor_state().store_put(chunked_body_options);
         242  +
            }
  263    243   
        }
  264    244   
  265    245   
        Ok(())
  266    246   
    }
  267    247   
  268         -
    /// Set the user-agent metrics for `RequestChecksumCalculation` here to avoid ownership issues
  269         -
    /// with the mutable borrow of cfg in `modify_before_signing`
  270         -
    fn read_after_serialization(
         248  +
    fn modify_before_transmit(
  271    249   
        &self,
  272         -
        _context: &aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef<'_>,
         250  +
        ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
  273    251   
        _runtime_components: &RuntimeComponents,
  274    252   
        cfg: &mut ConfigBag,
  275    253   
    ) -> Result<(), BoxError> {
  276         -
        let request_checksum_calculation = cfg
  277         -
            .load::<RequestChecksumCalculation>()
  278         -
            .unwrap_or(&RequestChecksumCalculation::WhenSupported);
  279         -
  280         -
        match request_checksum_calculation {
  281         -
            RequestChecksumCalculation::WhenSupported => {
  282         -
                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
         254  +
        if ctx.request().body().bytes().is_some() {
         255  +
            // Nothing to do for non-streaming bodies since the checksum was added to the the header
         256  +
            // in `modify_before_signing` and signing has already been done by the time this hook is called.
         257  +
            return Ok(());
  283    258   
        }
  284         -
            RequestChecksumCalculation::WhenRequired => {
  285         -
                cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
         259  +
         260  +
        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
         261  +
         262  +
        if !state.calculate_checksum() {
         263  +
            return Ok(());
  286    264   
        }
  287         -
            unsupported => tracing::warn!(
  288         -
                    more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
  289         -
                    unsupported = ?unsupported),
         265  +
         266  +
        let request = ctx.request_mut();
         267  +
         268  +
        let mut body = {
         269  +
            let body = mem::replace(request.body_mut(), SdkBody::taken());
         270  +
         271  +
            let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
         272  +
            let checksum_cache = state.checksum_cache.clone();
         273  +
         274  +
            body.map(move |body| {
         275  +
                let checksum = checksum_algorithm.into_impl();
         276  +
                let body = calculate::ChecksumBody::new(body, checksum).with_cache(checksum_cache.clone());
         277  +
         278  +
                SdkBody::from_body_0_4(body)
         279  +
            })
  290    280   
        };
  291    281   
         282  +
        mem::swap(request.body_mut(), &mut body);
         283  +
  292    284   
        Ok(())
  293    285   
    }
  294    286   
}
  295    287   
  296    288   
fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
  297    289   
    match cfg.load::<DefaultRequestChecksumOverride>() {
  298    290   
        Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
  299    291   
        None => checksum,
  300    292   
    }
  301    293   
}
  302    294   
  303         -
fn add_checksum_for_request_body(
  304         -
    request: &mut HttpRequest,
  305         -
    checksum_algorithm: ChecksumAlgorithm,
  306         -
    checksum_cache: ChecksumCache,
  307         -
    cfg: &mut ConfigBag,
  308         -
) -> Result<(), BoxError> {
  309         -
    match request.body().bytes() {
  310         -
        // Body is in-memory: read it and insert the checksum as a header.
  311         -
        Some(data) => {
  312         -
            let mut checksum = checksum_algorithm.into_impl();
  313         -
  314         -
            // If the header has not already been set we set it. If it was already set by the user
  315         -
            // we do nothing and maintain their set value.
  316         -
            if request.headers().get(checksum.header_name()).is_none() {
  317         -
                tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
  318         -
                checksum.update(data);
  319         -
  320         -
                let calculated_headers = checksum.headers();
  321         -
                let checksum_headers = if let Some(cached_headers) = checksum_cache.get() {
         295  +
fn get_or_cache_headers(calculated_headers: HeaderMap, checksum_cache: &ChecksumCache) -> HeaderMap {
         296  +
    if let Some(cached_headers) = checksum_cache.get() {
  322    297   
        if cached_headers != calculated_headers {
  323    298   
            tracing::warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
  324    299   
        }
  325    300   
        cached_headers
  326    301   
    } else {
  327    302   
        checksum_cache.set(calculated_headers.clone());
  328    303   
        calculated_headers
  329         -
                };
  330         -
  331         -
                for (hdr_name, hdr_value) in checksum_headers.iter() {
  332         -
                    request.headers_mut().insert(hdr_name.clone(), hdr_value.clone());
  333    304   
    }
         305  +
}
         306  +
         307  +
// Determine if we actually calculate the checksum
         308  +
fn calculate_checksum(cfg: &mut ConfigBag, state: &RequestChecksumInterceptorState) -> bool {
         309  +
    // This value is set by the user on the SdkConfig to indicate their preference
         310  +
    // We provide a default here for users that use a client config instead of the SdkConfig
         311  +
    let request_checksum_calculation = cfg
         312  +
        .load::<RequestChecksumCalculation>()
         313  +
        .unwrap_or(&RequestChecksumCalculation::WhenSupported);
         314  +
         315  +
    // If the user setting is WhenSupported (the default) we always calculate it (because this interceptor
         316  +
    // isn't added if it isn't supported). If it is WhenRequired we only calculate it if the checksum
         317  +
    // is marked required on the trait.
         318  +
    match request_checksum_calculation {
         319  +
        RequestChecksumCalculation::WhenRequired => {
         320  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
         321  +
            state.request_checksum_required
  334    322   
        }
         323  +
        RequestChecksumCalculation::WhenSupported => {
         324  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
         325  +
            true
  335    326   
        }
  336         -
        // Body is streaming: wrap the body so it will emit a checksum as a trailer.
  337         -
        None => {
  338         -
            tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
  339         -
            cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
  340         -
            wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm, checksum_cache.clone())?;
         327  +
        unsupported => {
         328  +
            tracing::warn!(
         329  +
                more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
         330  +
                unsupported = ?unsupported
         331  +
            );
         332  +
            true
  341    333   
        }
  342    334   
    }
  343         -
    Ok(())
  344    335   
}
  345    336   
  346         -
fn wrap_streaming_request_body_in_checksum_calculating_body(
  347         -
    request: &mut HttpRequest,
  348         -
    checksum_algorithm: ChecksumAlgorithm,
  349         -
    checksum_cache: ChecksumCache,
  350         -
) -> Result<(), BuildError> {
  351         -
    let checksum = checksum_algorithm.into_impl();
  352         -
  353         -
    // If the user already set the header value then do nothing and return early
  354         -
    if request.headers().get(checksum.header_name()).is_some() {
  355         -
        return Ok(());
         337  +
// Set the user-agent metric for the selected checksum algorithm
         338  +
fn track_metric_for_selected_checksum_algorithm(cfg: &mut ConfigBag, checksum_algorithm: &ChecksumAlgorithm) {
         339  +
    match checksum_algorithm {
         340  +
        ChecksumAlgorithm::Crc32 => {
         341  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
         342  +
        }
         343  +
        ChecksumAlgorithm::Crc32c => {
         344  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
         345  +
        }
         346  +
        ChecksumAlgorithm::Crc64Nvme => {
         347  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
         348  +
        }
         349  +
        #[allow(deprecated)]
         350  +
        ChecksumAlgorithm::Md5 => {
         351  +
            tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
         352  +
        }
         353  +
        ChecksumAlgorithm::Sha1 => {
         354  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
         355  +
        }
         356  +
        ChecksumAlgorithm::Sha256 => {
         357  +
            cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
         358  +
        }
         359  +
        unsupported => tracing::warn!(
         360  +
                more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
         361  +
                unsupported = ?unsupported),
  356    362   
    }
  357         -
  358         -
    let original_body_size = request
  359         -
        .body()
  360         -
        .size_hint()
  361         -
        .exact()
  362         -
        .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
  363         -
  364         -
    let mut body = {
  365         -
        let body = mem::replace(request.body_mut(), SdkBody::taken());
  366         -
  367         -
        body.map(move |body| {
  368         -
            let checksum = checksum_algorithm.into_impl();
  369         -
            let trailer_len = HttpChecksum::size(checksum.as_ref());
  370         -
            let body = calculate::ChecksumBody::new(body, checksum).with_cache(checksum_cache.clone());
  371         -
            let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
  372         -
  373         -
            let body = AwsChunkedBody::new(body, aws_chunked_body_options);
  374         -
  375         -
            SdkBody::from_body_0_4(body)
  376         -
        })
  377         -
    };
  378         -
  379         -
    let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
  380         -
  381         -
    let headers = request.headers_mut();
  382         -
  383         -
    headers.insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
  384         -
  385         -
    headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
  386         -
    headers.insert(
  387         -
        http::header::HeaderName::from_static("x-amz-decoded-content-length"),
  388         -
        HeaderValue::from(original_body_size),
  389         -
    );
  390         -
    // The target service does not depend on where `aws-chunked` appears in the `Content-Encoding` header,
  391         -
    // as it will ultimately be stripped.
  392         -
    headers.append(
  393         -
        http::header::CONTENT_ENCODING,
  394         -
        HeaderValue::from_str(AWS_CHUNKED)
  395         -
            .map_err(BuildError::other)
  396         -
            .expect("\"aws-chunked\" will always be a valid HeaderValue"),
  397         -
    );
  398         -
  399         -
    mem::swap(request.body_mut(), &mut body);
  400         -
  401         -
    Ok(())
  402    363   
}
  403    364   
  404    365   
#[cfg(test)]
  405    366   
mod tests {
  406         -
    use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
  407         -
    use aws_smithy_checksums::body::ChecksumCache;
         367  +
    use super::*;
  408    368   
    use aws_smithy_checksums::ChecksumAlgorithm;
         369  +
    use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, InterceptorContext};
  409    370   
    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
         371  +
    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
  410    372   
    use aws_smithy_types::base64;
  411         -
    use aws_smithy_types::body::SdkBody;
  412    373   
    use aws_smithy_types::byte_stream::ByteStream;
  413    374   
    use bytes::BytesMut;
  414    375   
    use http_body::Body;
  415    376   
    use tempfile::NamedTempFile;
  416    377   
  417         -
    #[tokio::test]
  418         -
    async fn test_checksum_body_is_retryable() {
  419         -
        let input_text = "Hello world";
  420         -
        let chunk_len_hex = format!("{:X}", input_text.len());
  421         -
        let mut request: HttpRequest = http::Request::builder()
  422         -
            .body(SdkBody::retryable(move || SdkBody::from(input_text)))
  423         -
            .unwrap()
  424         -
            .try_into()
  425         -
            .unwrap();
  426         -
  427         -
        // ensure original SdkBody is retryable
  428         -
        assert!(request.body().try_clone().is_some());
  429         -
  430         -
        let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
  431         -
        let checksum_cache = ChecksumCache::new();
  432         -
        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm, checksum_cache).unwrap();
  433         -
  434         -
        // ensure wrapped SdkBody is retryable
  435         -
        let mut body = request.body().try_clone().expect("body is retryable");
  436         -
  437         -
        let mut body_data = BytesMut::new();
  438         -
        while let Some(data) = body.data().await {
  439         -
            body_data.extend_from_slice(&data.unwrap())
         378  +
    fn create_test_interceptor() -> RequestChecksumInterceptor<
         379  +
        impl Fn(&Input) -> (Option<String>, bool) + Send + Sync,
         380  +
        impl Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
         381  +
    > {
         382  +
        fn algo(_: &Input) -> (Option<String>, bool) {
         383  +
            (Some("crc32".to_string()), false)
  440    384   
        }
  441         -
        let body = std::str::from_utf8(&body_data).unwrap();
  442         -
        assert_eq!(
  443         -
            format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
  444         -
            body
  445         -
        );
         385  +
        fn mutator(_: &mut Request, _: &ConfigBag) -> Result<bool, BoxError> {
         386  +
            Ok(false)
         387  +
        }
         388  +
        RequestChecksumInterceptor::new(algo, mutator)
  446    389   
    }
  447    390   
  448    391   
    #[tokio::test]
  449         -
    async fn test_checksum_body_from_file_is_retryable() {
         392  +
    async fn test_checksum_body_is_retryable() {
  450    393   
        use std::io::Write;
  451    394   
        let mut file = NamedTempFile::new().unwrap();
  452         -
        let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
         395  +
        let algorithm_str = "crc32c";
         396  +
        let checksum_algorithm: ChecksumAlgorithm = algorithm_str.parse().unwrap();
  453    397   
  454    398   
        let mut crc32c_checksum = checksum_algorithm.into_impl();
  455    399   
        for i in 0..10000 {
  456    400   
            let line = format!("This is a large file created for testing purposes {}", i);
  457    401   
            file.as_file_mut().write_all(line.as_bytes()).unwrap();
  458    402   
            crc32c_checksum.update(line.as_bytes());
  459    403   
        }
  460    404   
        let crc32c_checksum = crc32c_checksum.finalize();
  461    405   
  462         -
        let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
         406  +
        let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
  463    407   
  464    408   
        // ensure original SdkBody is retryable
  465    409   
        assert!(request.body().try_clone().is_some());
  466    410   
  467         -
        let checksum_cache = ChecksumCache::new();
  468         -
        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm, checksum_cache).unwrap();
         411  +
        let interceptor = create_test_interceptor();
         412  +
        let mut cfg = ConfigBag::base();
         413  +
        cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
         414  +
            checksum_algorithm: Some(algorithm_str.to_string()),
         415  +
            calculate_checksum: Arc::new(AtomicBool::new(true)),
         416  +
            ..Default::default()
         417  +
        });
         418  +
        let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
         419  +
        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
         420  +
        ctx.enter_serialization_phase();
         421  +
        let _ = ctx.take_input();
         422  +
        ctx.set_request(request);
         423  +
        ctx.enter_before_transmit_phase();
         424  +
        let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
         425  +
        interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
  469    426   
  470    427   
        // ensure wrapped SdkBody is retryable
  471         -
        let mut body = request.body().try_clone().expect("body is retryable");
         428  +
        let mut body = ctx.request().body().try_clone().expect("body is retryable");
  472    429   
  473    430   
        let mut body_data = BytesMut::new();
  474    431   
        while let Some(data) = body.data().await {
  475    432   
            body_data.extend_from_slice(&data.unwrap())
  476    433   
        }
  477         -
        let body = std::str::from_utf8(&body_data).unwrap();
         434  +
        let body_str = std::str::from_utf8(&body_data).unwrap();
         435  +
        let expected = format!("This is a large file created for testing purposes 9999");
         436  +
        assert!(body_str.ends_with(&expected), "expected '{body_str}' to end with '{expected}'");
  478    437   
        let expected_checksum = base64::encode(&crc32c_checksum);
  479         -
        let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
  480         -
        assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
         438  +
        while let Ok(Some(trailer)) = body.trailers().await {
         439  +
            if let Some(header_value) = trailer.get("x-amz-checksum-crc32c") {
         440  +
                let header_value = header_value.to_str().unwrap();
         441  +
                assert_eq!(
         442  +
                    header_value, expected_checksum,
         443  +
                    "expected checksum '{header_value}' to match '{expected_checksum}'"
         444  +
                );
         445  +
            }
         446  +
        }
  481    447   
    }
  482    448   
}

tmp-codegen-diff/aws-sdk/sdk/s3/src/lib.rs

@@ -172,172 +231,233 @@
  192    192   
  193    193   
/// All operations that this crate can perform.
  194    194   
pub mod operation;
  195    195   
  196    196   
/// Primitives such as `Blob` or `DateTime` used by other types.
  197    197   
pub mod primitives;
  198    198   
  199    199   
/// Data structures used by operation inputs/outputs.
  200    200   
pub mod types;
  201    201   
         202  +
pub(crate) mod aws_chunked;
         203  +
  202    204   
pub(crate) mod client_idempotency_token;
  203    205   
  204    206   
mod event_receiver;
  205    207   
  206    208   
pub(crate) mod http_request_checksum;
  207    209   
  208    210   
pub(crate) mod http_response_checksum;
  209    211   
  210    212   
mod idempotency_token;
  211    213   

tmp-codegen-diff/aws-sdk/sdk/s3/src/operation/put_object.rs

@@ -157,157 +216,217 @@
  177    177   
                            request.headers_mut().insert("x-amz-sdk-checksum-algorithm", "CRC32");
  178    178   
                        }
  179    179   
                        _ => {}
  180    180   
                    }
  181    181   
  182    182   
                    // We return a bool indicating if the user did set the checksum value, if they did
  183    183   
                    // we can short circuit and exit the interceptor early.
  184    184   
                    Ok(user_set_checksum_value)
  185    185   
                },
  186    186   
            ))
         187  +
            .with_interceptor(crate::aws_chunked::AwsChunkedContentEncodingInterceptor)
  187    188   
            .with_retry_classifier(::aws_smithy_runtime::client::retries::classifiers::TransientErrorClassifier::<
  188    189   
                crate::operation::put_object::PutObjectError,
  189    190   
            >::new())
  190    191   
            .with_retry_classifier(::aws_smithy_runtime::client::retries::classifiers::ModeledAsRetryableClassifier::<
  191    192   
                crate::operation::put_object::PutObjectError,
  192    193   
            >::new())
  193    194   
            .with_retry_classifier(
  194    195   
                ::aws_runtime::retries::classifiers::AwsErrorCodeClassifier::<crate::operation::put_object::PutObjectError>::builder()
  195    196   
                    .transient_errors({
  196    197   
                        let mut transient_errors: Vec<&'static str> = ::aws_runtime::retries::classifiers::TRANSIENT_ERRORS.into();

tmp-codegen-diff/aws-sdk/sdk/s3/src/operation/upload_part.rs

@@ -157,157 +216,217 @@
  177    177   
                            request.headers_mut().insert("x-amz-sdk-checksum-algorithm", "CRC32");
  178    178   
                        }
  179    179   
                        _ => {}
  180    180   
                    }
  181    181   
  182    182   
                    // We return a bool indicating if the user did set the checksum value, if they did
  183    183   
                    // we can short circuit and exit the interceptor early.
  184    184   
                    Ok(user_set_checksum_value)
  185    185   
                },
  186    186   
            ))
         187  +
            .with_interceptor(crate::aws_chunked::AwsChunkedContentEncodingInterceptor)
  187    188   
            .with_retry_classifier(::aws_smithy_runtime::client::retries::classifiers::TransientErrorClassifier::<
  188    189   
                crate::operation::upload_part::UploadPartError,
  189    190   
            >::new())
  190    191   
            .with_retry_classifier(::aws_smithy_runtime::client::retries::classifiers::ModeledAsRetryableClassifier::<
  191    192   
                crate::operation::upload_part::UploadPartError,
  192    193   
            >::new())
  193    194   
            .with_retry_classifier(
  194    195   
                ::aws_runtime::retries::classifiers::AwsErrorCodeClassifier::<crate::operation::upload_part::UploadPartError>::builder()
  195    196   
                    .transient_errors({
  196    197   
                        let mut transient_errors: Vec<&'static str> = ::aws_runtime::retries::classifiers::TRANSIENT_ERRORS.into();