368 368 | use aws_smithy_checksums::ChecksumAlgorithm;
|
369 369 | use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, InterceptorContext};
|
370 370 | use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
|
371 371 | use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
|
372 372 | use aws_smithy_types::base64;
|
373 373 | use aws_smithy_types::byte_stream::ByteStream;
|
374 374 | use bytes::BytesMut;
|
375 375 | use http_body::Body;
|
376 376 | use tempfile::NamedTempFile;
|
377 377 |
|
378 + | #[allow(clippy::type_complexity)]
|
378 379 | fn create_test_interceptor() -> RequestChecksumInterceptor<
|
379 380 | impl Fn(&Input) -> (Option<String>, bool) + Send + Sync,
|
380 381 | impl Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
|
381 382 | > {
|
382 383 | fn algo(_: &Input) -> (Option<String>, bool) {
|
383 384 | (Some("crc32".to_string()), false)
|
384 385 | }
|
385 386 | fn mutator(_: &mut Request, _: &ConfigBag) -> Result<bool, BoxError> {
|
386 387 | Ok(false)
|
387 388 | }
|
388 389 | RequestChecksumInterceptor::new(algo, mutator)
|
389 390 | }
|
390 391 |
|
391 392 | #[tokio::test]
|
392 393 | async fn test_checksum_body_is_retryable() {
|
393 394 | use std::io::Write;
|
394 395 | let mut file = NamedTempFile::new().unwrap();
|
395 396 | let algorithm_str = "crc32c";
|
396 397 | let checksum_algorithm: ChecksumAlgorithm = algorithm_str.parse().unwrap();
|
397 398 |
|
398 399 | let mut crc32c_checksum = checksum_algorithm.into_impl();
|
399 400 | for i in 0..10000 {
|
400 - | let line = format!("This is a large file created for testing purposes {}", i);
|
401 + | let line = format!("This is a large file created for testing purposes {i}");
|
401 402 | file.as_file_mut().write_all(line.as_bytes()).unwrap();
|
402 403 | crc32c_checksum.update(line.as_bytes());
|
403 404 | }
|
404 405 | let crc32c_checksum = crc32c_checksum.finalize();
|
405 406 |
|
406 407 | let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
|
407 408 |
|
408 409 | // ensure original SdkBody is retryable
|
409 410 | assert!(request.body().try_clone().is_some());
|
410 411 |
|
411 412 | let interceptor = create_test_interceptor();
|
412 413 | let mut cfg = ConfigBag::base();
|
413 414 | cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
|
414 415 | checksum_algorithm: Some(algorithm_str.to_string()),
|
415 416 | calculate_checksum: Arc::new(AtomicBool::new(true)),
|
416 417 | ..Default::default()
|
417 418 | });
|
418 419 | let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
|
419 420 | let mut ctx = InterceptorContext::new(Input::doesnt_matter());
|
420 421 | ctx.enter_serialization_phase();
|
421 422 | let _ = ctx.take_input();
|
422 423 | ctx.set_request(request);
|
423 424 | ctx.enter_before_transmit_phase();
|
424 425 | let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
|
425 426 | interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
|
426 427 |
|
427 428 | // ensure wrapped SdkBody is retryable
|
428 429 | let mut body = ctx.request().body().try_clone().expect("body is retryable");
|
429 430 |
|
430 431 | let mut body_data = BytesMut::new();
|
431 432 | while let Some(data) = body.data().await {
|
432 433 | body_data.extend_from_slice(&data.unwrap())
|
433 434 | }
|
434 435 | let body_str = std::str::from_utf8(&body_data).unwrap();
|
435 - | let expected = format!("This is a large file created for testing purposes 9999");
|
436 + | let expected = "This is a large file created for testing purposes 9999".to_string();
|
436 437 | assert!(body_str.ends_with(&expected), "expected '{body_str}' to end with '{expected}'");
|
437 438 | let expected_checksum = base64::encode(&crc32c_checksum);
|
438 439 | while let Ok(Some(trailer)) = body.trailers().await {
|
439 440 | if let Some(header_value) = trailer.get("x-amz-checksum-crc32c") {
|
440 441 | let header_value = header_value.to_str().unwrap();
|
441 442 | assert_eq!(
|
442 443 | header_value, expected_checksum,
|
443 444 | "expected checksum '{header_value}' to match '{expected_checksum}'"
|
444 445 | );
|
445 446 | }
|