aws_smithy_mocks/
rule.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25    /// A successful modeled response.
26    Output(O),
27    /// A modeled error.
28    Error(E),
29    /// An HTTP response.
30    Http(HttpResponse),
31}
32
33/// A function that matches requests.
34type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37/// A rule for matching requests and providing mock responses.
38///
39/// Rules are created using the `mock!` macro or the `RuleBuilder`.
40///
41#[derive(Clone)]
42pub struct Rule {
43    /// Function that determines if this rule matches a request.
44    pub(crate) matcher: MatchFn,
45
46    /// Handler function that generates responses.
47    response_handler: ServeFn,
48
49    /// Number of times this rule has been called.
50    call_count: Arc<AtomicUsize>,
51
52    /// Maximum number of responses this rule will provide.
53    pub(crate) max_responses: usize,
54
55    /// Flag indicating this is a "simple" rule which changes how it is interpreted
56    /// depending on the RuleMode.
57    ///
58    /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
59    is_simple: bool,
60}
61
62impl fmt::Debug for Rule {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "Rule")
65    }
66}
67
68impl Rule {
69    /// Creates a new rule with the given matcher, response handler, and max responses.
70    #[allow(clippy::type_complexity)]
71    pub(crate) fn new<O, E>(
72        matcher: MatchFn,
73        response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
74        max_responses: usize,
75        is_simple: bool,
76    ) -> Self
77    where
78        O: fmt::Debug + Send + Sync + 'static,
79        E: fmt::Debug + Send + Sync + std::error::Error + 'static,
80    {
81        Rule {
82            matcher,
83            response_handler: Arc::new(move |idx: usize, input: &Input| {
84                if idx < max_responses {
85                    response_handler(idx, input).map(|resp| match resp {
86                        MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
87                        MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
88                        MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
89                    })
90                } else {
91                    None
92                }
93            }),
94            call_count: Arc::new(AtomicUsize::new(0)),
95            max_responses,
96            is_simple,
97        }
98    }
99
100    /// Test if this is a "simple" rule (non-sequenced)
101    pub(crate) fn is_simple(&self) -> bool {
102        self.is_simple
103    }
104
105    /// Gets the next response.
106    pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
107        let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
108        (self.response_handler)(idx, input)
109    }
110
111    /// Returns the number of times this rule has been called.
112    pub fn num_calls(&self) -> usize {
113        self.call_count.load(Ordering::SeqCst)
114    }
115
116    /// Checks if this rule is exhausted (has provided all its responses).
117    pub fn is_exhausted(&self) -> bool {
118        self.num_calls() >= self.max_responses
119    }
120}
121
122/// RuleMode describes how rules will be interpreted.
123/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
124/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum RuleMode {
127    /// Match rules in the order they were added. The first matching rule will be applied and the
128    /// rules will remain unchanged
129    Sequential,
130    /// The first matching rule will be applied, and that rule will be removed from the list of rules
131    /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
132    /// will be consumed before moving to the next rule.
133    MatchAny,
134}
135
136/// A builder for creating rules.
137///
138/// This builder provides a fluent API for creating rules with different response types.
139///
140pub struct RuleBuilder<I, O, E> {
141    /// Function that determines if this rule matches a request.
142    pub(crate) input_filter: MatchFn,
143
144    /// Phantom data for the input type.
145    pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
146}
147
148impl<I, O, E> RuleBuilder<I, O, E>
149where
150    I: fmt::Debug + Send + Sync + 'static,
151    O: fmt::Debug + Send + Sync + 'static,
152    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
153{
154    /// Creates a new [`RuleBuilder`]
155    #[doc(hidden)]
156    pub fn new() -> Self {
157        RuleBuilder {
158            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
159            _ty: std::marker::PhantomData,
160        }
161    }
162
163    /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
164    #[doc(hidden)]
165    pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
166    where
167        F: Future<Output = Result<O, SdkError<E, R>>>,
168    {
169        Self {
170            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
171            _ty: Default::default(),
172        }
173    }
174
175    /// Sets the function that determines if this rule matches a request.
176    pub fn match_requests<F>(mut self, filter: F) -> Self
177    where
178        F: Fn(&I) -> bool + Send + Sync + 'static,
179    {
180        self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
181            Some(typed_input) => filter(typed_input),
182            _ => false,
183        });
184        self
185    }
186
187    /// Start building a response sequence
188    ///
189    /// A sequence allows a single rule to generate multiple responses which can
190    /// be used to test retry behavior.
191    ///
192    /// # Examples
193    ///
194    /// With repetition using `times()`:
195    ///
196    /// ```rust,ignore
197    /// let rule = mock!(Client::get_object)
198    ///     .sequence()
199    ///     .http_status(503, None)
200    ///     .times(2)                                        // First two calls return 503
201    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
202    ///     .build();
203    /// ```
204    pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
205        ResponseSequenceBuilder::new(self.input_filter)
206    }
207
208    /// Creates a rule that returns a modeled output.
209    pub fn then_output<F>(self, output_fn: F) -> Rule
210    where
211        F: Fn() -> O + Send + Sync + 'static,
212    {
213        self.sequence().output(output_fn).build_simple()
214    }
215
216    /// Creates a rule that returns a modeled error.
217    pub fn then_error<F>(self, error_fn: F) -> Rule
218    where
219        F: Fn() -> E + Send + Sync + 'static,
220    {
221        self.sequence().error(error_fn).build_simple()
222    }
223
224    /// Creates a rule that returns an HTTP response.
225    pub fn then_http_response<F>(self, response_fn: F) -> Rule
226    where
227        F: Fn() -> HttpResponse + Send + Sync + 'static,
228    {
229        self.sequence().http_response(response_fn).build_simple()
230    }
231
232    /// Creates a rule that computes an output based on the input.
233    ///
234    /// This allows generating responses based on the input request.
235    ///
236    /// # Examples
237    ///
238    /// ```rust,ignore
239    /// let rule = mock!(Client::get_object)
240    ///     .compute_output(|req| {
241    ///         GetObjectOutput::builder()
242    ///             .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
243    ///             .build()
244    ///     })
245    ///     .build();
246    /// ```
247    pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
248    where
249        F: Fn(&I) -> O + Send + Sync + 'static,
250    {
251        self.sequence().compute_output(compute_fn).build_simple()
252    }
253}
254
255type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
256
257/// A builder for creating response sequences
258pub struct ResponseSequenceBuilder<I, O, E> {
259    /// The response generators in the sequence
260    generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
261
262    /// Function that determines if this rule matches a request
263    input_filter: MatchFn,
264
265    /// flag indicating this is a "simple" rule
266    is_simple: bool,
267
268    /// Marker for the input, output, and error types
269    _marker: std::marker::PhantomData<I>,
270}
271
272/// Final sequence builder state  - can only `build()`
273pub struct FinalizedResponseSequenceBuilder<I, O, E> {
274    inner: ResponseSequenceBuilder<I, O, E>,
275}
276
277impl<I, O, E> ResponseSequenceBuilder<I, O, E>
278where
279    I: fmt::Debug + Send + Sync + 'static,
280    O: fmt::Debug + Send + Sync + 'static,
281    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
282{
283    /// Create a new response sequence builder
284    pub(crate) fn new(input_filter: MatchFn) -> Self {
285        Self {
286            generators: Vec::new(),
287            input_filter,
288            is_simple: false,
289            _marker: std::marker::PhantomData,
290        }
291    }
292
293    /// Add a modeled output response to the sequence
294    ///
295    /// # Examples
296    ///
297    /// ```rust,ignore
298    /// let rule = mock!(Client::get_object)
299    ///     .sequence()
300    ///     .output(|| GetObjectOutput::builder().build())
301    ///     .build();
302    /// ```
303    pub fn output<F>(mut self, output_fn: F) -> Self
304    where
305        F: Fn() -> O + Send + Sync + 'static,
306    {
307        let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
308        self.generators.push((generator, 1));
309        self
310    }
311
312    /// Add a modeled error response to the sequence
313    pub fn error<F>(mut self, error_fn: F) -> Self
314    where
315        F: Fn() -> E + Send + Sync + 'static,
316    {
317        let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
318        self.generators.push((generator, 1));
319        self
320    }
321
322    /// Add an HTTP status code response to the sequence
323    pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
324        let status_code = StatusCode::try_from(status).unwrap();
325
326        let generator: SequenceGeneratorFn<O, E> = match body {
327            Some(body) => Arc::new(move |_input: &Input| {
328                MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
329            }),
330            None => Arc::new(move |_input: &Input| {
331                MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
332            }),
333        };
334
335        self.generators.push((generator, 1));
336        self
337    }
338
339    /// Add an HTTP response to the sequence
340    pub fn http_response<F>(mut self, response_fn: F) -> Self
341    where
342        F: Fn() -> HttpResponse + Send + Sync + 'static,
343    {
344        let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
345        self.generators.push((generator, 1));
346        self
347    }
348
349    /// Add a computed output response to the sequence.  Note that this is not `pub`
350    /// because creating computed output rules off of sequenced rules doesn't work,
351    /// as we can't preserve the input across retries.  So we only expose `compute_output`
352    /// on unsequenced rules above.
353    fn compute_output<F>(mut self, compute_fn: F) -> Self
354    where
355        F: Fn(&I) -> O + Send + Sync + 'static,
356    {
357        let generator = Arc::new(move |input: &Input| {
358            if let Some(typed_input) = input.downcast_ref::<I>() {
359                MockResponse::Output(compute_fn(typed_input))
360            } else {
361                panic!("Input type mismatch in compute_output")
362            }
363        });
364        self.generators.push((generator, 1));
365        self
366    }
367
368    /// Repeat the last added response multiple times.
369    ///
370    /// This method sets the number of times the last response in the sequence will be used.
371    /// For example, if you add a response and then call `times(3)`, that response will be
372    /// returned for the next 3 calls to the rule.
373    ///
374    /// # Examples
375    ///
376    /// ```rust,ignore
377    /// // Create a rule that returns 503 twice, then succeeds
378    /// let rule = mock!(Client::get_object)
379    ///     .sequence()
380    ///     .http_status(503, None)
381    ///     .times(2)                                        // First two calls return 503
382    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
383    ///     .build();
384    /// ```
385    ///
386    /// # Panics
387    ///
388    /// Panics if:
389    /// - Called with a count of 0
390    /// - Called before adding any responses to the sequence
391    pub fn times(mut self, count: usize) -> Self {
392        if self.generators.is_empty() {
393            panic!("times(n) called before adding a response to the sequence");
394        }
395        match count {
396            0 => panic!("repeat count must be greater than zero"),
397            1 => {
398                return self;
399            }
400            _ => {}
401        }
402
403        // update the repeat count of the last generator
404        if let Some(last_generator) = self.generators.last_mut() {
405            last_generator.1 = count;
406        }
407        self
408    }
409    /// Make the last response in the sequence repeat indefinitely.
410    ///
411    /// This method causes the last response added to the sequence to be repeated
412    /// forever, making the rule never exhaust. After calling `repeatedly()`,
413    /// no more responses can be added to the sequence.
414    ///
415    /// # Examples
416    ///
417    /// ```rust,ignore
418    /// // Create a rule that returns an error once, then succeeds forever
419    /// let rule = mock!(Client::get_object)
420    ///     .sequence()
421    ///     .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
422    ///     .output(|| GetObjectOutput::builder().build())
423    ///     .repeatedly()
424    ///     .build();
425    ///
426    /// // First call will return NoSuchKey error
427    /// // All subsequent calls will return success
428    /// ```
429    ///
430    /// # Panics
431    ///
432    /// Panics if called before adding any responses to the sequence.
433    pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
434        if self.generators.is_empty() {
435            panic!("repeatedly() called before adding a response to the sequence");
436        }
437        let inner = self.times(usize::MAX);
438        FinalizedResponseSequenceBuilder { inner }
439    }
440
441    /// Build this a "simple" rule (internal detail)
442    pub(crate) fn build_simple(mut self) -> Rule {
443        self.is_simple = true;
444        self.repeatedly().build()
445    }
446
447    /// Build the rule with this response sequence
448    pub fn build(self) -> Rule {
449        let generators = self.generators;
450        let is_simple = self.is_simple;
451
452        // calculate total responses (sum of all repetitions)
453        let total_responses: usize = generators
454            .iter()
455            .map(|(_, count)| *count)
456            .fold(0, |acc, count| acc.saturating_add(count));
457
458        Rule::new(
459            self.input_filter,
460            Arc::new(move |idx, input| {
461                // find which generator to use
462                let mut current_idx = idx;
463                for (generator, repeat_count) in &generators {
464                    if current_idx < *repeat_count {
465                        return Some(generator(input));
466                    }
467                    current_idx -= repeat_count;
468                }
469                None
470            }),
471            total_responses,
472            is_simple,
473        )
474    }
475}
476
477impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
478where
479    I: fmt::Debug + Send + Sync + 'static,
480    O: fmt::Debug + Send + Sync + 'static,
481    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
482{
483    /// Build the rule with this response sequence
484    pub fn build(self) -> Rule {
485        self.inner.build()
486    }
487}