3 3 | * SPDX-License-Identifier: Apache-2.0
|
4 4 | */
|
5 5 |
|
6 6 | use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
|
7 7 | use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
|
8 8 | use aws_smithy_runtime_api::client::retries::classifiers::{
|
9 9 | ClassifyRetry, RetryAction, RetryClassifierPriority, RetryReason,
|
10 10 | };
|
11 11 | use aws_smithy_types::error::metadata::ProvideErrorMetadata;
|
12 12 | use aws_smithy_types::retry::ErrorKind;
|
13 + | use std::borrow::Cow;
|
13 14 | use std::error::Error as StdError;
|
14 15 | use std::marker::PhantomData;
|
15 16 |
|
16 17 | /// AWS error codes that represent throttling errors.
|
17 18 | pub const THROTTLING_ERRORS: &[&str] = &[
|
18 19 | "Throttling",
|
19 20 | "ThrottlingException",
|
20 21 | "ThrottledException",
|
21 22 | "RequestThrottledException",
|
22 23 | "TooManyRequestsException",
|
23 24 | "ProvisionedThroughputExceededException",
|
24 25 | "TransactionInProgressException",
|
25 26 | "RequestLimitExceeded",
|
26 27 | "BandwidthLimitExceeded",
|
27 28 | "LimitExceededException",
|
28 29 | "RequestThrottled",
|
29 30 | "SlowDown",
|
30 31 | "PriorRequestNotComplete",
|
31 32 | "EC2ThrottledException",
|
32 33 | ];
|
33 34 |
|
34 35 | /// AWS error codes that represent transient errors.
|
35 36 | pub const TRANSIENT_ERRORS: &[&str] = &["RequestTimeout", "RequestTimeoutException"];
|
36 37 |
|
37 38 | /// A retry classifier for determining if the response sent by an AWS service requires a retry.
|
38 - | #[derive(Debug, Default)]
|
39 + | #[derive(Debug)]
|
39 40 | pub struct AwsErrorCodeClassifier<E> {
|
41 + | throttling_errors: Cow<'static, [&'static str]>,
|
42 + | transient_errors: Cow<'static, [&'static str]>,
|
40 43 | _inner: PhantomData<E>,
|
41 44 | }
|
42 45 |
|
46 + | impl<E> Default for AwsErrorCodeClassifier<E> {
|
47 + | fn default() -> Self {
|
48 + | Self {
|
49 + | throttling_errors: THROTTLING_ERRORS.into(),
|
50 + | transient_errors: TRANSIENT_ERRORS.into(),
|
51 + | _inner: PhantomData,
|
52 + | }
|
53 + | }
|
54 + | }
|
55 + |
|
56 + | /// Builder for [`AwsErrorCodeClassifier`]
|
57 + | #[derive(Debug)]
|
58 + | pub struct AwsErrorCodeClassifierBuilder<E> {
|
59 + | throttling_errors: Option<Cow<'static, [&'static str]>>,
|
60 + | transient_errors: Option<Cow<'static, [&'static str]>>,
|
61 + | _inner: PhantomData<E>,
|
62 + | }
|
63 + |
|
64 + | impl<E> AwsErrorCodeClassifierBuilder<E> {
|
65 + | /// Set `transient_errors` for the builder
|
66 + | pub fn transient_errors(
|
67 + | mut self,
|
68 + | transient_errors: impl Into<Cow<'static, [&'static str]>>,
|
69 + | ) -> Self {
|
70 + | self.transient_errors = Some(transient_errors.into());
|
71 + | self
|
72 + | }
|
73 + |
|
74 + | /// Build a new [`AwsErrorCodeClassifier`]
|
75 + | pub fn build(self) -> AwsErrorCodeClassifier<E> {
|
76 + | AwsErrorCodeClassifier {
|
77 + | throttling_errors: self.throttling_errors.unwrap_or(THROTTLING_ERRORS.into()),
|
78 + | transient_errors: self.transient_errors.unwrap_or(TRANSIENT_ERRORS.into()),
|
79 + | _inner: self._inner,
|
80 + | }
|
81 + | }
|
82 + | }
|
83 + |
|
43 84 | impl<E> AwsErrorCodeClassifier<E> {
|
44 - | /// Create a new AwsErrorCodeClassifier
|
85 + | /// Create a new [`AwsErrorCodeClassifier`]
|
45 86 | pub fn new() -> Self {
|
46 - | Self {
|
87 + | Self::default()
|
88 + | }
|
89 + |
|
90 + | /// Return a builder that can create a new [`AwsErrorCodeClassifier`]
|
91 + | pub fn builder() -> AwsErrorCodeClassifierBuilder<E> {
|
92 + | AwsErrorCodeClassifierBuilder {
|
93 + | throttling_errors: None,
|
94 + | transient_errors: None,
|
47 95 | _inner: PhantomData,
|
48 96 | }
|
49 97 | }
|
50 98 | }
|
51 99 |
|
52 100 | impl<E> ClassifyRetry for AwsErrorCodeClassifier<E>
|
53 101 | where
|
54 102 | E: StdError + ProvideErrorMetadata + Send + Sync + 'static,
|
55 103 | {
|
56 104 | fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
|
57 105 | // Check for a result
|
58 106 | let output_or_error = ctx.output_or_error();
|
59 107 | // Check for an error
|
60 108 | let error = match output_or_error {
|
61 109 | Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
|
62 110 | Some(Err(err)) => err,
|
63 111 | };
|
64 112 |
|
65 113 | let retry_after = ctx
|
66 114 | .response()
|
67 115 | .and_then(|res| res.headers().get("x-amz-retry-after"))
|
68 116 | .and_then(|header| header.parse::<u64>().ok())
|
69 117 | .map(std::time::Duration::from_millis);
|
70 118 |
|
71 119 | let error_code = OrchestratorError::as_operation_error(error)
|
72 120 | .and_then(|err| err.downcast_ref::<E>())
|
73 121 | .and_then(|err| err.code());
|
74 122 |
|
75 123 | if let Some(error_code) = error_code {
|
76 - | if THROTTLING_ERRORS.contains(&error_code) {
|
124 + | if self.throttling_errors.contains(&error_code) {
|
77 125 | return RetryAction::RetryIndicated(RetryReason::RetryableError {
|
78 126 | kind: ErrorKind::ThrottlingError,
|
79 127 | retry_after,
|
80 128 | });
|
81 129 | }
|
82 - | if TRANSIENT_ERRORS.contains(&error_code) {
|
130 + | if self.transient_errors.contains(&error_code) {
|
83 131 | return RetryAction::RetryIndicated(RetryReason::RetryableError {
|
84 132 | kind: ErrorKind::TransientError,
|
85 133 | retry_after,
|
86 134 | });
|
87 135 | }
|
88 136 | };
|
89 137 |
|
90 138 | debug_assert!(
|
91 139 | retry_after.is_none(),
|
92 140 | "retry_after should be None if the error wasn't an identifiable AWS error"
|