aws_smithy_mocks/
rule.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25    /// A successful modeled response.
26    Output(O),
27    /// A modeled error.
28    Error(E),
29    /// An HTTP response.
30    Http(HttpResponse),
31}
32
33/// A function that matches requests.
34type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37/// A rule for matching requests and providing mock responses.
38///
39/// Rules are created using the `mock!` macro or the `RuleBuilder`.
40///
41#[derive(Clone)]
42pub struct Rule {
43    /// Function that determines if this rule matches a request.
44    pub(crate) matcher: MatchFn,
45
46    /// Handler function that generates responses.
47    pub(crate) response_handler: ServeFn,
48
49    /// Number of times this rule has been called.
50    pub(crate) call_count: Arc<AtomicUsize>,
51
52    /// Maximum number of responses this rule will provide.
53    pub(crate) max_responses: usize,
54}
55
56impl fmt::Debug for Rule {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        write!(f, "Rule")
59    }
60}
61
62impl Rule {
63    /// Creates a new rule with the given matcher, response handler, and max responses.
64    pub(crate) fn new<O, E>(
65        matcher: MatchFn,
66        response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
67        max_responses: usize,
68    ) -> Self
69    where
70        O: fmt::Debug + Send + Sync + 'static,
71        E: fmt::Debug + Send + Sync + std::error::Error + 'static,
72    {
73        Rule {
74            matcher,
75            response_handler: Arc::new(move |idx: usize| {
76                if idx < max_responses {
77                    response_handler(idx).map(|resp| match resp {
78                        MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
79                        MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
80                        MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
81                    })
82                } else {
83                    None
84                }
85            }),
86            call_count: Arc::new(AtomicUsize::new(0)),
87            max_responses,
88        }
89    }
90
91    /// Gets the next response.
92    pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
93        let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
94        (self.response_handler)(idx)
95    }
96
97    /// Returns the number of times this rule has been called.
98    pub fn num_calls(&self) -> usize {
99        self.call_count.load(Ordering::SeqCst)
100    }
101
102    /// Checks if this rule is exhausted (has provided all its responses).
103    pub fn is_exhausted(&self) -> bool {
104        self.num_calls() >= self.max_responses
105    }
106}
107
108/// RuleMode describes how rules will be interpreted.
109/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
110/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum RuleMode {
113    /// Match rules in the order they were added. The first matching rule will be applied and the
114    /// rules will remain unchanged
115    Sequential,
116    /// The first matching rule will be applied, and that rule will be removed from the list of rules
117    /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
118    /// will be consumed before moving to the next rule.
119    MatchAny,
120}
121
122/// A builder for creating rules.
123///
124/// This builder provides a fluent API for creating rules with different response types.
125///
126pub struct RuleBuilder<I, O, E> {
127    /// Function that determines if this rule matches a request.
128    pub(crate) input_filter: MatchFn,
129
130    /// Phantom data for the input type.
131    pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
132}
133
134impl<I, O, E> RuleBuilder<I, O, E>
135where
136    I: fmt::Debug + Send + Sync + 'static,
137    O: fmt::Debug + Send + Sync + 'static,
138    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
139{
140    /// Creates a new [`RuleBuilder`]
141    #[doc(hidden)]
142    pub fn new() -> Self {
143        RuleBuilder {
144            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
145            _ty: std::marker::PhantomData,
146        }
147    }
148
149    /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
150    #[doc(hidden)]
151    pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
152    where
153        F: Future<Output = Result<O, SdkError<E, R>>>,
154    {
155        Self {
156            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
157            _ty: Default::default(),
158        }
159    }
160
161    /// Sets the function that determines if this rule matches a request.
162    pub fn match_requests<F>(mut self, filter: F) -> Self
163    where
164        F: Fn(&I) -> bool + Send + Sync + 'static,
165    {
166        self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
167            Some(typed_input) => filter(typed_input),
168            _ => false,
169        });
170        self
171    }
172
173    /// Start building a response sequence
174    ///
175    /// A sequence allows a single rule to generate multiple responses which can
176    /// be used to test retry behavior.
177    ///
178    /// # Examples
179    ///
180    /// With repetition using `times()`:
181    ///
182    /// ```rust,ignore
183    /// let rule = mock!(Client::get_object)
184    ///     .sequence()
185    ///     .http_status(503, None)
186    ///     .times(2)                                        // First two calls return 503
187    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
188    ///     .build();
189    /// ```
190    pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
191        ResponseSequenceBuilder::new(self.input_filter)
192    }
193
194    /// Creates a rule that returns a modeled output.
195    pub fn then_output<F>(self, output_fn: F) -> Rule
196    where
197        F: Fn() -> O + Send + Sync + 'static,
198    {
199        self.sequence().output(output_fn).build()
200    }
201
202    /// Creates a rule that returns a modeled error.
203    pub fn then_error<F>(self, error_fn: F) -> Rule
204    where
205        F: Fn() -> E + Send + Sync + 'static,
206    {
207        self.sequence().error(error_fn).build()
208    }
209
210    /// Creates a rule that returns an HTTP response.
211    pub fn then_http_response<F>(self, response_fn: F) -> Rule
212    where
213        F: Fn() -> HttpResponse + Send + Sync + 'static,
214    {
215        self.sequence().http_response(response_fn).build()
216    }
217}
218
219type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
220
221/// A builder for creating response sequences
222pub struct ResponseSequenceBuilder<I, O, E> {
223    /// The response generators in the sequence
224    generators: Vec<SequenceGeneratorFn<O, E>>,
225
226    /// Function that determines if this rule matches a request
227    input_filter: MatchFn,
228
229    /// Marker for the input, output, and error types
230    _marker: std::marker::PhantomData<I>,
231}
232
233impl<I, O, E> ResponseSequenceBuilder<I, O, E>
234where
235    I: fmt::Debug + Send + Sync + 'static,
236    O: fmt::Debug + Send + Sync + 'static,
237    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
238{
239    /// Create a new response sequence builder
240    pub(crate) fn new(input_filter: MatchFn) -> Self {
241        Self {
242            generators: Vec::new(),
243            input_filter,
244            _marker: std::marker::PhantomData,
245        }
246    }
247
248    /// Add a modeled output response to the sequence
249    ///
250    /// # Examples
251    ///
252    /// ```rust,ignore
253    /// let rule = mock!(Client::get_object)
254    ///     .sequence()
255    ///     .output(|| GetObjectOutput::builder().build())
256    ///     .build();
257    /// ```
258    pub fn output<F>(mut self, output_fn: F) -> Self
259    where
260        F: Fn() -> O + Send + Sync + 'static,
261    {
262        let generator = Arc::new(move || MockResponse::Output(output_fn()));
263        self.generators.push(generator);
264        self
265    }
266
267    /// Add a modeled error response to the sequence
268    pub fn error<F>(mut self, error_fn: F) -> Self
269    where
270        F: Fn() -> E + Send + Sync + 'static,
271    {
272        let generator = Arc::new(move || MockResponse::Error(error_fn()));
273        self.generators.push(generator);
274        self
275    }
276
277    /// Add an HTTP status code response to the sequence
278    pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
279        let status_code = StatusCode::try_from(status).unwrap();
280
281        let generator: SequenceGeneratorFn<O, E> = match body {
282            Some(body) => Arc::new(move || {
283                MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
284            }),
285            None => Arc::new(move || {
286                MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
287            }),
288        };
289
290        self.generators.push(generator);
291        self
292    }
293
294    /// Add an HTTP response to the sequence
295    pub fn http_response<F>(mut self, response_fn: F) -> Self
296    where
297        F: Fn() -> HttpResponse + Send + Sync + 'static,
298    {
299        let generator = Arc::new(move || MockResponse::Http(response_fn()));
300        self.generators.push(generator);
301        self
302    }
303
304    /// Repeat the last added response multiple times (total count)
305    ///
306    /// NOTE: `times(1)` has no effect and `times(0)` will panic
307    pub fn times(mut self, count: usize) -> Self {
308        match count {
309            0 => panic!("repeat count must be greater than zero"),
310            1 => {
311                return self;
312            }
313            _ => {}
314        }
315
316        if let Some(last_generator) = self.generators.last().cloned() {
317            // Add count-1 more copies (we already have one)
318            for _ in 1..count {
319                self.generators.push(last_generator.clone());
320            }
321        }
322        self
323    }
324
325    /// Build the rule with this response sequence
326    pub fn build(self) -> Rule {
327        let generators = self.generators;
328        let count = generators.len();
329
330        Rule::new(
331            self.input_filter,
332            Arc::new(move |idx| {
333                if idx < count {
334                    Some(generators[idx]())
335                } else {
336                    None
337                }
338            }),
339            count,
340        )
341    }
342}