aws_smithy_runtime/client/http/body/
content_length_enforcement.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! RuntimePlugin to ensure that the amount of data received matches the `Content-Length` header
7
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::{
10    BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextRef,
11};
12use aws_smithy_runtime_api::client::interceptors::SharedInterceptor;
13use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
14use aws_smithy_runtime_api::client::runtime_components::{
15    RuntimeComponents, RuntimeComponentsBuilder,
16};
17use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
18use aws_smithy_runtime_api::http::Response;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
21use bytes::Buf;
22use http_body_1x::{Frame, SizeHint};
23use pin_project_lite::pin_project;
24use std::borrow::Cow;
25use std::error::Error;
26use std::fmt::{Display, Formatter};
27use std::pin::Pin;
28use std::task::{ready, Context, Poll};
29pin_project! {
30    /// A body-wrapper that will calculate the `InnerBody`'s checksum and emit it as a trailer.
31    struct ContentLengthEnforcingBody<InnerBody> {
32            #[pin]
33            body: InnerBody,
34            expected_length: u64,
35            bytes_received: u64,
36    }
37}
38
39/// An error returned when a body did not have the expected content length
40#[derive(Debug)]
41pub struct ContentLengthError {
42    expected: u64,
43    received: u64,
44}
45
46impl Error for ContentLengthError {}
47
48impl Display for ContentLengthError {
49    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50        write!(
51            f,
52            "Invalid Content-Length: Expected {} bytes but {} bytes were received",
53            self.expected, self.received
54        )
55    }
56}
57
58impl ContentLengthEnforcingBody<SdkBody> {
59    /// Wraps an existing [`SdkBody`] in a content-length enforcement layer
60    fn wrap(body: SdkBody, content_length: u64) -> SdkBody {
61        body.map_preserve_contents(move |b| {
62            SdkBody::from_body_1_x(ContentLengthEnforcingBody {
63                body: b,
64                expected_length: content_length,
65                bytes_received: 0,
66            })
67        })
68    }
69}
70
71impl<
72        E: Into<aws_smithy_types::body::Error>,
73        Data: Buf,
74        InnerBody: http_body_1x::Body<Error = E, Data = Data>,
75    > http_body_1x::Body for ContentLengthEnforcingBody<InnerBody>
76{
77    type Data = Data;
78    type Error = aws_smithy_types::body::Error;
79
80    fn poll_frame(
81        mut self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
84        let this = self.as_mut().project();
85        match ready!(this.body.poll_frame(cx)) {
86            None => {
87                if *this.expected_length == *this.bytes_received {
88                    Poll::Ready(None)
89                } else {
90                    Poll::Ready(Some(Err(ContentLengthError {
91                        expected: *this.expected_length,
92                        received: *this.bytes_received,
93                    }
94                    .into())))
95                }
96            }
97            Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
98            Some(Ok(frame)) => {
99                if let Some(data) = frame.data_ref() {
100                    *this.bytes_received += data.remaining() as u64;
101                }
102                Poll::Ready(Some(Ok(frame)))
103            }
104        }
105    }
106
107    fn is_end_stream(&self) -> bool {
108        self.body.is_end_stream()
109    }
110
111    fn size_hint(&self) -> SizeHint {
112        self.body.size_hint()
113    }
114}
115
116#[derive(Debug, Default)]
117struct EnforceContentLengthInterceptor {}
118
119#[derive(Debug)]
120struct EnableContentLengthEnforcement;
121impl Storable for EnableContentLengthEnforcement {
122    type Storer = StoreReplace<EnableContentLengthEnforcement>;
123}
124
125#[dyn_dispatch_hint]
126impl Intercept for EnforceContentLengthInterceptor {
127    fn name(&self) -> &'static str {
128        "EnforceContentLength"
129    }
130
131    fn read_before_transmit(
132        &self,
133        context: &BeforeTransmitInterceptorContextRef<'_>,
134        _runtime_components: &RuntimeComponents,
135        cfg: &mut ConfigBag,
136    ) -> Result<(), BoxError> {
137        if context.request().method() == "GET" {
138            cfg.interceptor_state()
139                .store_put(EnableContentLengthEnforcement);
140        }
141        Ok(())
142    }
143    fn modify_before_deserialization(
144        &self,
145        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
146        _runtime_components: &RuntimeComponents,
147        cfg: &mut ConfigBag,
148    ) -> Result<(), BoxError> {
149        // if we didn't enable it for this request, bail out
150        if cfg.load::<EnableContentLengthEnforcement>().is_none() {
151            return Ok(());
152        }
153        let content_length = match extract_content_length(context.response()) {
154            Err(err) => {
155                tracing::warn!(err = ?err, "could not parse content length from content-length header. This header will be ignored");
156                return Ok(());
157            }
158            Ok(Some(content_length)) => content_length,
159            Ok(None) => return Ok(()),
160        };
161
162        tracing::trace!(
163            expected_length = content_length,
164            "Wrapping response body in content-length enforcement."
165        );
166
167        let body = context.response_mut().take_body();
168        let wrapped = body.map_preserve_contents(move |body| {
169            ContentLengthEnforcingBody::wrap(body, content_length)
170        });
171        *context.response_mut().body_mut() = wrapped;
172        Ok(())
173    }
174}
175
176fn extract_content_length<B>(response: &Response<B>) -> Result<Option<u64>, BoxError> {
177    let Some(content_length) = response.headers().get("content-length") else {
178        tracing::trace!("No content length header was set. Will not validate content length");
179        return Ok(None);
180    };
181    if response.headers().get_all("content-length").count() != 1 {
182        return Err("Found multiple content length headers. This is invalid".into());
183    }
184
185    Ok(Some(content_length.parse::<u64>()?))
186}
187
188/// Runtime plugin that enforces response bodies match their expected content length
189#[derive(Debug, Default)]
190pub struct EnforceContentLengthRuntimePlugin {}
191
192impl EnforceContentLengthRuntimePlugin {
193    /// Creates a runtime plugin which installs Content-Length enforcement middleware for response bodies
194    pub fn new() -> Self {
195        Self {}
196    }
197}
198
199impl RuntimePlugin for EnforceContentLengthRuntimePlugin {
200    fn runtime_components(
201        &self,
202        _current_components: &RuntimeComponentsBuilder,
203    ) -> Cow<'_, RuntimeComponentsBuilder> {
204        Cow::Owned(
205            RuntimeComponentsBuilder::new("EnforceContentLength").with_interceptor(
206                SharedInterceptor::permanent(EnforceContentLengthInterceptor {}),
207            ),
208        )
209    }
210}
211
212#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
213mod test {
214    use crate::assert_str_contains;
215    use crate::client::http::body::content_length_enforcement::{
216        extract_content_length, ContentLengthEnforcingBody,
217    };
218    use aws_smithy_runtime_api::http::Response;
219    use aws_smithy_types::body::SdkBody;
220    use aws_smithy_types::byte_stream::ByteStream;
221    use aws_smithy_types::error::display::DisplayErrorContext;
222    use bytes::Bytes;
223    use http_1x::header::CONTENT_LENGTH;
224    use http_body_1x::Frame;
225    use std::error::Error;
226    use std::pin::Pin;
227    use std::task::{Context, Poll};
228
229    /// Body for tests so we ensure our code works on a body split across multiple frames
230    struct ManyFrameBody {
231        data: Vec<u8>,
232    }
233
234    impl ManyFrameBody {
235        #[allow(clippy::new_ret_no_self)]
236        fn new(input: impl Into<String>) -> SdkBody {
237            let mut data = input.into().as_bytes().to_vec();
238            data.reverse();
239            SdkBody::from_body_1_x(Self { data })
240        }
241    }
242
243    impl http_body_1x::Body for ManyFrameBody {
244        type Data = Bytes;
245        type Error = <SdkBody as http_body_1x::Body>::Error;
246
247        fn poll_frame(
248            mut self: Pin<&mut Self>,
249            _cx: &mut Context<'_>,
250        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
251            match self.data.pop() {
252                Some(next) => Poll::Ready(Some(Ok(Frame::data(Bytes::from(vec![next]))))),
253                None => Poll::Ready(None),
254            }
255        }
256    }
257
258    #[tokio::test]
259    async fn stream_too_short() {
260        let body = ManyFrameBody::new("123");
261        let enforced = ContentLengthEnforcingBody::wrap(body, 10);
262        let err = expect_body_error(enforced).await;
263        assert_str_contains!(
264            format!("{}", DisplayErrorContext(err)),
265            "Expected 10 bytes but 3 bytes were received"
266        );
267    }
268
269    #[tokio::test]
270    async fn stream_too_long() {
271        let body = ManyFrameBody::new("abcdefghijk");
272        let enforced = ContentLengthEnforcingBody::wrap(body, 5);
273        let err = expect_body_error(enforced).await;
274        assert_str_contains!(
275            format!("{}", DisplayErrorContext(err)),
276            "Expected 5 bytes but 11 bytes were received"
277        );
278    }
279
280    #[tokio::test]
281    async fn stream_just_right() {
282        use http_body_util::BodyExt;
283        let body = ManyFrameBody::new("abcdefghijk");
284        let enforced = ContentLengthEnforcingBody::wrap(body, 11);
285        let data = enforced.collect().await.unwrap().to_bytes();
286        assert_eq!(b"abcdefghijk", data.as_ref());
287    }
288
289    async fn expect_body_error(body: SdkBody) -> impl Error {
290        ByteStream::new(body)
291            .collect()
292            .await
293            .expect_err("body should have failed")
294    }
295
296    #[test]
297    fn extract_header() {
298        let mut resp1 = Response::new(200.try_into().unwrap(), ());
299        resp1.headers_mut().insert(CONTENT_LENGTH, "123");
300        assert_eq!(extract_content_length(&resp1).unwrap(), Some(123));
301        resp1.headers_mut().append(CONTENT_LENGTH, "124");
302        // duplicate content length header
303        extract_content_length(&resp1).expect_err("duplicate headers");
304
305        // not an integer
306        resp1.headers_mut().insert(CONTENT_LENGTH, "-123.5");
307        extract_content_length(&resp1).expect_err("not an integer");
308
309        // not an integer
310        resp1.headers_mut().insert(CONTENT_LENGTH, "");
311        extract_content_length(&resp1).expect_err("empty");
312
313        resp1.headers_mut().remove(CONTENT_LENGTH);
314        assert_eq!(extract_content_length(&resp1).unwrap(), None);
315    }
316}