use crate::{MockResponse, Rule, RuleMode};
use aws_smithy_http_client::test_util::infallible_client_fn;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::http::SharedHttpClient;
use aws_smithy_runtime_api::client::interceptors::context::{
BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Error,
FinalizerInterceptorContextMut, Output,
};
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, OrchestratorError};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use std::collections::VecDeque;
use std::fmt;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
struct ActiveRule(Rule);
impl Storable for ActiveRule {
type Storer = StoreReplace<ActiveRule>;
}
pub struct MockResponseInterceptor {
rules: Arc<Mutex<VecDeque<Rule>>>,
rule_mode: RuleMode,
must_match: bool,
active_response: Arc<Mutex<Option<MockResponse<Output, Error>>>>,
}
impl fmt::Debug for MockResponseInterceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} rules", self.rules.lock().unwrap().len())
}
}
impl Default for MockResponseInterceptor {
fn default() -> Self {
Self::new()
}
}
impl MockResponseInterceptor {
pub fn new() -> Self {
Self {
rules: Default::default(),
rule_mode: RuleMode::MatchAny,
must_match: true,
active_response: Default::default(),
}
}
pub fn with_rule(self, rule: &Rule) -> Self {
self.rules.lock().unwrap().push_back(rule.clone());
self
}
pub fn rule_mode(mut self, rule_mode: RuleMode) -> Self {
self.rule_mode = rule_mode;
self
}
pub fn allow_passthrough(mut self) -> Self {
self.must_match = false;
self
}
}
impl Intercept for MockResponseInterceptor {
fn name(&self) -> &'static str {
"MockResponseInterceptor"
}
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let mut rules = self.rules.lock().unwrap();
let input = context.inner().input().expect("input set");
let mut matching_rule = None;
let mut matching_response = None;
match self.rule_mode {
RuleMode::Sequential => {
let i = 0;
while i < rules.len() && matching_response.is_none() {
let rule = &rules[i];
if rule.is_exhausted() {
rules.remove(i);
continue; }
if !(rule.matcher)(input) {
panic!(
"In order matching was enforced but rule did not match {:?}",
input
);
}
if let Some(response) = rule.next_response() {
matching_rule = Some(rule.clone());
matching_response = Some(response);
} else {
rules.remove(i);
continue; }
break;
}
}
RuleMode::MatchAny => {
for rule in rules.iter() {
if rule.is_exhausted() {
continue;
}
if (rule.matcher)(input) {
if let Some(response) = rule.next_response() {
matching_rule = Some(rule.clone());
matching_response = Some(response);
break;
}
}
}
}
};
match (matching_rule, matching_response) {
(Some(rule), Some(response)) => {
cfg.interceptor_state().store_put(ActiveRule(rule));
let mut active_resp = self.active_response.lock().unwrap();
let _ = std::mem::replace(&mut *active_resp, Some(response));
}
_ => {
if self.must_match {
panic!(
"must_match was enabled but no rules matched or all rules were exhausted for {:?}",
input
);
}
}
}
Ok(())
}
fn modify_before_transmit(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let mut state = self.active_response.lock().unwrap();
let mut active_response = (*state).take();
if active_response.is_none() {
if let Some(active_rule) = cfg.load::<ActiveRule>() {
let next_resp = active_rule.0.next_response();
active_response = next_resp;
}
}
if let Some(resp) = active_response {
match resp {
MockResponse::Http(http_resp) => {
context
.request_mut()
.add_extension(MockHttpResponse(Arc::new(http_resp)));
}
_ => {
let _ = std::mem::replace(&mut *state, Some(resp));
}
}
}
Ok(())
}
fn modify_before_attempt_completion(
&self,
context: &mut FinalizerInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let mut state = self.active_response.lock().unwrap();
let active_response = (*state).take();
if let Some(resp) = active_response {
match resp {
MockResponse::Output(output) => {
context.inner_mut().set_output_or_error(Ok(output));
}
MockResponse::Error(error) => {
context
.inner_mut()
.set_output_or_error(Err(OrchestratorError::operation(error)));
}
MockResponse::Http(_) => {
}
}
}
Ok(())
}
}
#[derive(Clone)]
struct MockHttpResponse(Arc<HttpResponse>);
pub fn create_mock_http_client() -> SharedHttpClient {
infallible_client_fn(|mut req| {
if let Some(mock_response) = req.extensions_mut().remove::<MockHttpResponse>() {
let http_resp =
Arc::try_unwrap(mock_response.0).expect("mock HTTP response has single reference");
return http_resp.try_into_http1x().unwrap();
}
http::Response::builder()
.status(418)
.body(SdkBody::from("Mock HTTP client dummy response"))
.unwrap()
})
}
#[cfg(test)]
mod tests {
use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::client::retries::classifiers::HttpStatusCodeClassifier;
use aws_smithy_runtime_api::client::orchestrator::{
HttpRequest, HttpResponse, OrchestratorError,
};
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_runtime_api::http::StatusCode;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::retry::RetryConfig;
use aws_smithy_types::timeout::TimeoutConfig;
use crate::{create_mock_http_client, MockResponseInterceptor, RuleBuilder, RuleMode};
use std::time::Duration;
#[derive(Debug)]
struct TestInput {
bucket: String,
key: String,
}
impl TestInput {
fn new(bucket: &str, key: &str) -> Self {
Self {
bucket: bucket.to_string(),
key: key.to_string(),
}
}
}
#[derive(Debug, PartialEq)]
struct TestOutput {
content: String,
}
impl TestOutput {
fn new(content: &str) -> Self {
Self {
content: content.to_string(),
}
}
}
#[derive(Debug)]
struct TestError {
message: String,
}
impl TestError {
fn new(message: &str) -> Self {
Self {
message: message.to_string(),
}
}
}
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for TestError {}
fn create_rule_builder() -> RuleBuilder<TestInput, TestOutput, TestError> {
RuleBuilder::new_from_mock(
|| TestInput {
bucket: "".to_string(),
key: "".to_string(),
},
|| {
let fut: std::future::Ready<Result<TestOutput, SdkError<TestError, HttpResponse>>> =
std::future::ready(Ok(TestOutput {
content: "".to_string(),
}));
fut
},
)
}
fn create_test_operation(
interceptor: MockResponseInterceptor,
enable_retries: bool,
) -> Operation<TestInput, TestOutput, TestError> {
let builder = Operation::builder()
.service_name("test")
.operation_name("test")
.http_client(create_mock_http_client())
.endpoint_url("http://localhost:1234")
.no_auth()
.sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
.timeout_config(TimeoutConfig::disabled())
.interceptor(interceptor)
.serializer(|input: TestInput| {
let mut request = HttpRequest::new(SdkBody::empty());
request
.set_uri(format!("/{}/{}", input.bucket, input.key))
.expect("valid URI");
Ok(request)
})
.deserializer::<TestOutput, TestError>(|response| {
if response.status().is_success() {
let body = std::str::from_utf8(response.body().bytes().unwrap())
.unwrap_or("empty body")
.to_string();
Ok(TestOutput { content: body })
} else {
Err(OrchestratorError::operation(TestError {
message: format!("Error: {}", response.status()),
}))
}
});
if enable_retries {
let retry_config = RetryConfig::standard()
.with_max_attempts(5)
.with_initial_backoff(Duration::from_millis(1))
.with_max_backoff(Duration::from_millis(5));
builder
.retry_classifier(HttpStatusCodeClassifier::default())
.standard_retry(&retry_config)
.build()
} else {
builder.no_retry().build()
}
}
#[tokio::test]
async fn test_retry_sequence() {
let rule = create_rule_builder()
.match_requests(|input| input.bucket == "test-bucket" && input.key == "test-key")
.sequence()
.http_status(503, None)
.times(2)
.output(|| TestOutput::new("success after retries"))
.build();
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule);
let operation = create_test_operation(interceptor, true);
let result = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(
result.is_ok(),
"Expected success but got error: {:?}",
result.err()
);
assert_eq!(
result.unwrap(),
TestOutput {
content: "success after retries".to_string()
}
);
assert_eq!(rule.num_calls(), 3);
}
#[should_panic(
expected = "must_match was enabled but no rules matched or all rules were exhausted for"
)]
#[tokio::test]
async fn test_exhausted_rules() {
let rule = create_rule_builder().then_output(|| TestOutput::new("only response"));
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule);
let operation = create_test_operation(interceptor, false);
let result1 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result1.is_ok());
let _result2 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
}
#[tokio::test]
async fn test_rule_mode_match_any() {
let rule1 = create_rule_builder()
.match_requests(|input| input.bucket == "bucket1")
.then_output(|| TestOutput::new("response1"));
let rule2 = create_rule_builder()
.match_requests(|input| input.bucket == "bucket2")
.then_output(|| TestOutput::new("response2"));
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::MatchAny)
.with_rule(&rule1)
.with_rule(&rule2);
let operation = create_test_operation(interceptor, false);
let result1 = operation
.invoke(TestInput::new("bucket1", "test-key"))
.await;
assert!(result1.is_ok());
assert_eq!(result1.unwrap(), TestOutput::new("response1"));
let result2 = operation
.invoke(TestInput::new("bucket2", "test-key"))
.await;
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), TestOutput::new("response2"));
assert_eq!(rule1.num_calls(), 1);
assert_eq!(rule2.num_calls(), 1);
}
#[tokio::test]
async fn test_mixed_response_types() {
let rule = create_rule_builder()
.sequence()
.output(|| TestOutput::new("first output"))
.error(|| TestError::new("expected error"))
.http_response(|| {
HttpResponse::new(
StatusCode::try_from(200).unwrap(),
SdkBody::from("http response"),
)
})
.build();
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule);
let operation = create_test_operation(interceptor, false);
let result1 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result1.is_ok());
assert_eq!(result1.unwrap(), TestOutput::new("first output"));
let result2 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result2.is_err());
let sdk_err = result2.unwrap_err();
let err = sdk_err.as_service_error().expect("expected service error");
assert_eq!(err.to_string(), "expected error");
let result3 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result3.is_ok());
assert_eq!(result3.unwrap(), TestOutput::new("http response"));
assert_eq!(rule.num_calls(), 3);
}
#[tokio::test]
async fn test_exhausted_sequence() {
let rule = create_rule_builder()
.sequence()
.output(|| TestOutput::new("response 1"))
.output(|| TestOutput::new("response 2"))
.build();
let fallback_rule =
create_rule_builder().then_output(|| TestOutput::new("fallback response"));
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule)
.with_rule(&fallback_rule);
let operation = create_test_operation(interceptor, false);
let result1 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result1.is_ok());
assert_eq!(result1.unwrap(), TestOutput::new("response 1"));
let result2 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), TestOutput::new("response 2"));
let result3 = operation
.invoke(TestInput::new("test-bucket", "test-key"))
.await;
assert!(result3.is_ok());
assert_eq!(result3.unwrap(), TestOutput::new("fallback response"));
assert_eq!(rule.num_calls(), 2);
assert_eq!(fallback_rule.num_calls(), 1);
}
#[tokio::test]
async fn test_concurrent_usage() {
use std::sync::Arc;
use tokio::task;
let rule = Arc::new(
create_rule_builder()
.sequence()
.output(|| TestOutput::new("response 1"))
.output(|| TestOutput::new("response 2"))
.output(|| TestOutput::new("response 3"))
.build(),
);
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule);
let operation = Arc::new(create_test_operation(interceptor, false));
let mut handles = vec![];
for i in 0..3 {
let op = operation.clone();
let handle = task::spawn(async move {
let result = op
.invoke(TestInput::new(&format!("bucket-{}", i), "test-key"))
.await;
result.unwrap()
});
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
results.push(handle.await.unwrap());
}
results.sort_by(|a, b| a.content.cmp(&b.content));
assert_eq!(results.len(), 3);
assert_eq!(results[0], TestOutput::new("response 1"));
assert_eq!(results[1], TestOutput::new("response 2"));
assert_eq!(results[2], TestOutput::new("response 3"));
assert_eq!(rule.num_calls(), 3);
}
#[tokio::test]
async fn test_sequential_rule_removal() {
let rule1 = create_rule_builder()
.match_requests(|input| input.bucket == "test-bucket" && input.key != "correct-key")
.then_http_response(|| {
HttpResponse::new(
StatusCode::try_from(404).unwrap(),
SdkBody::from("not found"),
)
});
let rule2 = create_rule_builder()
.match_requests(|input| input.bucket == "test-bucket" && input.key == "correct-key")
.then_output(|| TestOutput::new("success"));
let interceptor = MockResponseInterceptor::new()
.rule_mode(RuleMode::Sequential)
.with_rule(&rule1)
.with_rule(&rule2);
let operation = create_test_operation(interceptor, true);
let result1 = operation.invoke(TestInput::new("test-bucket", "foo")).await;
assert!(result1.is_err());
assert_eq!(rule1.num_calls(), 1);
let result2 = operation
.invoke(TestInput::new("test-bucket", "correct-key"))
.await;
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), TestOutput::new("success"));
assert_eq!(rule2.num_calls(), 1);
}
}