aws_smithy_runtime_api/client/interceptors/
context.rs1use crate::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
31use crate::client::result::SdkError;
32use aws_smithy_types::config_bag::ConfigBag;
33use aws_smithy_types::type_erasure::{TypeErasedBox, TypeErasedError};
34use phase::Phase;
35use std::fmt;
36use std::fmt::Debug;
37use tracing::{debug, error, trace};
38
39macro_rules! new_type_box {
40 ($name:ident, $doc:literal) => {
41 new_type_box!($name, TypeErasedBox, $doc, Send, Sync, fmt::Debug,);
42 };
43 ($name:ident, $underlying:ident, $doc:literal, $($additional_bound:path,)*) => {
44 #[doc = $doc]
45 #[derive(Debug)]
46 pub struct $name($underlying);
47
48 impl $name {
49 #[doc = concat!("Creates a new `", stringify!($name), "` with the provided concrete input value.")]
50 pub fn erase<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(input: T) -> Self {
51 Self($underlying::new(input))
52 }
53
54 #[doc = concat!("Downcasts to the concrete input value.")]
55 pub fn downcast_ref<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&self) -> Option<&T> {
56 self.0.downcast_ref()
57 }
58
59 #[doc = concat!("Downcasts to the concrete input value.")]
60 pub fn downcast_mut<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(&mut self) -> Option<&mut T> {
61 self.0.downcast_mut()
62 }
63
64 #[doc = concat!("Downcasts to the concrete input value.")]
65 pub fn downcast<T: $($additional_bound +)* Send + Sync + fmt::Debug + 'static>(self) -> Result<T, Self> {
66 self.0.downcast::<T>().map(|v| *v).map_err(Self)
67 }
68
69 #[doc = concat!("Returns a `", stringify!($name), "` with a fake/test value with the expectation that it won't be downcast in the test.")]
70 #[cfg(feature = "test-util")]
71 pub fn doesnt_matter() -> Self {
72 Self($underlying::doesnt_matter())
73 }
74 }
75 };
76}
77
78new_type_box!(Input, "Type-erased operation input.");
79new_type_box!(Output, "Type-erased operation output.");
80new_type_box!(
81 Error,
82 TypeErasedError,
83 "Type-erased operation error.",
84 std::error::Error,
85);
86
87impl fmt::Display for Error {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 fmt::Display::fmt(&self.0, f)
90 }
91}
92impl std::error::Error for Error {
93 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
94 self.0.source()
95 }
96}
97
98pub type OutputOrError = Result<Output, OrchestratorError<Error>>;
100
101type Request = HttpRequest;
102type Response = HttpResponse;
103
104pub use wrappers::{
105 AfterDeserializationInterceptorContextRef, BeforeDeserializationInterceptorContextMut,
106 BeforeDeserializationInterceptorContextRef, BeforeSerializationInterceptorContextMut,
107 BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
108 BeforeTransmitInterceptorContextRef, FinalizerInterceptorContextMut,
109 FinalizerInterceptorContextRef,
110};
111
112mod wrappers;
113
114pub(crate) mod phase;
116
117#[derive(Debug)]
123pub struct InterceptorContext<I = Input, O = Output, E = Error> {
124 pub(crate) input: Option<I>,
125 pub(crate) output_or_error: Option<Result<O, OrchestratorError<E>>>,
126 pub(crate) request: Option<Request>,
127 pub(crate) response: Option<Response>,
128 phase: Phase,
129 tainted: bool,
130 request_checkpoint: Option<HttpRequest>,
131}
132
133impl InterceptorContext<Input, Output, Error> {
134 pub fn new(input: Input) -> InterceptorContext<Input, Output, Error> {
136 InterceptorContext {
137 input: Some(input),
138 output_or_error: None,
139 request: None,
140 response: None,
141 phase: Phase::BeforeSerialization,
142 tainted: false,
143 request_checkpoint: None,
144 }
145 }
146}
147
148impl<I, O, E> InterceptorContext<I, O, E> {
149 pub fn input(&self) -> Option<&I> {
153 self.input.as_ref()
154 }
155
156 pub fn input_mut(&mut self) -> Option<&mut I> {
160 self.input.as_mut()
161 }
162
163 pub fn take_input(&mut self) -> Option<I> {
167 self.input.take()
168 }
169
170 pub fn set_request(&mut self, request: Request) {
174 self.request = Some(request);
175 }
176
177 pub fn request(&self) -> Option<&Request> {
182 self.request.as_ref()
183 }
184
185 pub fn request_mut(&mut self) -> Option<&mut Request> {
190 self.request.as_mut()
191 }
192
193 pub fn take_request(&mut self) -> Option<Request> {
197 self.request.take()
198 }
199
200 pub fn set_response(&mut self, response: Response) {
204 self.response = Some(response);
205 }
206
207 pub fn response(&self) -> Option<&Response> {
211 self.response.as_ref()
212 }
213
214 pub fn response_mut(&mut self) -> Option<&mut Response> {
218 self.response.as_mut()
219 }
220
221 pub fn set_output_or_error(&mut self, output: Result<O, OrchestratorError<E>>) {
225 self.output_or_error = Some(output);
226 }
227
228 pub fn output_or_error(&self) -> Option<Result<&O, &OrchestratorError<E>>> {
232 self.output_or_error.as_ref().map(Result::as_ref)
233 }
234
235 pub fn output_or_error_mut(&mut self) -> Option<&mut Result<O, OrchestratorError<E>>> {
239 self.output_or_error.as_mut()
240 }
241
242 pub fn take_output_or_error(&mut self) -> Option<Result<O, OrchestratorError<E>>> {
246 self.output_or_error.take()
247 }
248
249 pub fn is_failed(&self) -> bool {
253 self.output_or_error
254 .as_ref()
255 .map(Result::is_err)
256 .unwrap_or_default()
257 }
258
259 pub fn enter_serialization_phase(&mut self) {
263 debug!("entering \'serialization\' phase");
264 debug_assert!(
265 self.phase.is_before_serialization(),
266 "called enter_serialization_phase but phase is not before 'serialization'"
267 );
268 self.phase = Phase::Serialization;
269 }
270
271 pub fn enter_before_transmit_phase(&mut self) {
275 debug!("entering \'before transmit\' phase");
276 debug_assert!(
277 self.phase.is_serialization(),
278 "called enter_before_transmit_phase but phase is not 'serialization'"
279 );
280 debug_assert!(
281 self.input.is_none(),
282 "input must be taken before calling enter_before_transmit_phase"
283 );
284 debug_assert!(
285 self.request.is_some(),
286 "request must be set before calling enter_before_transmit_phase"
287 );
288 self.request_checkpoint = self.request().expect("checked above").try_clone();
289 self.phase = Phase::BeforeTransmit;
290 }
291
292 pub fn enter_transmit_phase(&mut self) {
296 debug!("entering \'transmit\' phase");
297 debug_assert!(
298 self.phase.is_before_transmit(),
299 "called enter_transmit_phase but phase is not before transmit"
300 );
301 self.phase = Phase::Transmit;
302 }
303
304 pub fn enter_before_deserialization_phase(&mut self) {
308 debug!("entering \'before deserialization\' phase");
309 debug_assert!(
310 self.phase.is_transmit(),
311 "called enter_before_deserialization_phase but phase is not 'transmit'"
312 );
313 debug_assert!(
314 self.request.is_none(),
315 "request must be taken before entering the 'before deserialization' phase"
316 );
317 debug_assert!(
318 self.response.is_some(),
319 "response must be set to before entering the 'before deserialization' phase"
320 );
321 self.phase = Phase::BeforeDeserialization;
322 }
323
324 pub fn enter_deserialization_phase(&mut self) {
328 debug!("entering \'deserialization\' phase");
329 debug_assert!(
330 self.phase.is_before_deserialization(),
331 "called enter_deserialization_phase but phase is not 'before deserialization'"
332 );
333 self.phase = Phase::Deserialization;
334 }
335
336 pub fn enter_after_deserialization_phase(&mut self) {
340 debug!("entering \'after deserialization\' phase");
341 debug_assert!(
342 self.phase.is_deserialization(),
343 "called enter_after_deserialization_phase but phase is not 'deserialization'"
344 );
345 debug_assert!(
346 self.output_or_error.is_some(),
347 "output must be set to before entering the 'after deserialization' phase"
348 );
349 self.phase = Phase::AfterDeserialization;
350 }
351
352 pub fn save_checkpoint(&mut self) {
356 trace!("saving request checkpoint...");
357 self.request_checkpoint = self.request().and_then(|r| r.try_clone());
358 match self.request_checkpoint.as_ref() {
359 Some(_) => trace!("successfully saved request checkpoint"),
360 None => trace!("failed to save request checkpoint: request body could not be cloned"),
361 }
362 }
363
364 pub fn rewind(&mut self, _cfg: &mut ConfigBag) -> RewindResult {
368 let request_checkpoint = match (self.request_checkpoint.as_ref(), self.tainted) {
371 (None, true) => return RewindResult::Impossible,
372 (_, false) => {
373 self.tainted = true;
374 return RewindResult::Unnecessary;
375 }
376 (Some(req), _) => req.try_clone(),
377 };
378
379 self.phase = Phase::BeforeTransmit;
381 self.request = request_checkpoint;
382 assert!(
383 self.request.is_some(),
384 "if the request wasn't cloneable, then we should have already returned from this method."
385 );
386 self.response = None;
387 self.output_or_error = None;
388 RewindResult::Occurred
389 }
390}
391
392impl<I, O, E> InterceptorContext<I, O, E>
393where
394 E: Debug,
395{
396 #[allow(clippy::type_complexity)]
400 pub fn into_parts(
401 self,
402 ) -> (
403 Option<I>,
404 Option<Result<O, OrchestratorError<E>>>,
405 Option<Request>,
406 Option<Response>,
407 ) {
408 (
409 self.input,
410 self.output_or_error,
411 self.request,
412 self.response,
413 )
414 }
415
416 #[allow(clippy::result_large_err)]
420 pub fn finalize(mut self) -> Result<O, SdkError<E, HttpResponse>> {
421 let output_or_error = self
422 .output_or_error
423 .take()
424 .expect("output_or_error must always be set before finalize is called.");
425 self.finalize_result(output_or_error)
426 }
427
428 #[allow(clippy::result_large_err)]
432 pub fn finalize_result(
433 &mut self,
434 result: Result<O, OrchestratorError<E>>,
435 ) -> Result<O, SdkError<E, HttpResponse>> {
436 let response = self.response.take();
437 result.map_err(|error| OrchestratorError::into_sdk_error(error, &self.phase, response))
438 }
439
440 pub fn fail(&mut self, error: OrchestratorError<E>) {
445 if !self.is_failed() {
446 trace!(
447 "orchestrator is transitioning to the 'failure' phase from the '{:?}' phase",
448 self.phase
449 );
450 }
451 if let Some(Err(existing_err)) = self.output_or_error.replace(Err(error)) {
452 error!("orchestrator context received an error but one was already present; Throwing away previous error: {:?}", existing_err);
453 }
454 }
455}
456
457#[non_exhaustive]
461#[derive(Debug, PartialEq, Eq, Clone, Copy)]
462pub enum RewindResult {
463 Impossible,
465 Unnecessary,
467 Occurred,
469}
470
471impl fmt::Display for RewindResult {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 match self {
474 RewindResult::Impossible => write!(
475 f,
476 "The request couldn't be rewound because it wasn't cloneable."
477 ),
478 RewindResult::Unnecessary => {
479 write!(f, "The request wasn't rewound because it was unnecessary.")
480 }
481 RewindResult::Occurred => write!(f, "The request was rewound successfully."),
482 }
483 }
484}
485
486#[cfg(all(test, feature = "test-util", feature = "http-02x"))]
487mod tests {
488 use super::*;
489 use aws_smithy_types::body::SdkBody;
490 use http_02x::header::{AUTHORIZATION, CONTENT_LENGTH};
491 use http_02x::{HeaderValue, Uri};
492
493 #[test]
494 fn test_success_transitions() {
495 let input = Input::doesnt_matter();
496 let output = Output::erase("output".to_string());
497
498 let mut context = InterceptorContext::new(input);
499 assert!(context.input().is_some());
500 context.input_mut();
501
502 context.enter_serialization_phase();
503 let _ = context.take_input();
504 context.set_request(HttpRequest::empty());
505
506 context.enter_before_transmit_phase();
507 context.request();
508 context.request_mut();
509
510 context.enter_transmit_phase();
511 let _ = context.take_request();
512 context.set_response(
513 http_02x::Response::builder()
514 .body(SdkBody::empty())
515 .unwrap()
516 .try_into()
517 .unwrap(),
518 );
519
520 context.enter_before_deserialization_phase();
521 context.response();
522 context.response_mut();
523
524 context.enter_deserialization_phase();
525 context.response();
526 context.response_mut();
527 context.set_output_or_error(Ok(output));
528
529 context.enter_after_deserialization_phase();
530 context.response();
531 context.response_mut();
532 let _ = context.output_or_error();
533 let _ = context.output_or_error_mut();
534
535 let output = context.output_or_error.unwrap().expect("success");
536 assert_eq!("output", output.downcast_ref::<String>().unwrap());
537 }
538
539 #[test]
540 fn test_rewind_for_retry() {
541 let mut cfg = ConfigBag::base();
542 let input = Input::doesnt_matter();
543 let output = Output::erase("output".to_string());
544 let error = Error::doesnt_matter();
545
546 let mut context = InterceptorContext::new(input);
547 assert!(context.input().is_some());
548
549 context.enter_serialization_phase();
550 let _ = context.take_input();
551 context.set_request(
552 http_02x::Request::builder()
553 .header("test", "the-original-un-mutated-request")
554 .body(SdkBody::empty())
555 .unwrap()
556 .try_into()
557 .unwrap(),
558 );
559 context.enter_before_transmit_phase();
560 context.save_checkpoint();
561 assert_eq!(context.rewind(&mut cfg), RewindResult::Unnecessary);
562 context.request_mut().unwrap().headers_mut().remove("test");
564 context.request_mut().unwrap().headers_mut().insert(
565 "test",
566 HeaderValue::from_static("request-modified-after-signing"),
567 );
568
569 context.enter_transmit_phase();
570 let request = context.take_request().unwrap();
571 assert_eq!(
572 "request-modified-after-signing",
573 request.headers().get("test").unwrap()
574 );
575 context.set_response(
576 http_02x::Response::builder()
577 .body(SdkBody::empty())
578 .unwrap()
579 .try_into()
580 .unwrap(),
581 );
582
583 context.enter_before_deserialization_phase();
584 context.enter_deserialization_phase();
585 context.set_output_or_error(Err(OrchestratorError::operation(error)));
586
587 assert_eq!(context.rewind(&mut cfg), RewindResult::Occurred);
588
589 assert_eq!(
591 "the-original-un-mutated-request",
592 context.request().unwrap().headers().get("test").unwrap()
593 );
594
595 context.enter_transmit_phase();
596 let _ = context.take_request();
597 context.set_response(
598 http_02x::Response::builder()
599 .body(SdkBody::empty())
600 .unwrap()
601 .try_into()
602 .unwrap(),
603 );
604
605 context.enter_before_deserialization_phase();
606 context.enter_deserialization_phase();
607 context.set_output_or_error(Ok(output));
608
609 context.enter_after_deserialization_phase();
610
611 let output = context.output_or_error.unwrap().expect("success");
612 assert_eq!("output", output.downcast_ref::<String>().unwrap());
613 }
614
615 #[test]
616 fn try_clone_clones_all_data() {
617 let request: HttpRequest = http_02x::Request::builder()
618 .uri(Uri::from_static("https://www.amazon.com"))
619 .method("POST")
620 .header(CONTENT_LENGTH, 456)
621 .header(AUTHORIZATION, "Token: hello")
622 .body(SdkBody::from("hello world!"))
623 .expect("valid request")
624 .try_into()
625 .unwrap();
626 let cloned = request.try_clone().expect("request is cloneable");
627
628 assert_eq!(&Uri::from_static("https://www.amazon.com"), cloned.uri());
629 assert_eq!("POST", cloned.method());
630 assert_eq!(2, cloned.headers().len());
631 assert_eq!("Token: hello", cloned.headers().get(AUTHORIZATION).unwrap(),);
632 assert_eq!("456", cloned.headers().get(CONTENT_LENGTH).unwrap());
633 assert_eq!("hello world!".as_bytes(), cloned.body().bytes().unwrap());
634 }
635}