1use crate::{MockResponse, Rule, RuleMode};
7use aws_smithy_http_client::test_util::infallible_client_fn;
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::http::SharedHttpClient;
10use aws_smithy_runtime_api::client::interceptors::context::{
11 BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Error,
12 FinalizerInterceptorContextMut, Output,
13};
14use aws_smithy_runtime_api::client::interceptors::Intercept;
15use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, OrchestratorError};
16use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
19use std::collections::VecDeque;
20use std::fmt;
21use std::sync::{Arc, Mutex};
22
23#[derive(Debug, Clone)]
25struct ActiveRule(Rule);
26
27impl Storable for ActiveRule {
28 type Storer = StoreReplace<ActiveRule>;
29}
30
31pub struct MockResponseInterceptor {
33 rules: Arc<Mutex<VecDeque<Rule>>>,
34 rule_mode: RuleMode,
35 must_match: bool,
36 active_response: Arc<Mutex<Option<MockResponse<Output, Error>>>>,
37}
38
39impl fmt::Debug for MockResponseInterceptor {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 write!(f, "{} rules", self.rules.lock().unwrap().len())
42 }
43}
44
45impl Default for MockResponseInterceptor {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl MockResponseInterceptor {
52 pub fn new() -> Self {
56 Self {
57 rules: Default::default(),
58 rule_mode: RuleMode::MatchAny,
59 must_match: true,
60 active_response: Default::default(),
61 }
62 }
63 pub fn with_rule(self, rule: &Rule) -> Self {
67 self.rules.lock().unwrap().push_back(rule.clone());
68 self
69 }
70
71 pub fn rule_mode(mut self, rule_mode: RuleMode) -> Self {
75 self.rule_mode = rule_mode;
76 self
77 }
78
79 pub fn allow_passthrough(mut self) -> Self {
84 self.must_match = false;
85 self
86 }
87}
88
89impl Intercept for MockResponseInterceptor {
90 fn name(&self) -> &'static str {
91 "MockResponseInterceptor"
92 }
93
94 fn modify_before_serialization(
95 &self,
96 context: &mut BeforeSerializationInterceptorContextMut<'_>,
97 _runtime_components: &RuntimeComponents,
98 cfg: &mut ConfigBag,
99 ) -> Result<(), BoxError> {
100 let mut rules = self.rules.lock().unwrap();
101 let input = context.inner().input().expect("input set");
102
103 let mut matching_rule = None;
105 let mut matching_response = None;
106
107 match self.rule_mode {
108 RuleMode::Sequential => {
109 let i = 0;
111 while i < rules.len() && matching_response.is_none() {
112 let rule = &rules[i];
113
114 if rule.is_exhausted() {
116 rules.remove(i);
118 continue; }
120
121 if !(rule.matcher)(input) {
123 panic!(
125 "In order matching was enforced but rule did not match {:?}",
126 input
127 );
128 }
129
130 if let Some(response) = rule.next_response() {
132 matching_rule = Some(rule.clone());
133 matching_response = Some(response);
134 } else {
135 rules.remove(i);
137 continue; }
139
140 break;
142 }
143 }
144 RuleMode::MatchAny => {
145 for rule in rules.iter() {
147 if rule.is_exhausted() {
149 continue;
150 }
151
152 if (rule.matcher)(input) {
153 if let Some(response) = rule.next_response() {
154 matching_rule = Some(rule.clone());
155 matching_response = Some(response);
156 break;
157 }
158 }
159 }
160 }
161 };
162
163 match (matching_rule, matching_response) {
164 (Some(rule), Some(response)) => {
165 cfg.interceptor_state().store_put(ActiveRule(rule));
167 let mut active_resp = self.active_response.lock().unwrap();
170 let _ = std::mem::replace(&mut *active_resp, Some(response));
171 }
172 _ => {
173 if self.must_match {
175 panic!(
176 "must_match was enabled but no rules matched or all rules were exhausted for {:?}",
177 input
178 );
179 }
180 }
181 }
182
183 Ok(())
184 }
185
186 fn modify_before_transmit(
187 &self,
188 context: &mut BeforeTransmitInterceptorContextMut<'_>,
189 _runtime_components: &RuntimeComponents,
190 cfg: &mut ConfigBag,
191 ) -> Result<(), BoxError> {
192 let mut state = self.active_response.lock().unwrap();
193 let mut active_response = (*state).take();
194 if active_response.is_none() {
195 if let Some(active_rule) = cfg.load::<ActiveRule>() {
197 let next_resp = active_rule.0.next_response();
198 active_response = next_resp;
199 }
200 }
201
202 if let Some(resp) = active_response {
203 match resp {
204 MockResponse::Http(http_resp) => {
206 context
207 .request_mut()
208 .add_extension(MockHttpResponse(Arc::new(http_resp)));
209 }
210 _ => {
211 let _ = std::mem::replace(&mut *state, Some(resp));
213 }
214 }
215 }
216
217 Ok(())
218 }
219
220 fn modify_before_attempt_completion(
221 &self,
222 context: &mut FinalizerInterceptorContextMut<'_>,
223 _runtime_components: &RuntimeComponents,
224 _cfg: &mut ConfigBag,
225 ) -> Result<(), BoxError> {
226 let mut state = self.active_response.lock().unwrap();
228 let active_response = (*state).take();
229 if let Some(resp) = active_response {
230 match resp {
231 MockResponse::Output(output) => {
232 context.inner_mut().set_output_or_error(Ok(output));
233 }
234 MockResponse::Error(error) => {
235 context
236 .inner_mut()
237 .set_output_or_error(Err(OrchestratorError::operation(error)));
238 }
239 MockResponse::Http(_) => {
240 }
242 }
243 }
244
245 Ok(())
246 }
247}
248
249#[derive(Clone)]
251struct MockHttpResponse(Arc<HttpResponse>);
252
253pub fn create_mock_http_client() -> SharedHttpClient {
255 infallible_client_fn(|mut req| {
256 if let Some(mock_response) = req.extensions_mut().remove::<MockHttpResponse>() {
258 let http_resp =
259 Arc::try_unwrap(mock_response.0).expect("mock HTTP response has single reference");
260 return http_resp.try_into_http1x().unwrap();
261 }
262
263 http::Response::builder()
265 .status(418)
266 .body(SdkBody::from("Mock HTTP client dummy response"))
267 .unwrap()
268 })
269}
270
271#[cfg(test)]
272mod tests {
273 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
274 use aws_smithy_runtime::client::orchestrator::operation::Operation;
275 use aws_smithy_runtime::client::retries::classifiers::HttpStatusCodeClassifier;
276 use aws_smithy_runtime_api::client::orchestrator::{
277 HttpRequest, HttpResponse, OrchestratorError,
278 };
279 use aws_smithy_runtime_api::client::result::SdkError;
280 use aws_smithy_runtime_api::http::StatusCode;
281 use aws_smithy_types::body::SdkBody;
282 use aws_smithy_types::retry::RetryConfig;
283 use aws_smithy_types::timeout::TimeoutConfig;
284
285 use crate::{create_mock_http_client, MockResponseInterceptor, RuleBuilder, RuleMode};
286 use std::time::Duration;
287
288 #[derive(Debug)]
290 struct TestInput {
291 bucket: String,
292 key: String,
293 }
294 impl TestInput {
295 fn new(bucket: &str, key: &str) -> Self {
296 Self {
297 bucket: bucket.to_string(),
298 key: key.to_string(),
299 }
300 }
301 }
302
303 #[derive(Debug, PartialEq)]
304 struct TestOutput {
305 content: String,
306 }
307
308 impl TestOutput {
309 fn new(content: &str) -> Self {
310 Self {
311 content: content.to_string(),
312 }
313 }
314 }
315
316 #[derive(Debug)]
317 struct TestError {
318 message: String,
319 }
320
321 impl TestError {
322 fn new(message: &str) -> Self {
323 Self {
324 message: message.to_string(),
325 }
326 }
327 }
328
329 impl std::fmt::Display for TestError {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 write!(f, "{}", self.message)
332 }
333 }
334
335 impl std::error::Error for TestError {}
336
337 fn create_rule_builder() -> RuleBuilder<TestInput, TestOutput, TestError> {
339 RuleBuilder::new_from_mock(
340 || TestInput {
341 bucket: "".to_string(),
342 key: "".to_string(),
343 },
344 || {
345 let fut: std::future::Ready<Result<TestOutput, SdkError<TestError, HttpResponse>>> =
346 std::future::ready(Ok(TestOutput {
347 content: "".to_string(),
348 }));
349 fut
350 },
351 )
352 }
353
354 fn create_test_operation(
356 interceptor: MockResponseInterceptor,
357 enable_retries: bool,
358 ) -> Operation<TestInput, TestOutput, TestError> {
359 let builder = Operation::builder()
360 .service_name("test")
361 .operation_name("test")
362 .http_client(create_mock_http_client())
363 .endpoint_url("http://localhost:1234")
364 .no_auth()
365 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
366 .timeout_config(TimeoutConfig::disabled())
367 .interceptor(interceptor)
368 .serializer(|input: TestInput| {
369 let mut request = HttpRequest::new(SdkBody::empty());
370 request
371 .set_uri(format!("/{}/{}", input.bucket, input.key))
372 .expect("valid URI");
373 Ok(request)
374 })
375 .deserializer::<TestOutput, TestError>(|response| {
376 if response.status().is_success() {
377 let body = std::str::from_utf8(response.body().bytes().unwrap())
378 .unwrap_or("empty body")
379 .to_string();
380 Ok(TestOutput { content: body })
381 } else {
382 Err(OrchestratorError::operation(TestError {
383 message: format!("Error: {}", response.status()),
384 }))
385 }
386 });
387
388 if enable_retries {
389 let retry_config = RetryConfig::standard()
390 .with_max_attempts(5)
391 .with_initial_backoff(Duration::from_millis(1))
392 .with_max_backoff(Duration::from_millis(5));
393
394 builder
395 .retry_classifier(HttpStatusCodeClassifier::default())
396 .standard_retry(&retry_config)
397 .build()
398 } else {
399 builder.no_retry().build()
400 }
401 }
402
403 #[tokio::test]
404 async fn test_retry_sequence() {
405 let rule = create_rule_builder()
407 .match_requests(|input| input.bucket == "test-bucket" && input.key == "test-key")
408 .sequence()
409 .http_status(503, None)
410 .times(2)
411 .output(|| TestOutput::new("success after retries"))
412 .build();
413
414 let interceptor = MockResponseInterceptor::new()
416 .rule_mode(RuleMode::Sequential)
417 .with_rule(&rule);
418
419 let operation = create_test_operation(interceptor, true);
420
421 let result = operation
423 .invoke(TestInput::new("test-bucket", "test-key"))
424 .await;
425
426 assert!(
428 result.is_ok(),
429 "Expected success but got error: {:?}",
430 result.err()
431 );
432 assert_eq!(
433 result.unwrap(),
434 TestOutput {
435 content: "success after retries".to_string()
436 }
437 );
438
439 assert_eq!(rule.num_calls(), 3);
441 }
442
443 #[should_panic(
444 expected = "must_match was enabled but no rules matched or all rules were exhausted for"
445 )]
446 #[tokio::test]
447 async fn test_exhausted_rules() {
448 let rule = create_rule_builder().then_output(|| TestOutput::new("only response"));
450
451 let interceptor = MockResponseInterceptor::new()
453 .rule_mode(RuleMode::Sequential)
454 .with_rule(&rule);
455
456 let operation = create_test_operation(interceptor, false);
457
458 let result1 = operation
460 .invoke(TestInput::new("test-bucket", "test-key"))
461 .await;
462 assert!(result1.is_ok());
463
464 let _result2 = operation
466 .invoke(TestInput::new("test-bucket", "test-key"))
467 .await;
468 }
469
470 #[tokio::test]
471 async fn test_rule_mode_match_any() {
472 let rule1 = create_rule_builder()
474 .match_requests(|input| input.bucket == "bucket1")
475 .then_output(|| TestOutput::new("response1"));
476
477 let rule2 = create_rule_builder()
478 .match_requests(|input| input.bucket == "bucket2")
479 .then_output(|| TestOutput::new("response2"));
480
481 let interceptor = MockResponseInterceptor::new()
483 .rule_mode(RuleMode::MatchAny)
484 .with_rule(&rule1)
485 .with_rule(&rule2);
486
487 let operation = create_test_operation(interceptor, false);
488
489 let result1 = operation
491 .invoke(TestInput::new("bucket1", "test-key"))
492 .await;
493 assert!(result1.is_ok());
494 assert_eq!(result1.unwrap(), TestOutput::new("response1"));
495
496 let result2 = operation
498 .invoke(TestInput::new("bucket2", "test-key"))
499 .await;
500 assert!(result2.is_ok());
501 assert_eq!(result2.unwrap(), TestOutput::new("response2"));
502
503 assert_eq!(rule1.num_calls(), 1);
505 assert_eq!(rule2.num_calls(), 1);
506 }
507
508 #[tokio::test]
509 async fn test_mixed_response_types() {
510 let rule = create_rule_builder()
512 .sequence()
513 .output(|| TestOutput::new("first output"))
514 .error(|| TestError::new("expected error"))
515 .http_response(|| {
516 HttpResponse::new(
517 StatusCode::try_from(200).unwrap(),
518 SdkBody::from("http response"),
519 )
520 })
521 .build();
522
523 let interceptor = MockResponseInterceptor::new()
525 .rule_mode(RuleMode::Sequential)
526 .with_rule(&rule);
527
528 let operation = create_test_operation(interceptor, false);
529
530 let result1 = operation
532 .invoke(TestInput::new("test-bucket", "test-key"))
533 .await;
534 assert!(result1.is_ok());
535 assert_eq!(result1.unwrap(), TestOutput::new("first output"));
536
537 let result2 = operation
539 .invoke(TestInput::new("test-bucket", "test-key"))
540 .await;
541 assert!(result2.is_err());
542 let sdk_err = result2.unwrap_err();
543 let err = sdk_err.as_service_error().expect("expected service error");
544 assert_eq!(err.to_string(), "expected error");
545
546 let result3 = operation
548 .invoke(TestInput::new("test-bucket", "test-key"))
549 .await;
550 assert!(result3.is_ok());
551 assert_eq!(result3.unwrap(), TestOutput::new("http response"));
552
553 assert_eq!(rule.num_calls(), 3);
555 }
556
557 #[tokio::test]
558 async fn test_exhausted_sequence() {
559 let rule = create_rule_builder()
561 .sequence()
562 .output(|| TestOutput::new("response 1"))
563 .output(|| TestOutput::new("response 2"))
564 .build();
565
566 let fallback_rule =
568 create_rule_builder().then_output(|| TestOutput::new("fallback response"));
569
570 let interceptor = MockResponseInterceptor::new()
572 .rule_mode(RuleMode::Sequential)
573 .with_rule(&rule)
574 .with_rule(&fallback_rule);
575
576 let operation = create_test_operation(interceptor, false);
577
578 let result1 = operation
580 .invoke(TestInput::new("test-bucket", "test-key"))
581 .await;
582 assert!(result1.is_ok());
583 assert_eq!(result1.unwrap(), TestOutput::new("response 1"));
584
585 let result2 = operation
586 .invoke(TestInput::new("test-bucket", "test-key"))
587 .await;
588 assert!(result2.is_ok());
589 assert_eq!(result2.unwrap(), TestOutput::new("response 2"));
590
591 let result3 = operation
593 .invoke(TestInput::new("test-bucket", "test-key"))
594 .await;
595 assert!(result3.is_ok());
596 assert_eq!(result3.unwrap(), TestOutput::new("fallback response"));
597
598 assert_eq!(rule.num_calls(), 2);
600 assert_eq!(fallback_rule.num_calls(), 1);
601 }
602
603 #[tokio::test]
604 async fn test_concurrent_usage() {
605 use std::sync::Arc;
606 use tokio::task;
607
608 let rule = Arc::new(
610 create_rule_builder()
611 .sequence()
612 .output(|| TestOutput::new("response 1"))
613 .output(|| TestOutput::new("response 2"))
614 .output(|| TestOutput::new("response 3"))
615 .build(),
616 );
617
618 let interceptor = MockResponseInterceptor::new()
620 .rule_mode(RuleMode::Sequential)
621 .with_rule(&rule);
622
623 let operation = Arc::new(create_test_operation(interceptor, false));
624
625 let mut handles = vec![];
627 for i in 0..3 {
628 let op = operation.clone();
629 let handle = task::spawn(async move {
630 let result = op
631 .invoke(TestInput::new(&format!("bucket-{}", i), "test-key"))
632 .await;
633 result.unwrap()
634 });
635 handles.push(handle);
636 }
637
638 let mut results = vec![];
640 for handle in handles {
641 results.push(handle.await.unwrap());
642 }
643
644 results.sort_by(|a, b| a.content.cmp(&b.content));
646
647 assert_eq!(results.len(), 3);
649 assert_eq!(results[0], TestOutput::new("response 1"));
650 assert_eq!(results[1], TestOutput::new("response 2"));
651 assert_eq!(results[2], TestOutput::new("response 3"));
652
653 assert_eq!(rule.num_calls(), 3);
655 }
656
657 #[tokio::test]
658 async fn test_sequential_rule_removal() {
659 let rule1 = create_rule_builder()
661 .match_requests(|input| input.bucket == "test-bucket" && input.key != "correct-key")
662 .then_http_response(|| {
663 HttpResponse::new(
664 StatusCode::try_from(404).unwrap(),
665 SdkBody::from("not found"),
666 )
667 });
668
669 let rule2 = create_rule_builder()
671 .match_requests(|input| input.bucket == "test-bucket" && input.key == "correct-key")
672 .then_output(|| TestOutput::new("success"));
673
674 let interceptor = MockResponseInterceptor::new()
676 .rule_mode(RuleMode::Sequential)
677 .with_rule(&rule1)
678 .with_rule(&rule2);
679
680 let operation = create_test_operation(interceptor, true);
681
682 let result1 = operation.invoke(TestInput::new("test-bucket", "foo")).await;
684 assert!(result1.is_err());
685 assert_eq!(rule1.num_calls(), 1);
686
687 let result2 = operation
690 .invoke(TestInput::new("test-bucket", "correct-key"))
691 .await;
692
693 assert!(result2.is_ok());
695 assert_eq!(result2.unwrap(), TestOutput::new("success"));
696 assert_eq!(rule2.num_calls(), 1);
697 }
698}