aws_smithy_mocks/
rule.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25    /// A successful modeled response.
26    Output(O),
27    /// A modeled error.
28    Error(E),
29    /// An HTTP response.
30    Http(HttpResponse),
31}
32
33/// A function that matches requests.
34type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37/// A rule for matching requests and providing mock responses.
38///
39/// Rules are created using the `mock!` macro or the `RuleBuilder`.
40///
41#[derive(Clone)]
42pub struct Rule {
43    /// Function that determines if this rule matches a request.
44    pub(crate) matcher: MatchFn,
45
46    /// Handler function that generates responses.
47    response_handler: ServeFn,
48
49    /// Number of times this rule has been called.
50    call_count: Arc<AtomicUsize>,
51
52    /// Maximum number of responses this rule will provide.
53    pub(crate) max_responses: usize,
54
55    /// Flag indicating this is a "simple" rule which changes how it is interpreted
56    /// depending on the RuleMode.
57    ///
58    /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
59    is_simple: bool,
60}
61
62impl fmt::Debug for Rule {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "Rule")
65    }
66}
67
68impl Rule {
69    /// Creates a new rule with the given matcher, response handler, and max responses.
70    pub(crate) fn new<O, E>(
71        matcher: MatchFn,
72        response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
73        max_responses: usize,
74        is_simple: bool,
75    ) -> Self
76    where
77        O: fmt::Debug + Send + Sync + 'static,
78        E: fmt::Debug + Send + Sync + std::error::Error + 'static,
79    {
80        Rule {
81            matcher,
82            response_handler: Arc::new(move |idx: usize| {
83                if idx < max_responses {
84                    response_handler(idx).map(|resp| match resp {
85                        MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
86                        MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
87                        MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
88                    })
89                } else {
90                    None
91                }
92            }),
93            call_count: Arc::new(AtomicUsize::new(0)),
94            max_responses,
95            is_simple,
96        }
97    }
98
99    /// Test if this is a "simple" rule (non-sequenced)
100    pub(crate) fn is_simple(&self) -> bool {
101        self.is_simple
102    }
103
104    /// Gets the next response.
105    pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
106        let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
107        (self.response_handler)(idx)
108    }
109
110    /// Returns the number of times this rule has been called.
111    pub fn num_calls(&self) -> usize {
112        self.call_count.load(Ordering::SeqCst)
113    }
114
115    /// Checks if this rule is exhausted (has provided all its responses).
116    pub fn is_exhausted(&self) -> bool {
117        self.num_calls() >= self.max_responses
118    }
119}
120
121/// RuleMode describes how rules will be interpreted.
122/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
123/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum RuleMode {
126    /// Match rules in the order they were added. The first matching rule will be applied and the
127    /// rules will remain unchanged
128    Sequential,
129    /// The first matching rule will be applied, and that rule will be removed from the list of rules
130    /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
131    /// will be consumed before moving to the next rule.
132    MatchAny,
133}
134
135/// A builder for creating rules.
136///
137/// This builder provides a fluent API for creating rules with different response types.
138///
139pub struct RuleBuilder<I, O, E> {
140    /// Function that determines if this rule matches a request.
141    pub(crate) input_filter: MatchFn,
142
143    /// Phantom data for the input type.
144    pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
145}
146
147impl<I, O, E> RuleBuilder<I, O, E>
148where
149    I: fmt::Debug + Send + Sync + 'static,
150    O: fmt::Debug + Send + Sync + 'static,
151    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
152{
153    /// Creates a new [`RuleBuilder`]
154    #[doc(hidden)]
155    pub fn new() -> Self {
156        RuleBuilder {
157            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
158            _ty: std::marker::PhantomData,
159        }
160    }
161
162    /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
163    #[doc(hidden)]
164    pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
165    where
166        F: Future<Output = Result<O, SdkError<E, R>>>,
167    {
168        Self {
169            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
170            _ty: Default::default(),
171        }
172    }
173
174    /// Sets the function that determines if this rule matches a request.
175    pub fn match_requests<F>(mut self, filter: F) -> Self
176    where
177        F: Fn(&I) -> bool + Send + Sync + 'static,
178    {
179        self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
180            Some(typed_input) => filter(typed_input),
181            _ => false,
182        });
183        self
184    }
185
186    /// Start building a response sequence
187    ///
188    /// A sequence allows a single rule to generate multiple responses which can
189    /// be used to test retry behavior.
190    ///
191    /// # Examples
192    ///
193    /// With repetition using `times()`:
194    ///
195    /// ```rust,ignore
196    /// let rule = mock!(Client::get_object)
197    ///     .sequence()
198    ///     .http_status(503, None)
199    ///     .times(2)                                        // First two calls return 503
200    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
201    ///     .build();
202    /// ```
203    pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
204        ResponseSequenceBuilder::new(self.input_filter)
205    }
206
207    /// Creates a rule that returns a modeled output.
208    pub fn then_output<F>(self, output_fn: F) -> Rule
209    where
210        F: Fn() -> O + Send + Sync + 'static,
211    {
212        self.sequence().output(output_fn).build_simple()
213    }
214
215    /// Creates a rule that returns a modeled error.
216    pub fn then_error<F>(self, error_fn: F) -> Rule
217    where
218        F: Fn() -> E + Send + Sync + 'static,
219    {
220        self.sequence().error(error_fn).build_simple()
221    }
222
223    /// Creates a rule that returns an HTTP response.
224    pub fn then_http_response<F>(self, response_fn: F) -> Rule
225    where
226        F: Fn() -> HttpResponse + Send + Sync + 'static,
227    {
228        self.sequence().http_response(response_fn).build_simple()
229    }
230}
231
232type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
233
234/// A builder for creating response sequences
235pub struct ResponseSequenceBuilder<I, O, E> {
236    /// The response generators in the sequence
237    generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
238
239    /// Function that determines if this rule matches a request
240    input_filter: MatchFn,
241
242    /// flag indicating this is a "simple" rule
243    is_simple: bool,
244
245    /// Marker for the input, output, and error types
246    _marker: std::marker::PhantomData<I>,
247}
248
249/// Final sequence builder state  - can only `build()`
250pub struct FinalizedResponseSequenceBuilder<I, O, E> {
251    inner: ResponseSequenceBuilder<I, O, E>,
252}
253
254impl<I, O, E> ResponseSequenceBuilder<I, O, E>
255where
256    I: fmt::Debug + Send + Sync + 'static,
257    O: fmt::Debug + Send + Sync + 'static,
258    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
259{
260    /// Create a new response sequence builder
261    pub(crate) fn new(input_filter: MatchFn) -> Self {
262        Self {
263            generators: Vec::new(),
264            input_filter,
265            is_simple: false,
266            _marker: std::marker::PhantomData,
267        }
268    }
269
270    /// Add a modeled output response to the sequence
271    ///
272    /// # Examples
273    ///
274    /// ```rust,ignore
275    /// let rule = mock!(Client::get_object)
276    ///     .sequence()
277    ///     .output(|| GetObjectOutput::builder().build())
278    ///     .build();
279    /// ```
280    pub fn output<F>(mut self, output_fn: F) -> Self
281    where
282        F: Fn() -> O + Send + Sync + 'static,
283    {
284        let generator = Arc::new(move || MockResponse::Output(output_fn()));
285        self.generators.push((generator, 1));
286        self
287    }
288
289    /// Add a modeled error response to the sequence
290    pub fn error<F>(mut self, error_fn: F) -> Self
291    where
292        F: Fn() -> E + Send + Sync + 'static,
293    {
294        let generator = Arc::new(move || MockResponse::Error(error_fn()));
295        self.generators.push((generator, 1));
296        self
297    }
298
299    /// Add an HTTP status code response to the sequence
300    pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
301        let status_code = StatusCode::try_from(status).unwrap();
302
303        let generator: SequenceGeneratorFn<O, E> = match body {
304            Some(body) => Arc::new(move || {
305                MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
306            }),
307            None => Arc::new(move || {
308                MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
309            }),
310        };
311
312        self.generators.push((generator, 1));
313        self
314    }
315
316    /// Add an HTTP response to the sequence
317    pub fn http_response<F>(mut self, response_fn: F) -> Self
318    where
319        F: Fn() -> HttpResponse + Send + Sync + 'static,
320    {
321        let generator = Arc::new(move || MockResponse::Http(response_fn()));
322        self.generators.push((generator, 1));
323        self
324    }
325
326    /// Repeat the last added response multiple times.
327    ///
328    /// This method sets the number of times the last response in the sequence will be used.
329    /// For example, if you add a response and then call `times(3)`, that response will be
330    /// returned for the next 3 calls to the rule.
331    ///
332    /// # Examples
333    ///
334    /// ```rust,ignore
335    /// // Create a rule that returns 503 twice, then succeeds
336    /// let rule = mock!(Client::get_object)
337    ///     .sequence()
338    ///     .http_status(503, None)
339    ///     .times(2)                                        // First two calls return 503
340    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
341    ///     .build();
342    /// ```
343    ///
344    /// # Panics
345    ///
346    /// Panics if:
347    /// - Called with a count of 0
348    /// - Called before adding any responses to the sequence
349    pub fn times(mut self, count: usize) -> Self {
350        if self.generators.is_empty() {
351            panic!("times(n) called before adding a response to the sequence");
352        }
353        match count {
354            0 => panic!("repeat count must be greater than zero"),
355            1 => {
356                return self;
357            }
358            _ => {}
359        }
360
361        // update the repeat count of the last generator
362        if let Some(last_generator) = self.generators.last_mut() {
363            last_generator.1 = count;
364        }
365        self
366    }
367    /// Make the last response in the sequence repeat indefinitely.
368    ///
369    /// This method causes the last response added to the sequence to be repeated
370    /// forever, making the rule never exhaust. After calling `repeatedly()`,
371    /// no more responses can be added to the sequence.
372    ///
373    /// # Examples
374    ///
375    /// ```rust,ignore
376    /// // Create a rule that returns an error once, then succeeds forever
377    /// let rule = mock!(Client::get_object)
378    ///     .sequence()
379    ///     .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
380    ///     .output(|| GetObjectOutput::builder().build())
381    ///     .repeatedly()
382    ///     .build();
383    ///
384    /// // First call will return NoSuchKey error
385    /// // All subsequent calls will return success
386    /// ```
387    ///
388    /// # Panics
389    ///
390    /// Panics if called before adding any responses to the sequence.
391    pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
392        if self.generators.is_empty() {
393            panic!("repeatedly() called before adding a response to the sequence");
394        }
395        let inner = self.times(usize::MAX);
396        FinalizedResponseSequenceBuilder { inner }
397    }
398
399    /// Build this a "simple" rule (internal detail)
400    pub(crate) fn build_simple(mut self) -> Rule {
401        self.is_simple = true;
402        self.repeatedly().build()
403    }
404
405    /// Build the rule with this response sequence
406    pub fn build(self) -> Rule {
407        let generators = self.generators;
408        let is_simple = self.is_simple;
409
410        // calculate total responses (sum of all repetitions)
411        let total_responses: usize = generators
412            .iter()
413            .map(|(_, count)| *count)
414            .fold(0, |acc, count| acc.saturating_add(count));
415
416        Rule::new(
417            self.input_filter,
418            Arc::new(move |idx| {
419                // find which generator to use
420                let mut current_idx = idx;
421                for (generator, repeat_count) in &generators {
422                    if current_idx < *repeat_count {
423                        return Some(generator());
424                    }
425                    current_idx -= repeat_count;
426                }
427                None
428            }),
429            total_responses,
430            is_simple,
431        )
432    }
433}
434
435impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
436where
437    I: fmt::Debug + Send + Sync + 'static,
438    O: fmt::Debug + Send + Sync + 'static,
439    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
440{
441    /// Build the rule with this response sequence
442    pub fn build(self) -> Rule {
443        self.inner.build()
444    }
445}