use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
use crate::client::result::SdkError;
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::type_erasure::{TypeErasedBox, TypeErasedError};
use phase::Phase;
use std::fmt::Debug;
use std::{fmt, mem};
use tracing::{debug, error, trace};
macro_rules! new_type_box {
($name:ident, $doc:literal) => {
new_type_box!($name, TypeErasedBox, $doc, Send, Sync, fmt::Debug,);
};
($name:ident, $underlying:ident, $doc:literal, $($additional_bound:path,)*) => {
#[doc = $doc]
#[derive(Debug)]
pub struct $name($underlying);
impl $name {
#[doc = concat!("Creates a new `", stringify!($name), "` with the provided concrete input value.")]
pub fn erase<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(input: T) -> Self {
Self($underlying::new(input))
}
#[doc = concat!("Downcasts to the concrete input value.")]
pub fn downcast_ref<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&self) -> Option<&T> {
self.0.downcast_ref()
}
#[doc = concat!("Downcasts to the concrete input value.")]
pub fn downcast_mut<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&mut self) -> Option<&mut T> {
self.0.downcast_mut()
}
#[doc = concat!("Downcasts to the concrete input value.")]
pub fn downcast<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(self) -> Result<T, Self> {
self.0.downcast::<T>().map(|v| *v).map_err(Self)
}
#[doc = concat!("Returns a `", stringify!($name), "` with a fake/test value with the expectation that it won't be downcast in the test.")]
#[cfg(feature = "test-util")]
pub fn doesnt_matter() -> Self {
Self($underlying::doesnt_matter())
}
}
};
}
new_type_box!(Input, "Type-erased operation input.");
new_type_box!(Output, "Type-erased operation output.");
new_type_box!(
Error,
TypeErasedError,
"Type-erased operation error.",
std::error::Error,
);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
pub type OutputOrError = Result<Output, OrchestratorError<Error>>;
type Request = HttpRequest;
type Response = HttpResponse;
pub use wrappers::{
AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut,
BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut,
BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
FinalizerInterceptorContextRef,
};
mod wrappers;
pub(crate) mod phase;
#[derive(Debug)]
pub struct InterceptorContext<I = Input, O = Output, E = Error> {
pub(crate) input: Option<I>,
pub(crate) output_or_error: Option<Result<O, OrchestratorError<E>>>,
pub(crate) request: Option<Request>,
pub(crate) response: Option<Response>,
phase: Phase,
tainted: bool,
request_checkpoint: Option<HttpRequest>,
}
impl InterceptorContext<Input, Output, Error> {
pub fn new(input: Input) -> InterceptorContext<Input, Output, Error> {
InterceptorContext {
input: Some(input),
output_or_error: None,
request: None,
response: None,
phase: Phase::BeforeSerialization,
tainted: false,
request_checkpoint: None,
}
}
}
impl<I, O, E> InterceptorContext<I, O, E> {
pub fn input(&self) -> Option<&I> {
self.input.as_ref()
}
pub fn input_mut(&mut self) -> Option<&mut I> {
self.input.as_mut()
}
pub fn take_input(&mut self) -> Option<I> {
self.input.take()
}
pub fn set_request(&mut self, request: Request) {
self.request = Some(request);
}
pub fn request(&self) -> Option<&Request> {
self.request.as_ref()
}
pub fn request_mut(&mut self) -> Option<&mut Request> {
self.request.as_mut()
}
pub fn take_request(&mut self) -> Option<Request> {
self.request.take()
}
pub fn set_response(&mut self, response: Response) {
self.response = Some(response);
}
pub fn response(&self) -> Option<&Response> {
self.response.as_ref()
}
pub fn response_mut(&mut self) -> Option<&mut Response> {
self.response.as_mut()
}
pub fn set_output_or_error(&mut self, output: Result<O, OrchestratorError<E>>) {
self.output_or_error = Some(output);
}
pub fn output_or_error(&self) -> Option<Result<&O, &OrchestratorError<E>>> {
self.output_or_error.as_ref().map(Result::as_ref)
}
pub fn output_or_error_mut(&mut self) -> Option<&mut Result<O, OrchestratorError<E>>> {
self.output_or_error.as_mut()
}
pub fn take_output_or_error(&mut self) -> Option<Result<O, OrchestratorError<E>>> {
self.output_or_error.take()
}
pub fn is_failed(&self) -> bool {
self.output_or_error
.as_ref()
.map(Result::is_err)
.unwrap_or_default()
}
pub fn enter_serialization_phase(&mut self) {
debug!("entering \'serialization\' phase");
debug_assert!(
self.phase.is_before_serialization(),
"called enter_serialization_phase but phase is not before 'serialization'"
);
self.phase = Phase::Serialization;
}
pub fn enter_before_transmit_phase(&mut self) {
debug!("entering \'before transmit\' phase");
debug_assert!(
self.phase.is_serialization(),
"called enter_before_transmit_phase but phase is not 'serialization'"
);
debug_assert!(
self.input.is_none(),
"input must be taken before calling enter_before_transmit_phase"
);
debug_assert!(
self.request.is_some(),
"request must be set before calling enter_before_transmit_phase"
);
self.request_checkpoint = self.request().expect("checked above").try_clone();
self.phase = Phase::BeforeTransmit;
}
pub fn enter_transmit_phase(&mut self) {
debug!("entering \'transmit\' phase");
debug_assert!(
self.phase.is_before_transmit(),
"called enter_transmit_phase but phase is not before transmit"
);
self.phase = Phase::Transmit;
}
pub fn enter_before_deserialization_phase(&mut self) {
debug!("entering \'before deserialization\' phase");
debug_assert!(
self.phase.is_transmit(),
"called enter_before_deserialization_phase but phase is not 'transmit'"
);
debug_assert!(
self.request.is_none(),
"request must be taken before entering the 'before deserialization' phase"
);
debug_assert!(
self.response.is_some(),
"response must be set to before entering the 'before deserialization' phase"
);
self.phase = Phase::BeforeDeserialization;
}
pub fn enter_deserialization_phase(&mut self) {
debug!("entering \'deserialization\' phase");
debug_assert!(
self.phase.is_before_deserialization(),
"called enter_deserialization_phase but phase is not 'before deserialization'"
);
self.phase = Phase::Deserialization;
}
pub fn enter_after_deserialization_phase(&mut self) {
debug!("entering \'after deserialization\' phase");
debug_assert!(
self.phase.is_deserialization(),
"called enter_after_deserialization_phase but phase is not 'deserialization'"
);
debug_assert!(
self.output_or_error.is_some(),
"output must be set to before entering the 'after deserialization' phase"
);
self.phase = Phase::AfterDeserialization;
}
pub fn save_checkpoint(&mut self) {
trace!("saving request checkpoint...");
self.request_checkpoint = self.request().and_then(|r| r.try_clone());
match self.request_checkpoint.as_ref() {
Some(_) => trace!("successfully saved request checkpoint"),
None => trace!("failed to save request checkpoint: request body could not be cloned"),
}
}
pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> RewindResult {
let request_checkpoint = match (self.request_checkpoint.as_ref(), self.tainted) {
(None, true) => return RewindResult::Impossible,
(_, false) => {
self.tainted = true;
return RewindResult::Unnecessary;
}
(Some(req), _) => req.try_clone(),
};
self.phase = Phase::BeforeTransmit;
self.request = request_checkpoint;
assert!(
self.request.is_some(),
"if the request wasn't cloneable, then we should have already returned from this method."
);
self.response = None;
self.output_or_error = None;
RewindResult::Occurred
}
}
impl<I, O, E> InterceptorContext<I, O, E>
where
E: Debug,
{
#[allow(clippy::type_complexity)]
pub fn into_parts(
self,
) -> (
Option<I>,
Option<Result<O, OrchestratorError<E>>>,
Option<Request>,
Option<Response>,
) {
(
self.input,
self.output_or_error,
self.request,
self.response,
)
}
pub fn finalize(mut self) -> Result<O, SdkError<E, HttpResponse>> {
let output_or_error = self
.output_or_error
.take()
.expect("output_or_error must always be set before finalize is called.");
self.finalize_result(output_or_error)
}
pub fn finalize_result(
&mut self,
result: Result<O, OrchestratorError<E>>,
) -> Result<O, SdkError<E, HttpResponse>> {
let response = self.response.take();
result.map_err(|error| OrchestratorError::into_sdk_error(error, &self.phase, response))
}
pub fn fail(&mut self, error: OrchestratorError<E>) {
if !self.is_failed() {
trace!(
"orchestrator is transitioning to the 'failure' phase from the '{:?}' phase",
self.phase
);
}
if let Some(Err(existing_err)) = mem::replace(&mut self.output_or_error, Some(Err(error))) {
error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err);
}
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum RewindResult {
Impossible,
Unnecessary,
Occurred,
}
impl fmt::Display for RewindResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RewindResult::Impossible => write!(
f,
"The request couldn't be rewound because it wasn't cloneable."
),
RewindResult::Unnecessary => {
write!(f, "The request wasn't rewound because it was unnecessary.")
}
RewindResult::Occurred => write!(f, "The request was rewound successfully."),
}
}
}
#[cfg(all(test, feature = "test-util", feature = "http-02x"))]
mod tests {
use super::*;
use aws_smithy_types::body::SdkBody;
use http_02x::header::{AUTHORIZATION, CONTENT_LENGTH};
use http_02x::{HeaderValue, Uri};
#[test]
fn test_success_transitions() {
let input = Input::doesnt_matter();
let output = Output::erase("output".to_string());
let mut context = InterceptorContext::new(input);
assert!(context.input().is_some());
context.input_mut();
context.enter_serialization_phase();
let _ = context.take_input();
context.set_request(HttpRequest::empty());
context.enter_before_transmit_phase();
context.request();
context.request_mut();
context.enter_transmit_phase();
let _ = context.take_request();
context.set_response(
http_02x::Response::builder()
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap(),
);
context.enter_before_deserialization_phase();
context.response();
context.response_mut();
context.enter_deserialization_phase();
context.response();
context.response_mut();
context.set_output_or_error(Ok(output));
context.enter_after_deserialization_phase();
context.response();
context.response_mut();
let _ = context.output_or_error();
let _ = context.output_or_error_mut();
let output = context.output_or_error.unwrap().expect("success");
assert_eq!("output", output.downcast_ref::<String>().unwrap());
}
#[test]
fn test_rewind_for_retry() {
let mut cfg = ConfigBag::base();
let input = Input::doesnt_matter();
let output = Output::erase("output".to_string());
let error = Error::doesnt_matter();
let mut context = InterceptorContext::new(input);
assert!(context.input().is_some());
context.enter_serialization_phase();
let _ = context.take_input();
context.set_request(
http_02x::Request::builder()
.header("test", "the-original-un-mutated-request")
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap(),
);
context.enter_before_transmit_phase();
context.save_checkpoint();
assert_eq!(context.rewind(&mut cfg), RewindResult::Unnecessary);
context.request_mut().unwrap().headers_mut().remove("test");
context.request_mut().unwrap().headers_mut().insert(
"test",
HeaderValue::from_static("request-modified-after-signing"),
);
context.enter_transmit_phase();
let request = context.take_request().unwrap();
assert_eq!(
"request-modified-after-signing",
request.headers().get("test").unwrap()
);
context.set_response(
http_02x::Response::builder()
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap(),
);
context.enter_before_deserialization_phase();
context.enter_deserialization_phase();
context.set_output_or_error(Err(OrchestratorError::operation(error)));
assert_eq!(context.rewind(&mut cfg), RewindResult::Occurred);
assert_eq!(
"the-original-un-mutated-request",
context.request().unwrap().headers().get("test").unwrap()
);
context.enter_transmit_phase();
let _ = context.take_request();
context.set_response(
http_02x::Response::builder()
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap(),
);
context.enter_before_deserialization_phase();
context.enter_deserialization_phase();
context.set_output_or_error(Ok(output));
context.enter_after_deserialization_phase();
let output = context.output_or_error.unwrap().expect("success");
assert_eq!("output", output.downcast_ref::<String>().unwrap());
}
#[test]
fn try_clone_clones_all_data() {
let request: HttpRequest = http_02x::Request::builder()
.uri(Uri::from_static("https://www.amazon.com"))
.method("POST")
.header(CONTENT_LENGTH, 456)
.header(AUTHORIZATION, "Token: hello")
.body(SdkBody::from("hello world!"))
.expect("valid request")
.try_into()
.unwrap();
let cloned = request.try_clone().expect("request is cloneable");
assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri());
assert_eq!("POST", cloned.method());
assert_eq!(2, cloned.headers().len());
assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
}
}