1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use crate::client::retries::classifiers::run_classifiers_on_ctx;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
use aws_smithy_runtime_api::client::interceptors::context::{
    AfterDeserializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
};
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::retries::classifiers::RetryAction;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::retry::{ReconnectMode, RetryConfig};
use std::fmt;
use std::sync::{Arc, Mutex};
use tracing::{debug, error};

/// An interceptor for poisoning connections in response to certain events.
///
/// This interceptor, when paired with a compatible connection, allows the connection to be
/// poisoned in reaction to certain events *(like receiving a transient error.)* This allows users
/// to avoid sending requests to a server that isn't responding. This can increase the load on a
/// server, because more connections will be made overall.
///
/// **In order for this interceptor to work,** the configured connection must interact with the
/// "connection retriever" stored in an HTTP request's `extensions` map. For an example of this,
/// see [`HyperConnector`]. When a connection is made available to the retriever, this interceptor
/// will call a `.poison` method on it, signalling that the connection should be dropped. It is
/// up to the connection implementer to handle this.
///
/// [`HyperConnector`]: https://github.com/smithy-lang/smithy-rs/blob/26a914ece072bba2dd9b5b49003204b70e7666ac/rust-runtime/aws-smithy-runtime/src/client/http/hyper_014.rs#L347
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct ConnectionPoisoningInterceptor {}

impl ConnectionPoisoningInterceptor {
    /// Create a new `ConnectionPoisoningInterceptor`.
    pub fn new() -> Self {
        Self::default()
    }
}

impl Intercept for ConnectionPoisoningInterceptor {
    fn name(&self) -> &'static str {
        "ConnectionPoisoningInterceptor"
    }

    fn modify_before_transmit(
        &self,
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
        _runtime_components: &RuntimeComponents,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let capture_smithy_connection = CaptureSmithyConnection::new();
        context
            .request_mut()
            .add_extension(capture_smithy_connection.clone());
        cfg.interceptor_state().store_put(capture_smithy_connection);

        Ok(())
    }

    fn read_after_deserialization(
        &self,
        context: &AfterDeserializationInterceptorContextRef<'_>,
        runtime_components: &RuntimeComponents,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        let reconnect_mode = cfg
            .load::<RetryConfig>()
            .map(RetryConfig::reconnect_mode)
            .unwrap_or(ReconnectMode::ReconnectOnTransientError);
        let captured_connection = cfg.load::<CaptureSmithyConnection>().cloned();
        let retry_classifier_result =
            run_classifiers_on_ctx(runtime_components.retry_classifiers(), context.inner());
        let error_is_transient = retry_classifier_result == RetryAction::transient_error();
        let connection_poisoning_is_enabled =
            reconnect_mode == ReconnectMode::ReconnectOnTransientError;

        if error_is_transient && connection_poisoning_is_enabled {
            debug!("received a transient error, marking the connection for closure...");

            if let Some(captured_connection) = captured_connection.and_then(|conn| conn.get()) {
                captured_connection.poison();
                debug!("the connection was marked for closure")
            } else {
                error!(
                    "unable to mark the connection for closure because no connection was found! The underlying HTTP connector never set a connection."
                );
            }
        }

        Ok(())
    }
}

type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;

/// State for a middleware that will monitor and manage connections.
#[derive(Clone, Default)]
pub struct CaptureSmithyConnection {
    loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
}

impl CaptureSmithyConnection {
    /// Create a new connection monitor.
    pub fn new() -> Self {
        Self {
            loader: Default::default(),
        }
    }

    /// Set the retriever that will capture the `hyper` connection.
    pub fn set_connection_retriever<F>(&self, f: F)
    where
        F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
    {
        *self.loader.lock().unwrap() = Some(Box::new(f));
    }

    /// Get the associated connection metadata.
    pub fn get(&self) -> Option<ConnectionMetadata> {
        match self.loader.lock().unwrap().as_ref() {
            Some(loader) => loader(),
            None => {
                tracing::debug!("no loader was set on the CaptureSmithyConnection");
                None
            }
        }
    }
}

impl fmt::Debug for CaptureSmithyConnection {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "CaptureSmithyConnection")
    }
}

impl Storable for CaptureSmithyConnection {
    type Storer = StoreReplace<Self>;
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    #[allow(clippy::redundant_clone)]
    fn retrieve_connection_metadata() {
        let retriever = CaptureSmithyConnection::new();
        let retriever_clone = retriever.clone();
        assert!(retriever.get().is_none());
        retriever.set_connection_retriever(|| {
            Some(
                ConnectionMetadata::builder()
                    .proxied(true)
                    .poison_fn(|| {})
                    .build(),
            )
        });

        assert!(retriever.get().is_some());
        assert!(retriever_clone.get().is_some());
    }
}