127 133 | }
|
128 134 |
|
129 135 | fn modify_before_serialization(
|
130 136 | &self,
|
131 137 | context: &mut BeforeSerializationInterceptorContextMut<'_>,
|
132 138 | _runtime_components: &RuntimeComponents,
|
133 139 | cfg: &mut ConfigBag,
|
134 140 | ) -> Result<(), BoxError> {
|
135 141 | let (checksum_algorithm, request_checksum_required) = (self.algorithm_provider)(context.input());
|
136 142 |
|
137 - | let mut layer = Layer::new("RequestChecksumInterceptor");
|
138 - | layer.store_put(RequestChecksumInterceptorState {
|
143 + | cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
|
139 144 | checksum_algorithm,
|
140 145 | request_checksum_required,
|
141 146 | checksum_cache: ChecksumCache::new(),
|
142 147 | calculate_checksum: Arc::new(AtomicBool::new(false)),
|
143 148 | });
|
144 - | cfg.push_layer(layer);
|
145 149 |
|
146 150 | Ok(())
|
147 151 | }
|
148 152 |
|
149 153 | /// Setup state for calculating checksum and setting UA features
|
150 154 | fn modify_before_retry_loop(
|
151 155 | &self,
|
152 156 | context: &mut BeforeTransmitInterceptorContextMut<'_>,
|
153 157 | _runtime_components: &RuntimeComponents,
|
154 158 | cfg: &mut ConfigBag,
|
155 159 | ) -> Result<(), BoxError> {
|
156 - | let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
|
157 - |
|
158 160 | let user_set_checksum_value = (self.checksum_mutator)(context.request_mut(), cfg).expect("Checksum header mutation should not fail");
|
161 + | let is_presigned = cfg.load::<PresigningMarker>().is_some();
|
159 162 |
|
160 - | // If the user manually set a checksum header we short circuit
|
161 - | if user_set_checksum_value {
|
163 + | // If the user manually set a checksum header or if this is a presigned request, we short circuit
|
164 + | if user_set_checksum_value || is_presigned {
|
165 + | // Disable aws-chunked encoding since either the user has set a custom checksum
|
166 + | cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
|
162 167 | return Ok(());
|
163 168 | }
|
164 169 |
|
165 - | // This value is from the trait, but is needed for runtime logic
|
166 - | let request_checksum_required = state.request_checksum_required;
|
170 + | let state = cfg
|
171 + | .get_mut_from_interceptor_state::<RequestChecksumInterceptorState>()
|
172 + | .expect("set in `read_before_serialization`");
|
167 173 |
|
168 174 | // If the algorithm fails to parse it is not one we support and we error
|
169 175 | let checksum_algorithm = state
|
170 176 | .checksum_algorithm
|
171 177 | .clone()
|
172 178 | .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
|
173 179 | .transpose()?;
|
174 180 |
|
175 - | // This value is set by the user on the SdkConfig to indicate their preference
|
176 - | // We provide a default here for users that use a client config instead of the SdkConfig
|
177 - | let request_checksum_calculation = cfg
|
178 - | .load::<RequestChecksumCalculation>()
|
179 - | .unwrap_or(&RequestChecksumCalculation::WhenSupported);
|
180 - |
|
181 - | // Need to know if this is a presigned req because we do not calculate checksums for those.
|
182 - | let is_presigned_req = cfg.load::<PresigningMarker>().is_some();
|
181 + | let mut state = std::mem::take(state);
|
183 182 |
|
184 - | // Determine if we actually calculate the checksum. If this is a presigned request we do not
|
185 - | // If the user setting is WhenSupported (the default) we always calculate it (because this interceptor
|
186 - | // isn't added if it isn't supported). If it is WhenRequired we only calculate it if the checksum
|
187 - | // is marked required on the trait.
|
188 - | let calculate_checksum = match (request_checksum_calculation, is_presigned_req) {
|
189 - | (_, true) => false,
|
190 - | (RequestChecksumCalculation::WhenRequired, false) => request_checksum_required,
|
191 - | (RequestChecksumCalculation::WhenSupported, false) => true,
|
192 - | _ => true,
|
193 - | };
|
183 + | if calculate_checksum(cfg, &state) {
|
184 + | state.calculate_checksum.store(true, Ordering::Release);
|
194 185 |
|
195 186 | // If a checksum override is set in the ConfigBag we use that instead (currently only used by S3Express)
|
196 187 | // If we have made it this far without a checksum being set we set the default (currently Crc32)
|
197 188 | let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
|
189 + | state.checksum_algorithm = Some(checksum_algorithm.as_str().to_owned());
|
198 190 |
|
199 - | if calculate_checksum {
|
200 - | state.calculate_checksum.store(true, Ordering::Release);
|
201 - |
|
202 - | // Set the user-agent metric for the selected checksum algorithm
|
203 191 | // NOTE: We have to do this in modify_before_retry_loop since UA interceptor also runs
|
204 192 | // in modify_before_signing but is registered before this interceptor (client level vs operation level).
|
205 - | match checksum_algorithm {
|
206 - | ChecksumAlgorithm::Crc32 => {
|
207 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
|
208 - | }
|
209 - | ChecksumAlgorithm::Crc32c => {
|
210 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
|
211 - | }
|
212 - | ChecksumAlgorithm::Crc64Nvme => {
|
213 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
|
214 - | }
|
215 - | #[allow(deprecated)]
|
216 - | ChecksumAlgorithm::Md5 => {
|
217 - | tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
|
218 - | }
|
219 - | ChecksumAlgorithm::Sha1 => {
|
220 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
|
221 - | }
|
222 - | ChecksumAlgorithm::Sha256 => {
|
223 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
|
224 - | }
|
225 - | unsupported => tracing::warn!(
|
226 - | more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
|
227 - | unsupported = ?unsupported),
|
228 - | }
|
193 + | track_metric_for_selected_checksum_algorithm(cfg, &checksum_algorithm);
|
194 + | } else {
|
195 + | // No checksum calculation needed so disable aws-chunked encoding
|
196 + | cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
|
229 197 | }
|
230 198 |
|
199 + | cfg.interceptor_state().store_put(state);
|
200 + |
|
231 201 | Ok(())
|
232 202 | }
|
233 203 |
|
234 - | /// Calculate a checksum and modify the request to include the checksum as a header
|
235 - | /// (for in-memory request bodies) or a trailer (for streaming request bodies).
|
236 - | /// Streaming bodies must be sized or this will return an error.
|
204 + | /// Calculate a checksum and modify the request to do either of the following:
|
205 + | /// - include the checksum as a header for signing with in-memory request bodies.
|
206 + | /// - include the checksum as a trailer for streaming request bodies.
|
237 207 | fn modify_before_signing(
|
238 208 | &self,
|
239 209 | context: &mut BeforeTransmitInterceptorContextMut<'_>,
|
240 210 | _runtime_components: &RuntimeComponents,
|
241 211 | cfg: &mut ConfigBag,
|
242 212 | ) -> Result<(), BoxError> {
|
243 213 | let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
|
244 214 |
|
245 - | let checksum_cache = state.checksum_cache.clone();
|
215 + | if !state.calculate_checksum() {
|
216 + | return Ok(());
|
217 + | }
|
246 218 |
|
247 - | let checksum_algorithm = state
|
248 - | .checksum_algorithm
|
249 - | .clone()
|
250 - | .map(|s| ChecksumAlgorithm::from_str(s.as_str()))
|
251 - | .transpose()?;
|
219 + | let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
|
220 + | let mut checksum = checksum_algorithm.into_impl();
|
252 221 |
|
253 - | let calculate_checksum = state.calculate_checksum.load(Ordering::SeqCst);
|
222 + | match context.request().body().bytes() {
|
223 + | Some(data) => {
|
224 + | tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
|
225 + | checksum.update(data);
|
254 226 |
|
255 - | // Calculate the checksum if necessary
|
256 - | if calculate_checksum {
|
257 - | // If a checksum override is set in the ConfigBag we use that instead (currently only used by S3Express)
|
258 - | // If we have made it this far without a checksum being set we set the default (currently Crc32)
|
259 - | let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
|
227 + | for (hdr_name, hdr_value) in get_or_cache_headers(checksum.headers(), &state.checksum_cache).iter() {
|
228 + | context.request_mut().headers_mut().insert(hdr_name.clone(), hdr_value.clone());
|
229 + | }
|
230 + | }
|
231 + | None => {
|
232 + | tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
|
233 + | context
|
234 + | .request_mut()
|
235 + | .headers_mut()
|
236 + | .insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
|
260 237 |
|
261 - | let request = context.request_mut();
|
262 - | add_checksum_for_request_body(request, checksum_algorithm, checksum_cache, cfg)?;
|
238 + | // Take checksum header into account for `AwsChunkedBodyOptions`'s trailer length
|
239 + | let trailer_len = HttpChecksum::size(checksum.as_ref());
|
240 + | let chunked_body_options = AwsChunkedBodyOptions::default().with_trailer_len(trailer_len);
|
241 + | cfg.interceptor_state().store_put(chunked_body_options);
|
242 + | }
|
263 243 | }
|
264 244 |
|
265 245 | Ok(())
|
266 246 | }
|
267 247 |
|
268 - | /// Set the user-agent metrics for `RequestChecksumCalculation` here to avoid ownership issues
|
269 - | /// with the mutable borrow of cfg in `modify_before_signing`
|
270 - | fn read_after_serialization(
|
248 + | fn modify_before_transmit(
|
271 249 | &self,
|
272 - | _context: &aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef<'_>,
|
250 + | ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
|
273 251 | _runtime_components: &RuntimeComponents,
|
274 252 | cfg: &mut ConfigBag,
|
275 253 | ) -> Result<(), BoxError> {
|
276 - | let request_checksum_calculation = cfg
|
277 - | .load::<RequestChecksumCalculation>()
|
278 - | .unwrap_or(&RequestChecksumCalculation::WhenSupported);
|
279 - |
|
280 - | match request_checksum_calculation {
|
281 - | RequestChecksumCalculation::WhenSupported => {
|
282 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
|
254 + | if ctx.request().body().bytes().is_some() {
|
255 + | // Nothing to do for non-streaming bodies since the checksum was added to the the header
|
256 + | // in `modify_before_signing` and signing has already been done by the time this hook is called.
|
257 + | return Ok(());
|
283 258 | }
|
284 - | RequestChecksumCalculation::WhenRequired => {
|
285 - | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
|
259 + |
|
260 + | let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
|
261 + |
|
262 + | if !state.calculate_checksum() {
|
263 + | return Ok(());
|
286 264 | }
|
287 - | unsupported => tracing::warn!(
|
288 - | more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
|
289 - | unsupported = ?unsupported),
|
265 + |
|
266 + | let request = ctx.request_mut();
|
267 + |
|
268 + | let mut body = {
|
269 + | let body = mem::replace(request.body_mut(), SdkBody::taken());
|
270 + |
|
271 + | let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
|
272 + | let checksum_cache = state.checksum_cache.clone();
|
273 + |
|
274 + | body.map(move |body| {
|
275 + | let checksum = checksum_algorithm.into_impl();
|
276 + | let body = calculate::ChecksumBody::new(body, checksum).with_cache(checksum_cache.clone());
|
277 + |
|
278 + | SdkBody::from_body_0_4(body)
|
279 + | })
|
290 280 | };
|
291 281 |
|
282 + | mem::swap(request.body_mut(), &mut body);
|
283 + |
|
292 284 | Ok(())
|
293 285 | }
|
294 286 | }
|
295 287 |
|
296 288 | fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
|
297 289 | match cfg.load::<DefaultRequestChecksumOverride>() {
|
298 290 | Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
|
299 291 | None => checksum,
|
300 292 | }
|
301 293 | }
|
302 294 |
|
303 - | fn add_checksum_for_request_body(
|
304 - | request: &mut HttpRequest,
|
305 - | checksum_algorithm: ChecksumAlgorithm,
|
306 - | checksum_cache: ChecksumCache,
|
307 - | cfg: &mut ConfigBag,
|
308 - | ) -> Result<(), BoxError> {
|
309 - | match request.body().bytes() {
|
310 - | // Body is in-memory: read it and insert the checksum as a header.
|
311 - | Some(data) => {
|
312 - | let mut checksum = checksum_algorithm.into_impl();
|
313 - |
|
314 - | // If the header has not already been set we set it. If it was already set by the user
|
315 - | // we do nothing and maintain their set value.
|
316 - | if request.headers().get(checksum.header_name()).is_none() {
|
317 - | tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
|
318 - | checksum.update(data);
|
319 - |
|
320 - | let calculated_headers = checksum.headers();
|
321 - | let checksum_headers = if let Some(cached_headers) = checksum_cache.get() {
|
295 + | fn get_or_cache_headers(calculated_headers: HeaderMap, checksum_cache: &ChecksumCache) -> HeaderMap {
|
296 + | if let Some(cached_headers) = checksum_cache.get() {
|
322 297 | if cached_headers != calculated_headers {
|
323 298 | tracing::warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
|
324 299 | }
|
325 300 | cached_headers
|
326 301 | } else {
|
327 302 | checksum_cache.set(calculated_headers.clone());
|
328 303 | calculated_headers
|
329 - | };
|
330 - |
|
331 - | for (hdr_name, hdr_value) in checksum_headers.iter() {
|
332 - | request.headers_mut().insert(hdr_name.clone(), hdr_value.clone());
|
333 304 | }
|
305 + | }
|
306 + |
|
307 + | // Determine if we actually calculate the checksum
|
308 + | fn calculate_checksum(cfg: &mut ConfigBag, state: &RequestChecksumInterceptorState) -> bool {
|
309 + | // This value is set by the user on the SdkConfig to indicate their preference
|
310 + | // We provide a default here for users that use a client config instead of the SdkConfig
|
311 + | let request_checksum_calculation = cfg
|
312 + | .load::<RequestChecksumCalculation>()
|
313 + | .unwrap_or(&RequestChecksumCalculation::WhenSupported);
|
314 + |
|
315 + | // If the user setting is WhenSupported (the default) we always calculate it (because this interceptor
|
316 + | // isn't added if it isn't supported). If it is WhenRequired we only calculate it if the checksum
|
317 + | // is marked required on the trait.
|
318 + | match request_checksum_calculation {
|
319 + | RequestChecksumCalculation::WhenRequired => {
|
320 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
|
321 + | state.request_checksum_required
|
334 322 | }
|
323 + | RequestChecksumCalculation::WhenSupported => {
|
324 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
|
325 + | true
|
335 326 | }
|
336 - | // Body is streaming: wrap the body so it will emit a checksum as a trailer.
|
337 - | None => {
|
338 - | tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
|
339 - | cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
|
340 - | wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm, checksum_cache.clone())?;
|
327 + | unsupported => {
|
328 + | tracing::warn!(
|
329 + | more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
|
330 + | unsupported = ?unsupported
|
331 + | );
|
332 + | true
|
341 333 | }
|
342 334 | }
|
343 - | Ok(())
|
344 335 | }
|
345 336 |
|
346 - | fn wrap_streaming_request_body_in_checksum_calculating_body(
|
347 - | request: &mut HttpRequest,
|
348 - | checksum_algorithm: ChecksumAlgorithm,
|
349 - | checksum_cache: ChecksumCache,
|
350 - | ) -> Result<(), BuildError> {
|
351 - | let checksum = checksum_algorithm.into_impl();
|
352 - |
|
353 - | // If the user already set the header value then do nothing and return early
|
354 - | if request.headers().get(checksum.header_name()).is_some() {
|
355 - | return Ok(());
|
337 + | // Set the user-agent metric for the selected checksum algorithm
|
338 + | fn track_metric_for_selected_checksum_algorithm(cfg: &mut ConfigBag, checksum_algorithm: &ChecksumAlgorithm) {
|
339 + | match checksum_algorithm {
|
340 + | ChecksumAlgorithm::Crc32 => {
|
341 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
|
342 + | }
|
343 + | ChecksumAlgorithm::Crc32c => {
|
344 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
|
345 + | }
|
346 + | ChecksumAlgorithm::Crc64Nvme => {
|
347 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
|
348 + | }
|
349 + | #[allow(deprecated)]
|
350 + | ChecksumAlgorithm::Md5 => {
|
351 + | tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
|
352 + | }
|
353 + | ChecksumAlgorithm::Sha1 => {
|
354 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
|
355 + | }
|
356 + | ChecksumAlgorithm::Sha256 => {
|
357 + | cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
|
358 + | }
|
359 + | unsupported => tracing::warn!(
|
360 + | more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
|
361 + | unsupported = ?unsupported),
|
356 362 | }
|
357 - |
|
358 - | let original_body_size = request
|
359 - | .body()
|
360 - | .size_hint()
|
361 - | .exact()
|
362 - | .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
|
363 - |
|
364 - | let mut body = {
|
365 - | let body = mem::replace(request.body_mut(), SdkBody::taken());
|
366 - |
|
367 - | body.map(move |body| {
|
368 - | let checksum = checksum_algorithm.into_impl();
|
369 - | let trailer_len = HttpChecksum::size(checksum.as_ref());
|
370 - | let body = calculate::ChecksumBody::new(body, checksum).with_cache(checksum_cache.clone());
|
371 - | let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
|
372 - |
|
373 - | let body = AwsChunkedBody::new(body, aws_chunked_body_options);
|
374 - |
|
375 - | SdkBody::from_body_0_4(body)
|
376 - | })
|
377 - | };
|
378 - |
|
379 - | let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
|
380 - |
|
381 - | let headers = request.headers_mut();
|
382 - |
|
383 - | headers.insert(http::header::HeaderName::from_static("x-amz-trailer"), checksum.header_name());
|
384 - |
|
385 - | headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
|
386 - | headers.insert(
|
387 - | http::header::HeaderName::from_static("x-amz-decoded-content-length"),
|
388 - | HeaderValue::from(original_body_size),
|
389 - | );
|
390 - | // The target service does not depend on where `aws-chunked` appears in the `Content-Encoding` header,
|
391 - | // as it will ultimately be stripped.
|
392 - | headers.append(
|
393 - | http::header::CONTENT_ENCODING,
|
394 - | HeaderValue::from_str(AWS_CHUNKED)
|
395 - | .map_err(BuildError::other)
|
396 - | .expect("\"aws-chunked\" will always be a valid HeaderValue"),
|
397 - | );
|
398 - |
|
399 - | mem::swap(request.body_mut(), &mut body);
|
400 - |
|
401 - | Ok(())
|
402 363 | }
|
403 364 |
|
404 365 | #[cfg(test)]
|
405 366 | mod tests {
|
406 - | use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
|
407 - | use aws_smithy_checksums::body::ChecksumCache;
|
367 + | use super::*;
|
408 368 | use aws_smithy_checksums::ChecksumAlgorithm;
|
369 + | use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, InterceptorContext};
|
409 370 | use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
|
371 + | use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
|
410 372 | use aws_smithy_types::base64;
|
411 - | use aws_smithy_types::body::SdkBody;
|
412 373 | use aws_smithy_types::byte_stream::ByteStream;
|
413 374 | use bytes::BytesMut;
|
414 375 | use http_body::Body;
|
415 376 | use tempfile::NamedTempFile;
|
416 377 |
|
417 - | #[tokio::test]
|
418 - | async fn test_checksum_body_is_retryable() {
|
419 - | let input_text = "Hello world";
|
420 - | let chunk_len_hex = format!("{:X}", input_text.len());
|
421 - | let mut request: HttpRequest = http::Request::builder()
|
422 - | .body(SdkBody::retryable(move || SdkBody::from(input_text)))
|
423 - | .unwrap()
|
424 - | .try_into()
|
425 - | .unwrap();
|
426 - |
|
427 - | // ensure original SdkBody is retryable
|
428 - | assert!(request.body().try_clone().is_some());
|
429 - |
|
430 - | let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
|
431 - | let checksum_cache = ChecksumCache::new();
|
432 - | wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm, checksum_cache).unwrap();
|
433 - |
|
434 - | // ensure wrapped SdkBody is retryable
|
435 - | let mut body = request.body().try_clone().expect("body is retryable");
|
436 - |
|
437 - | let mut body_data = BytesMut::new();
|
438 - | while let Some(data) = body.data().await {
|
439 - | body_data.extend_from_slice(&data.unwrap())
|
378 + | fn create_test_interceptor() -> RequestChecksumInterceptor<
|
379 + | impl Fn(&Input) -> (Option<String>, bool) + Send + Sync,
|
380 + | impl Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
|
381 + | > {
|
382 + | fn algo(_: &Input) -> (Option<String>, bool) {
|
383 + | (Some("crc32".to_string()), false)
|
440 384 | }
|
441 - | let body = std::str::from_utf8(&body_data).unwrap();
|
442 - | assert_eq!(
|
443 - | format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
|
444 - | body
|
445 - | );
|
385 + | fn mutator(_: &mut Request, _: &ConfigBag) -> Result<bool, BoxError> {
|
386 + | Ok(false)
|
387 + | }
|
388 + | RequestChecksumInterceptor::new(algo, mutator)
|
446 389 | }
|
447 390 |
|
448 391 | #[tokio::test]
|
449 - | async fn test_checksum_body_from_file_is_retryable() {
|
392 + | async fn test_checksum_body_is_retryable() {
|
450 393 | use std::io::Write;
|
451 394 | let mut file = NamedTempFile::new().unwrap();
|
452 - | let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
|
395 + | let algorithm_str = "crc32c";
|
396 + | let checksum_algorithm: ChecksumAlgorithm = algorithm_str.parse().unwrap();
|
453 397 |
|
454 398 | let mut crc32c_checksum = checksum_algorithm.into_impl();
|
455 399 | for i in 0..10000 {
|
456 400 | let line = format!("This is a large file created for testing purposes {}", i);
|
457 401 | file.as_file_mut().write_all(line.as_bytes()).unwrap();
|
458 402 | crc32c_checksum.update(line.as_bytes());
|
459 403 | }
|
460 404 | let crc32c_checksum = crc32c_checksum.finalize();
|
461 405 |
|
462 - | let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
|
406 + | let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
|
463 407 |
|
464 408 | // ensure original SdkBody is retryable
|
465 409 | assert!(request.body().try_clone().is_some());
|
466 410 |
|
467 - | let checksum_cache = ChecksumCache::new();
|
468 - | wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm, checksum_cache).unwrap();
|
411 + | let interceptor = create_test_interceptor();
|
412 + | let mut cfg = ConfigBag::base();
|
413 + | cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
|
414 + | checksum_algorithm: Some(algorithm_str.to_string()),
|
415 + | calculate_checksum: Arc::new(AtomicBool::new(true)),
|
416 + | ..Default::default()
|
417 + | });
|
418 + | let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
|
419 + | let mut ctx = InterceptorContext::new(Input::doesnt_matter());
|
420 + | ctx.enter_serialization_phase();
|
421 + | let _ = ctx.take_input();
|
422 + | ctx.set_request(request);
|
423 + | ctx.enter_before_transmit_phase();
|
424 + | let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
|
425 + | interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
|
469 426 |
|
470 427 | // ensure wrapped SdkBody is retryable
|
471 - | let mut body = request.body().try_clone().expect("body is retryable");
|
428 + | let mut body = ctx.request().body().try_clone().expect("body is retryable");
|
472 429 |
|
473 430 | let mut body_data = BytesMut::new();
|
474 431 | while let Some(data) = body.data().await {
|
475 432 | body_data.extend_from_slice(&data.unwrap())
|
476 433 | }
|
477 - | let body = std::str::from_utf8(&body_data).unwrap();
|
434 + | let body_str = std::str::from_utf8(&body_data).unwrap();
|
435 + | let expected = format!("This is a large file created for testing purposes 9999");
|
436 + | assert!(body_str.ends_with(&expected), "expected '{body_str}' to end with '{expected}'");
|
478 437 | let expected_checksum = base64::encode(&crc32c_checksum);
|
479 - | let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
|
480 - | assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
|
438 + | while let Ok(Some(trailer)) = body.trailers().await {
|
439 + | if let Some(header_value) = trailer.get("x-amz-checksum-crc32c") {
|
440 + | let header_value = header_value.to_str().unwrap();
|
441 + | assert_eq!(
|
442 + | header_value, expected_checksum,
|
443 + | "expected checksum '{header_value}' to match '{expected_checksum}'"
|
444 + | );
|
445 + | }
|
446 + | }
|
481 447 | }
|
482 448 | }
|