aws_smithy_runtime/client/http/body/
content_length_enforcement.rs1use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::{
10 BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextRef,
11};
12use aws_smithy_runtime_api::client::interceptors::SharedInterceptor;
13use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
14use aws_smithy_runtime_api::client::runtime_components::{
15 RuntimeComponents, RuntimeComponentsBuilder,
16};
17use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
18use aws_smithy_runtime_api::http::Response;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
21use bytes::Buf;
22use http_body_1x::{Frame, SizeHint};
23use pin_project_lite::pin_project;
24use std::borrow::Cow;
25use std::error::Error;
26use std::fmt::{Display, Formatter};
27use std::pin::Pin;
28use std::task::{ready, Context, Poll};
29pin_project! {
30 struct ContentLengthEnforcingBody<InnerBody> {
32 #[pin]
33 body: InnerBody,
34 expected_length: u64,
35 bytes_received: u64,
36 }
37}
38
39#[derive(Debug)]
41pub struct ContentLengthError {
42 expected: u64,
43 received: u64,
44}
45
46impl Error for ContentLengthError {}
47
48impl Display for ContentLengthError {
49 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50 write!(
51 f,
52 "Invalid Content-Length: Expected {} bytes but {} bytes were received",
53 self.expected, self.received
54 )
55 }
56}
57
58impl ContentLengthEnforcingBody<SdkBody> {
59 fn wrap(body: SdkBody, content_length: u64) -> SdkBody {
61 body.map_preserve_contents(move |b| {
62 SdkBody::from_body_1_x(ContentLengthEnforcingBody {
63 body: b,
64 expected_length: content_length,
65 bytes_received: 0,
66 })
67 })
68 }
69}
70
71impl<
72 E: Into<aws_smithy_types::body::Error>,
73 Data: Buf,
74 InnerBody: http_body_1x::Body<Error = E, Data = Data>,
75 > http_body_1x::Body for ContentLengthEnforcingBody<InnerBody>
76{
77 type Data = Data;
78 type Error = aws_smithy_types::body::Error;
79
80 fn poll_frame(
81 mut self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
84 let this = self.as_mut().project();
85 match ready!(this.body.poll_frame(cx)) {
86 None => {
87 if *this.expected_length == *this.bytes_received {
88 Poll::Ready(None)
89 } else {
90 Poll::Ready(Some(Err(ContentLengthError {
91 expected: *this.expected_length,
92 received: *this.bytes_received,
93 }
94 .into())))
95 }
96 }
97 Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
98 Some(Ok(frame)) => {
99 if let Some(data) = frame.data_ref() {
100 *this.bytes_received += data.remaining() as u64;
101 }
102 Poll::Ready(Some(Ok(frame)))
103 }
104 }
105 }
106
107 fn is_end_stream(&self) -> bool {
108 self.body.is_end_stream()
109 }
110
111 fn size_hint(&self) -> SizeHint {
112 self.body.size_hint()
113 }
114}
115
116#[derive(Debug, Default)]
117struct EnforceContentLengthInterceptor {}
118
119#[derive(Debug)]
120struct EnableContentLengthEnforcement;
121impl Storable for EnableContentLengthEnforcement {
122 type Storer = StoreReplace<EnableContentLengthEnforcement>;
123}
124
125#[dyn_dispatch_hint]
126impl Intercept for EnforceContentLengthInterceptor {
127 fn name(&self) -> &'static str {
128 "EnforceContentLength"
129 }
130
131 fn read_before_transmit(
132 &self,
133 context: &BeforeTransmitInterceptorContextRef<'_>,
134 _runtime_components: &RuntimeComponents,
135 cfg: &mut ConfigBag,
136 ) -> Result<(), BoxError> {
137 if context.request().method() == "GET" {
138 cfg.interceptor_state()
139 .store_put(EnableContentLengthEnforcement);
140 }
141 Ok(())
142 }
143 fn modify_before_deserialization(
144 &self,
145 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
146 _runtime_components: &RuntimeComponents,
147 cfg: &mut ConfigBag,
148 ) -> Result<(), BoxError> {
149 if cfg.load::<EnableContentLengthEnforcement>().is_none() {
151 return Ok(());
152 }
153 let content_length = match extract_content_length(context.response()) {
154 Err(err) => {
155 tracing::warn!(err = ?err, "could not parse content length from content-length header. This header will be ignored");
156 return Ok(());
157 }
158 Ok(Some(content_length)) => content_length,
159 Ok(None) => return Ok(()),
160 };
161
162 tracing::trace!(
163 expected_length = content_length,
164 "Wrapping response body in content-length enforcement."
165 );
166
167 let body = context.response_mut().take_body();
168 let wrapped = body.map_preserve_contents(move |body| {
169 ContentLengthEnforcingBody::wrap(body, content_length)
170 });
171 *context.response_mut().body_mut() = wrapped;
172 Ok(())
173 }
174}
175
176fn extract_content_length<B>(response: &Response<B>) -> Result<Option<u64>, BoxError> {
177 let Some(content_length) = response.headers().get("content-length") else {
178 tracing::trace!("No content length header was set. Will not validate content length");
179 return Ok(None);
180 };
181 if response.headers().get_all("content-length").count() != 1 {
182 return Err("Found multiple content length headers. This is invalid".into());
183 }
184
185 Ok(Some(content_length.parse::<u64>()?))
186}
187
188#[derive(Debug, Default)]
190pub struct EnforceContentLengthRuntimePlugin {}
191
192impl EnforceContentLengthRuntimePlugin {
193 pub fn new() -> Self {
195 Self {}
196 }
197}
198
199impl RuntimePlugin for EnforceContentLengthRuntimePlugin {
200 fn runtime_components(
201 &self,
202 _current_components: &RuntimeComponentsBuilder,
203 ) -> Cow<'_, RuntimeComponentsBuilder> {
204 Cow::Owned(
205 RuntimeComponentsBuilder::new("EnforceContentLength").with_interceptor(
206 SharedInterceptor::permanent(EnforceContentLengthInterceptor {}),
207 ),
208 )
209 }
210}
211
212#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
213mod test {
214 use crate::assert_str_contains;
215 use crate::client::http::body::content_length_enforcement::{
216 extract_content_length, ContentLengthEnforcingBody,
217 };
218 use aws_smithy_runtime_api::http::Response;
219 use aws_smithy_types::body::SdkBody;
220 use aws_smithy_types::byte_stream::ByteStream;
221 use aws_smithy_types::error::display::DisplayErrorContext;
222 use bytes::Bytes;
223 use http_1x::header::CONTENT_LENGTH;
224 use http_body_1x::Frame;
225 use std::error::Error;
226 use std::pin::Pin;
227 use std::task::{Context, Poll};
228
229 struct ManyFrameBody {
231 data: Vec<u8>,
232 }
233
234 impl ManyFrameBody {
235 #[allow(clippy::new_ret_no_self)]
236 fn new(input: impl Into<String>) -> SdkBody {
237 let mut data = input.into().as_bytes().to_vec();
238 data.reverse();
239 SdkBody::from_body_1_x(Self { data })
240 }
241 }
242
243 impl http_body_1x::Body for ManyFrameBody {
244 type Data = Bytes;
245 type Error = <SdkBody as http_body_1x::Body>::Error;
246
247 fn poll_frame(
248 mut self: Pin<&mut Self>,
249 _cx: &mut Context<'_>,
250 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
251 match self.data.pop() {
252 Some(next) => Poll::Ready(Some(Ok(Frame::data(Bytes::from(vec![next]))))),
253 None => Poll::Ready(None),
254 }
255 }
256 }
257
258 #[tokio::test]
259 async fn stream_too_short() {
260 let body = ManyFrameBody::new("123");
261 let enforced = ContentLengthEnforcingBody::wrap(body, 10);
262 let err = expect_body_error(enforced).await;
263 assert_str_contains!(
264 format!("{}", DisplayErrorContext(err)),
265 "Expected 10 bytes but 3 bytes were received"
266 );
267 }
268
269 #[tokio::test]
270 async fn stream_too_long() {
271 let body = ManyFrameBody::new("abcdefghijk");
272 let enforced = ContentLengthEnforcingBody::wrap(body, 5);
273 let err = expect_body_error(enforced).await;
274 assert_str_contains!(
275 format!("{}", DisplayErrorContext(err)),
276 "Expected 5 bytes but 11 bytes were received"
277 );
278 }
279
280 #[tokio::test]
281 async fn stream_just_right() {
282 use http_body_util::BodyExt;
283 let body = ManyFrameBody::new("abcdefghijk");
284 let enforced = ContentLengthEnforcingBody::wrap(body, 11);
285 let data = enforced.collect().await.unwrap().to_bytes();
286 assert_eq!(b"abcdefghijk", data.as_ref());
287 }
288
289 async fn expect_body_error(body: SdkBody) -> impl Error {
290 ByteStream::new(body)
291 .collect()
292 .await
293 .expect_err("body should have failed")
294 }
295
296 #[test]
297 fn extract_header() {
298 let mut resp1 = Response::new(200.try_into().unwrap(), ());
299 resp1.headers_mut().insert(CONTENT_LENGTH, "123");
300 assert_eq!(extract_content_length(&resp1).unwrap(), Some(123));
301 resp1.headers_mut().append(CONTENT_LENGTH, "124");
302 extract_content_length(&resp1).expect_err("duplicate headers");
304
305 resp1.headers_mut().insert(CONTENT_LENGTH, "-123.5");
307 extract_content_length(&resp1).expect_err("not an integer");
308
309 resp1.headers_mut().insert(CONTENT_LENGTH, "");
311 extract_content_length(&resp1).expect_err("empty");
312
313 resp1.headers_mut().remove(CONTENT_LENGTH);
314 assert_eq!(extract_content_length(&resp1).unwrap(), None);
315 }
316}