aws_smithy_mocks/
interceptor.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::{MockResponse, Rule, RuleMode};
7use aws_smithy_http_client::test_util::infallible_client_fn;
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::http::SharedHttpClient;
10use aws_smithy_runtime_api::client::interceptors::context::{
11    BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Error,
12    FinalizerInterceptorContextMut, Output,
13};
14use aws_smithy_runtime_api::client::interceptors::Intercept;
15use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, OrchestratorError};
16use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
17use aws_smithy_types::body::SdkBody;
18use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
19use std::collections::VecDeque;
20use std::fmt;
21use std::sync::{Arc, Mutex};
22
23// Store active rule in config bag
24#[derive(Debug, Clone)]
25struct ActiveRule(Rule);
26
27impl Storable for ActiveRule {
28    type Storer = StoreReplace<ActiveRule>;
29}
30
31/// Interceptor which produces mock responses based on a list of rules
32pub struct MockResponseInterceptor {
33    rules: Arc<Mutex<VecDeque<Rule>>>,
34    rule_mode: RuleMode,
35    must_match: bool,
36    active_response: Arc<Mutex<Option<MockResponse<Output, Error>>>>,
37}
38
39impl fmt::Debug for MockResponseInterceptor {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        write!(f, "{} rules", self.rules.lock().unwrap().len())
42    }
43}
44
45impl Default for MockResponseInterceptor {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl MockResponseInterceptor {
52    /// Create a new [MockResponseInterceptor]
53    ///
54    /// This is normally created and registered on a client through the [`mock_client`](crate::mock_client) macro.
55    pub fn new() -> Self {
56        Self {
57            rules: Default::default(),
58            rule_mode: RuleMode::MatchAny,
59            must_match: true,
60            active_response: Default::default(),
61        }
62    }
63    /// Add a rule to the Interceptor
64    ///
65    /// Rules are matched in order—this rule will only apply if all previous rules do not match.
66    pub fn with_rule(self, rule: &Rule) -> Self {
67        self.rules.lock().unwrap().push_back(rule.clone());
68        self
69    }
70
71    /// Set the RuleMode to use when evaluating rules.
72    ///
73    /// See `RuleMode` enum for modes and how they are applied.
74    pub fn rule_mode(mut self, rule_mode: RuleMode) -> Self {
75        self.rule_mode = rule_mode;
76        self
77    }
78
79    /// Allow passthrough for unmatched requests.
80    ///
81    /// By default, if a request doesn't match any rule, the interceptor will panic.
82    /// This method allows unmatched requests to pass through.
83    pub fn allow_passthrough(mut self) -> Self {
84        self.must_match = false;
85        self
86    }
87}
88
89impl Intercept for MockResponseInterceptor {
90    fn name(&self) -> &'static str {
91        "MockResponseInterceptor"
92    }
93
94    fn modify_before_serialization(
95        &self,
96        context: &mut BeforeSerializationInterceptorContextMut<'_>,
97        _runtime_components: &RuntimeComponents,
98        cfg: &mut ConfigBag,
99    ) -> Result<(), BoxError> {
100        let mut rules = self.rules.lock().unwrap();
101        let input = context.inner().input().expect("input set");
102
103        // Find a matching rule and get its response
104        let mut matching_rule = None;
105        let mut matching_response = None;
106
107        match self.rule_mode {
108            RuleMode::Sequential => {
109                // Sequential mode requires rules match in-order
110                let i = 0;
111                while i < rules.len() && matching_response.is_none() {
112                    let rule = &rules[i];
113
114                    // Check if the rule is already exhausted
115                    if rule.is_exhausted() {
116                        // Rule is exhausted, remove it and try the next one
117                        rules.remove(i);
118                        continue; // Don't increment i since we removed an element
119                    }
120
121                    // Check if the rule matches
122                    if !(rule.matcher)(input) {
123                        // Rule doesn't match, this is an error in sequential mode
124                        panic!(
125                            "In order matching was enforced but rule did not match {:?}",
126                            input
127                        );
128                    }
129
130                    // Rule matches and is not exhausted, get the response
131                    if let Some(response) = rule.next_response() {
132                        matching_rule = Some(rule.clone());
133                        matching_response = Some(response);
134                    } else {
135                        // Rule is exhausted, remove it and try the next one
136                        rules.remove(i);
137                        continue; // Don't increment i since we removed an element
138                    }
139
140                    // We found a matching rule and got a response, so we're done
141                    break;
142                }
143            }
144            RuleMode::MatchAny => {
145                // Find any matching rule with a response
146                for rule in rules.iter() {
147                    // Skip exhausted rules
148                    if rule.is_exhausted() {
149                        continue;
150                    }
151
152                    if (rule.matcher)(input) {
153                        if let Some(response) = rule.next_response() {
154                            matching_rule = Some(rule.clone());
155                            matching_response = Some(response);
156                            break;
157                        }
158                    }
159                }
160            }
161        };
162
163        match (matching_rule, matching_response) {
164            (Some(rule), Some(response)) => {
165                // Store the rule in the config bag
166                cfg.interceptor_state().store_put(ActiveRule(rule));
167                // store the response on the interceptor (because going
168                // through interceptor context requires the type to impl Clone)
169                let mut active_resp = self.active_response.lock().unwrap();
170                let _ = std::mem::replace(&mut *active_resp, Some(response));
171            }
172            _ => {
173                // No matching rule or no response
174                if self.must_match {
175                    panic!(
176                        "must_match was enabled but no rules matched or all rules were exhausted for {:?}",
177                        input
178                    );
179                }
180            }
181        }
182
183        Ok(())
184    }
185
186    fn modify_before_transmit(
187        &self,
188        context: &mut BeforeTransmitInterceptorContextMut<'_>,
189        _runtime_components: &RuntimeComponents,
190        cfg: &mut ConfigBag,
191    ) -> Result<(), BoxError> {
192        let mut state = self.active_response.lock().unwrap();
193        let mut active_response = (*state).take();
194        if active_response.is_none() {
195            // in the case of retries we try to get the next response if it has been consumed
196            if let Some(active_rule) = cfg.load::<ActiveRule>() {
197                let next_resp = active_rule.0.next_response();
198                active_response = next_resp;
199            }
200        }
201
202        if let Some(resp) = active_response {
203            match resp {
204                // place the http response into the extensions and let the HTTP client return it
205                MockResponse::Http(http_resp) => {
206                    context
207                        .request_mut()
208                        .add_extension(MockHttpResponse(Arc::new(http_resp)));
209                }
210                _ => {
211                    // put it back for modeled output/errors
212                    let _ = std::mem::replace(&mut *state, Some(resp));
213                }
214            }
215        }
216
217        Ok(())
218    }
219
220    fn modify_before_attempt_completion(
221        &self,
222        context: &mut FinalizerInterceptorContextMut<'_>,
223        _runtime_components: &RuntimeComponents,
224        _cfg: &mut ConfigBag,
225    ) -> Result<(), BoxError> {
226        // Handle modeled responses
227        let mut state = self.active_response.lock().unwrap();
228        let active_response = (*state).take();
229        if let Some(resp) = active_response {
230            match resp {
231                MockResponse::Output(output) => {
232                    context.inner_mut().set_output_or_error(Ok(output));
233                }
234                MockResponse::Error(error) => {
235                    context
236                        .inner_mut()
237                        .set_output_or_error(Err(OrchestratorError::operation(error)));
238                }
239                MockResponse::Http(_) => {
240                    // HTTP responses are handled by the mock HTTP client
241                }
242            }
243        }
244
245        Ok(())
246    }
247}
248
249/// Extension for storing mock HTTP responses in request extensions
250#[derive(Clone)]
251struct MockHttpResponse(Arc<HttpResponse>);
252
253/// Create a mock HTTP client that works with the interceptor using existing utilities
254pub fn create_mock_http_client() -> SharedHttpClient {
255    infallible_client_fn(|mut req| {
256        // Try to get the mock HTTP response generator from the extensions
257        if let Some(mock_response) = req.extensions_mut().remove::<MockHttpResponse>() {
258            let http_resp =
259                Arc::try_unwrap(mock_response.0).expect("mock HTTP response has single reference");
260            return http_resp.try_into_http1x().unwrap();
261        }
262
263        // Default dummy response if no mock response is defined
264        http::Response::builder()
265            .status(418)
266            .body(SdkBody::from("Mock HTTP client dummy response"))
267            .unwrap()
268    })
269}
270
271#[cfg(test)]
272mod tests {
273    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
274    use aws_smithy_runtime::client::orchestrator::operation::Operation;
275    use aws_smithy_runtime::client::retries::classifiers::HttpStatusCodeClassifier;
276    use aws_smithy_runtime_api::client::orchestrator::{
277        HttpRequest, HttpResponse, OrchestratorError,
278    };
279    use aws_smithy_runtime_api::client::result::SdkError;
280    use aws_smithy_runtime_api::http::StatusCode;
281    use aws_smithy_types::body::SdkBody;
282    use aws_smithy_types::retry::RetryConfig;
283    use aws_smithy_types::timeout::TimeoutConfig;
284
285    use crate::{create_mock_http_client, MockResponseInterceptor, RuleBuilder, RuleMode};
286    use std::time::Duration;
287
288    // Simple test input and output types
289    #[derive(Debug)]
290    struct TestInput {
291        bucket: String,
292        key: String,
293    }
294    impl TestInput {
295        fn new(bucket: &str, key: &str) -> Self {
296            Self {
297                bucket: bucket.to_string(),
298                key: key.to_string(),
299            }
300        }
301    }
302
303    #[derive(Debug, PartialEq)]
304    struct TestOutput {
305        content: String,
306    }
307
308    impl TestOutput {
309        fn new(content: &str) -> Self {
310            Self {
311                content: content.to_string(),
312            }
313        }
314    }
315
316    #[derive(Debug)]
317    struct TestError {
318        message: String,
319    }
320
321    impl TestError {
322        fn new(message: &str) -> Self {
323            Self {
324                message: message.to_string(),
325            }
326        }
327    }
328
329    impl std::fmt::Display for TestError {
330        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331            write!(f, "{}", self.message)
332        }
333    }
334
335    impl std::error::Error for TestError {}
336
337    // Helper function to create a RuleBuilder with proper type hints
338    fn create_rule_builder() -> RuleBuilder<TestInput, TestOutput, TestError> {
339        RuleBuilder::new_from_mock(
340            || TestInput {
341                bucket: "".to_string(),
342                key: "".to_string(),
343            },
344            || {
345                let fut: std::future::Ready<Result<TestOutput, SdkError<TestError, HttpResponse>>> =
346                    std::future::ready(Ok(TestOutput {
347                        content: "".to_string(),
348                    }));
349                fut
350            },
351        )
352    }
353
354    // Helper function to create an Operation with common configuration
355    fn create_test_operation(
356        interceptor: MockResponseInterceptor,
357        enable_retries: bool,
358    ) -> Operation<TestInput, TestOutput, TestError> {
359        let builder = Operation::builder()
360            .service_name("test")
361            .operation_name("test")
362            .http_client(create_mock_http_client())
363            .endpoint_url("http://localhost:1234")
364            .no_auth()
365            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
366            .timeout_config(TimeoutConfig::disabled())
367            .interceptor(interceptor)
368            .serializer(|input: TestInput| {
369                let mut request = HttpRequest::new(SdkBody::empty());
370                request
371                    .set_uri(format!("/{}/{}", input.bucket, input.key))
372                    .expect("valid URI");
373                Ok(request)
374            })
375            .deserializer::<TestOutput, TestError>(|response| {
376                if response.status().is_success() {
377                    let body = std::str::from_utf8(response.body().bytes().unwrap())
378                        .unwrap_or("empty body")
379                        .to_string();
380                    Ok(TestOutput { content: body })
381                } else {
382                    Err(OrchestratorError::operation(TestError {
383                        message: format!("Error: {}", response.status()),
384                    }))
385                }
386            });
387
388        if enable_retries {
389            let retry_config = RetryConfig::standard()
390                .with_max_attempts(5)
391                .with_initial_backoff(Duration::from_millis(1))
392                .with_max_backoff(Duration::from_millis(5));
393
394            builder
395                .retry_classifier(HttpStatusCodeClassifier::default())
396                .standard_retry(&retry_config)
397                .build()
398        } else {
399            builder.no_retry().build()
400        }
401    }
402
403    #[tokio::test]
404    async fn test_retry_sequence() {
405        // Create a rule with repeated error responses followed by success
406        let rule = create_rule_builder()
407            .match_requests(|input| input.bucket == "test-bucket" && input.key == "test-key")
408            .sequence()
409            .http_status(503, None)
410            .times(2)
411            .output(|| TestOutput::new("success after retries"))
412            .build();
413
414        // Create an interceptor with the rule
415        let interceptor = MockResponseInterceptor::new()
416            .rule_mode(RuleMode::Sequential)
417            .with_rule(&rule);
418
419        let operation = create_test_operation(interceptor, true);
420
421        // Make a single request - it should automatically retry through the sequence
422        let result = operation
423            .invoke(TestInput::new("test-bucket", "test-key"))
424            .await;
425
426        // Should succeed with the final output after retries
427        assert!(
428            result.is_ok(),
429            "Expected success but got error: {:?}",
430            result.err()
431        );
432        assert_eq!(
433            result.unwrap(),
434            TestOutput {
435                content: "success after retries".to_string()
436            }
437        );
438
439        // Verify the rule was used the expected number of times (all 4 responses: 2 errors + 1 success)
440        assert_eq!(rule.num_calls(), 3);
441    }
442
443    #[should_panic(
444        expected = "must_match was enabled but no rules matched or all rules were exhausted for"
445    )]
446    #[tokio::test]
447    async fn test_exhausted_rules() {
448        // Create a rule with a single response
449        let rule = create_rule_builder().then_output(|| TestOutput::new("only response"));
450
451        // Create an interceptor with the rule
452        let interceptor = MockResponseInterceptor::new()
453            .rule_mode(RuleMode::Sequential)
454            .with_rule(&rule);
455
456        let operation = create_test_operation(interceptor, false);
457
458        // First call should succeed
459        let result1 = operation
460            .invoke(TestInput::new("test-bucket", "test-key"))
461            .await;
462        assert!(result1.is_ok());
463
464        // Second call should panic because the rules are exhausted
465        let _result2 = operation
466            .invoke(TestInput::new("test-bucket", "test-key"))
467            .await;
468    }
469
470    #[tokio::test]
471    async fn test_rule_mode_match_any() {
472        // Create two rules with different matchers
473        let rule1 = create_rule_builder()
474            .match_requests(|input| input.bucket == "bucket1")
475            .then_output(|| TestOutput::new("response1"));
476
477        let rule2 = create_rule_builder()
478            .match_requests(|input| input.bucket == "bucket2")
479            .then_output(|| TestOutput::new("response2"));
480
481        // Create an interceptor with both rules in MatchAny mode
482        let interceptor = MockResponseInterceptor::new()
483            .rule_mode(RuleMode::MatchAny)
484            .with_rule(&rule1)
485            .with_rule(&rule2);
486
487        let operation = create_test_operation(interceptor, false);
488
489        // Call with bucket1 should match rule1
490        let result1 = operation
491            .invoke(TestInput::new("bucket1", "test-key"))
492            .await;
493        assert!(result1.is_ok());
494        assert_eq!(result1.unwrap(), TestOutput::new("response1"));
495
496        // Call with bucket2 should match rule2
497        let result2 = operation
498            .invoke(TestInput::new("bucket2", "test-key"))
499            .await;
500        assert!(result2.is_ok());
501        assert_eq!(result2.unwrap(), TestOutput::new("response2"));
502
503        // Verify the rules were used the expected number of times
504        assert_eq!(rule1.num_calls(), 1);
505        assert_eq!(rule2.num_calls(), 1);
506    }
507
508    #[tokio::test]
509    async fn test_mixed_response_types() {
510        // Create a rule with all three types of responses
511        let rule = create_rule_builder()
512            .sequence()
513            .output(|| TestOutput::new("first output"))
514            .error(|| TestError::new("expected error"))
515            .http_response(|| {
516                HttpResponse::new(
517                    StatusCode::try_from(200).unwrap(),
518                    SdkBody::from("http response"),
519                )
520            })
521            .build();
522
523        // Create an interceptor with the rule
524        let interceptor = MockResponseInterceptor::new()
525            .rule_mode(RuleMode::Sequential)
526            .with_rule(&rule);
527
528        let operation = create_test_operation(interceptor, false);
529
530        // First call should return the modeled output
531        let result1 = operation
532            .invoke(TestInput::new("test-bucket", "test-key"))
533            .await;
534        assert!(result1.is_ok());
535        assert_eq!(result1.unwrap(), TestOutput::new("first output"));
536
537        // Second call should return the modeled error
538        let result2 = operation
539            .invoke(TestInput::new("test-bucket", "test-key"))
540            .await;
541        assert!(result2.is_err());
542        let sdk_err = result2.unwrap_err();
543        let err = sdk_err.as_service_error().expect("expected service error");
544        assert_eq!(err.to_string(), "expected error");
545
546        // Third call should return the HTTP response
547        let result3 = operation
548            .invoke(TestInput::new("test-bucket", "test-key"))
549            .await;
550        assert!(result3.is_ok());
551        assert_eq!(result3.unwrap(), TestOutput::new("http response"));
552
553        // Verify the rule was used the expected number of times
554        assert_eq!(rule.num_calls(), 3);
555    }
556
557    #[tokio::test]
558    async fn test_exhausted_sequence() {
559        // Create a rule with a sequence that will be exhausted
560        let rule = create_rule_builder()
561            .sequence()
562            .output(|| TestOutput::new("response 1"))
563            .output(|| TestOutput::new("response 2"))
564            .build();
565
566        // Create another rule to use after the first one is exhausted
567        let fallback_rule =
568            create_rule_builder().then_output(|| TestOutput::new("fallback response"));
569
570        // Create an interceptor with both rules
571        let interceptor = MockResponseInterceptor::new()
572            .rule_mode(RuleMode::Sequential)
573            .with_rule(&rule)
574            .with_rule(&fallback_rule);
575
576        let operation = create_test_operation(interceptor, false);
577
578        // First two calls should use the first rule
579        let result1 = operation
580            .invoke(TestInput::new("test-bucket", "test-key"))
581            .await;
582        assert!(result1.is_ok());
583        assert_eq!(result1.unwrap(), TestOutput::new("response 1"));
584
585        let result2 = operation
586            .invoke(TestInput::new("test-bucket", "test-key"))
587            .await;
588        assert!(result2.is_ok());
589        assert_eq!(result2.unwrap(), TestOutput::new("response 2"));
590
591        // Third call should use the fallback rule
592        let result3 = operation
593            .invoke(TestInput::new("test-bucket", "test-key"))
594            .await;
595        assert!(result3.is_ok());
596        assert_eq!(result3.unwrap(), TestOutput::new("fallback response"));
597
598        // Verify the rules were used the expected number of times
599        assert_eq!(rule.num_calls(), 2);
600        assert_eq!(fallback_rule.num_calls(), 1);
601    }
602
603    #[tokio::test]
604    async fn test_concurrent_usage() {
605        use std::sync::Arc;
606        use tokio::task;
607
608        // Create a rule with multiple responses
609        let rule = Arc::new(
610            create_rule_builder()
611                .sequence()
612                .output(|| TestOutput::new("response 1"))
613                .output(|| TestOutput::new("response 2"))
614                .output(|| TestOutput::new("response 3"))
615                .build(),
616        );
617
618        // Create an interceptor with the rule
619        let interceptor = MockResponseInterceptor::new()
620            .rule_mode(RuleMode::Sequential)
621            .with_rule(&rule);
622
623        let operation = Arc::new(create_test_operation(interceptor, false));
624
625        // Spawn multiple tasks that use the operation concurrently
626        let mut handles = vec![];
627        for i in 0..3 {
628            let op = operation.clone();
629            let handle = task::spawn(async move {
630                let result = op
631                    .invoke(TestInput::new(&format!("bucket-{}", i), "test-key"))
632                    .await;
633                result.unwrap()
634            });
635            handles.push(handle);
636        }
637
638        // Wait for all tasks to complete
639        let mut results = vec![];
640        for handle in handles {
641            results.push(handle.await.unwrap());
642        }
643
644        // Sort the results to make the test deterministic
645        results.sort_by(|a, b| a.content.cmp(&b.content));
646
647        // Verify we got all three responses
648        assert_eq!(results.len(), 3);
649        assert_eq!(results[0], TestOutput::new("response 1"));
650        assert_eq!(results[1], TestOutput::new("response 2"));
651        assert_eq!(results[2], TestOutput::new("response 3"));
652
653        // Verify the rule was used the expected number of times
654        assert_eq!(rule.num_calls(), 3);
655    }
656
657    #[tokio::test]
658    async fn test_sequential_rule_removal() {
659        // Create a rule that matches only when key != "correct-key"
660        let rule1 = create_rule_builder()
661            .match_requests(|input| input.bucket == "test-bucket" && input.key != "correct-key")
662            .then_http_response(|| {
663                HttpResponse::new(
664                    StatusCode::try_from(404).unwrap(),
665                    SdkBody::from("not found"),
666                )
667            });
668
669        // Create a rule that matches only when key == "correct-key"
670        let rule2 = create_rule_builder()
671            .match_requests(|input| input.bucket == "test-bucket" && input.key == "correct-key")
672            .then_output(|| TestOutput::new("success"));
673
674        // Create an interceptor with both rules in Sequential mode
675        let interceptor = MockResponseInterceptor::new()
676            .rule_mode(RuleMode::Sequential)
677            .with_rule(&rule1)
678            .with_rule(&rule2);
679
680        let operation = create_test_operation(interceptor, true);
681
682        // First call with key="foo" should match rule1
683        let result1 = operation.invoke(TestInput::new("test-bucket", "foo")).await;
684        assert!(result1.is_err());
685        assert_eq!(rule1.num_calls(), 1);
686
687        // Second call with key="correct-key" should match rule2
688        // But this will fail if rule1 is not removed after being used
689        let result2 = operation
690            .invoke(TestInput::new("test-bucket", "correct-key"))
691            .await;
692
693        // This should succeed, rule1 doesn't match but should have been removed
694        assert!(result2.is_ok());
695        assert_eq!(result2.unwrap(), TestOutput::new("success"));
696        assert_eq!(rule2.num_calls(), 1);
697    }
698}