use crate::client::retries::classifiers::run_classifiers_on_ctx;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::connection::ConnectionMetadata;
use aws_smithy_runtime_api::client::interceptors::context::{
AfterDeserializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
};
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::retries::classifiers::RetryAction;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::retry::{ReconnectMode, RetryConfig};
use std::fmt;
use std::sync::{Arc, Mutex};
use tracing::{debug, error};
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct ConnectionPoisoningInterceptor {}
impl ConnectionPoisoningInterceptor {
pub fn new() -> Self {
Self::default()
}
}
impl Intercept for ConnectionPoisoningInterceptor {
fn name(&self) -> &'static str {
"ConnectionPoisoningInterceptor"
}
fn modify_before_transmit(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let capture_smithy_connection = CaptureSmithyConnection::new();
context
.request_mut()
.add_extension(capture_smithy_connection.clone());
cfg.interceptor_state().store_put(capture_smithy_connection);
Ok(())
}
fn read_after_deserialization(
&self,
context: &AfterDeserializationInterceptorContextRef<'_>,
runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let reconnect_mode = cfg
.load::<RetryConfig>()
.map(RetryConfig::reconnect_mode)
.unwrap_or(ReconnectMode::ReconnectOnTransientError);
let captured_connection = cfg.load::<CaptureSmithyConnection>().cloned();
let retry_classifier_result =
run_classifiers_on_ctx(runtime_components.retry_classifiers(), context.inner());
let error_is_transient = retry_classifier_result == RetryAction::transient_error();
let connection_poisoning_is_enabled =
reconnect_mode == ReconnectMode::ReconnectOnTransientError;
if error_is_transient && connection_poisoning_is_enabled {
debug!("received a transient error, marking the connection for closure...");
if let Some(captured_connection) = captured_connection.and_then(|conn| conn.get()) {
captured_connection.poison();
debug!("the connection was marked for closure")
} else {
error!(
"unable to mark the connection for closure because no connection was found! The underlying HTTP connector never set a connection."
);
}
}
Ok(())
}
}
type LoaderFn = dyn Fn() -> Option<ConnectionMetadata> + Send + Sync;
#[derive(Clone, Default)]
pub struct CaptureSmithyConnection {
loader: Arc<Mutex<Option<Box<LoaderFn>>>>,
}
impl CaptureSmithyConnection {
pub fn new() -> Self {
Self {
loader: Default::default(),
}
}
pub fn set_connection_retriever<F>(&self, f: F)
where
F: Fn() -> Option<ConnectionMetadata> + Send + Sync + 'static,
{
*self.loader.lock().unwrap() = Some(Box::new(f));
}
pub fn get(&self) -> Option<ConnectionMetadata> {
match self.loader.lock().unwrap().as_ref() {
Some(loader) => loader(),
None => {
tracing::debug!("no loader was set on the CaptureSmithyConnection");
None
}
}
}
}
impl fmt::Debug for CaptureSmithyConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CaptureSmithyConnection")
}
}
impl Storable for CaptureSmithyConnection {
type Storer = StoreReplace<Self>;
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[allow(clippy::redundant_clone)]
fn retrieve_connection_metadata() {
let retriever = CaptureSmithyConnection::new();
let retriever_clone = retriever.clone();
assert!(retriever.get().is_none());
retriever.set_connection_retriever(|| {
Some(
ConnectionMetadata::builder()
.proxied(true)
.poison_fn(|| {})
.build(),
)
});
assert!(retriever.get().is_some());
assert!(retriever_clone.get().is_some());
}
}