aws_smithy_mocks/
rule.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24#[allow(clippy::large_enum_variant)]
25pub(crate) enum MockResponse<O, E> {
26    /// A successful modeled response.
27    Output(O),
28    /// A modeled error.
29    Error(E),
30    /// An HTTP response.
31    Http(HttpResponse),
32}
33
34/// A function that matches requests.
35type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
36type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
37
38/// A rule for matching requests and providing mock responses.
39///
40/// Rules are created using the `mock!` macro or the `RuleBuilder`.
41///
42#[derive(Clone)]
43pub struct Rule {
44    /// Function that determines if this rule matches a request.
45    pub(crate) matcher: MatchFn,
46
47    /// Handler function that generates responses.
48    response_handler: ServeFn,
49
50    /// Number of times this rule has been called.
51    call_count: Arc<AtomicUsize>,
52
53    /// Maximum number of responses this rule will provide.
54    pub(crate) max_responses: usize,
55
56    /// Flag indicating this is a "simple" rule which changes how it is interpreted
57    /// depending on the RuleMode.
58    ///
59    /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
60    is_simple: bool,
61}
62
63impl fmt::Debug for Rule {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "Rule")
66    }
67}
68
69impl Rule {
70    /// Creates a new rule with the given matcher, response handler, and max responses.
71    #[allow(clippy::type_complexity)]
72    pub(crate) fn new<O, E>(
73        matcher: MatchFn,
74        response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
75        max_responses: usize,
76        is_simple: bool,
77    ) -> Self
78    where
79        O: fmt::Debug + Send + Sync + 'static,
80        E: fmt::Debug + Send + Sync + std::error::Error + 'static,
81    {
82        Rule {
83            matcher,
84            response_handler: Arc::new(move |idx: usize, input: &Input| {
85                if idx < max_responses {
86                    response_handler(idx, input).map(|resp| match resp {
87                        MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
88                        MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
89                        MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
90                    })
91                } else {
92                    None
93                }
94            }),
95            call_count: Arc::new(AtomicUsize::new(0)),
96            max_responses,
97            is_simple,
98        }
99    }
100
101    /// Test if this is a "simple" rule (non-sequenced)
102    pub(crate) fn is_simple(&self) -> bool {
103        self.is_simple
104    }
105
106    /// Gets the next response.
107    pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
108        let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
109        (self.response_handler)(idx, input)
110    }
111
112    /// Returns the number of times this rule has been called.
113    pub fn num_calls(&self) -> usize {
114        self.call_count.load(Ordering::SeqCst)
115    }
116
117    /// Checks if this rule is exhausted (has provided all its responses).
118    pub fn is_exhausted(&self) -> bool {
119        self.num_calls() >= self.max_responses
120    }
121}
122
123/// RuleMode describes how rules will be interpreted.
124/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
125/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum RuleMode {
128    /// Match rules in the order they were added. The first matching rule will be applied and the
129    /// rules will remain unchanged
130    Sequential,
131    /// The first matching rule will be applied, and that rule will be removed from the list of rules
132    /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
133    /// will be consumed before moving to the next rule.
134    MatchAny,
135}
136
137/// A builder for creating rules.
138///
139/// This builder provides a fluent API for creating rules with different response types.
140///
141pub struct RuleBuilder<I, O, E> {
142    /// Function that determines if this rule matches a request.
143    pub(crate) input_filter: MatchFn,
144
145    /// Phantom data for the input type.
146    pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
147}
148
149impl<I, O, E> RuleBuilder<I, O, E>
150where
151    I: fmt::Debug + Send + Sync + 'static,
152    O: fmt::Debug + Send + Sync + 'static,
153    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
154{
155    /// Creates a new [`RuleBuilder`]
156    #[doc(hidden)]
157    pub fn new() -> Self {
158        RuleBuilder {
159            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
160            _ty: std::marker::PhantomData,
161        }
162    }
163
164    /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
165    #[doc(hidden)]
166    pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
167    where
168        F: Future<Output = Result<O, SdkError<E, R>>>,
169    {
170        Self {
171            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
172            _ty: Default::default(),
173        }
174    }
175
176    /// Sets the function that determines if this rule matches a request.
177    pub fn match_requests<F>(mut self, filter: F) -> Self
178    where
179        F: Fn(&I) -> bool + Send + Sync + 'static,
180    {
181        self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
182            Some(typed_input) => filter(typed_input),
183            _ => false,
184        });
185        self
186    }
187
188    /// Start building a response sequence
189    ///
190    /// A sequence allows a single rule to generate multiple responses which can
191    /// be used to test retry behavior.
192    ///
193    /// # Examples
194    ///
195    /// With repetition using `times()`:
196    ///
197    /// ```rust,ignore
198    /// let rule = mock!(Client::get_object)
199    ///     .sequence()
200    ///     .http_status(503, None)
201    ///     .times(2)                                        // First two calls return 503
202    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
203    ///     .build();
204    /// ```
205    pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
206        ResponseSequenceBuilder::new(self.input_filter)
207    }
208
209    /// Creates a rule that returns a modeled output.
210    pub fn then_output<F>(self, output_fn: F) -> Rule
211    where
212        F: Fn() -> O + Send + Sync + 'static,
213    {
214        self.sequence().output(output_fn).build_simple()
215    }
216
217    /// Creates a rule that returns a modeled error.
218    pub fn then_error<F>(self, error_fn: F) -> Rule
219    where
220        F: Fn() -> E + Send + Sync + 'static,
221    {
222        self.sequence().error(error_fn).build_simple()
223    }
224
225    /// Creates a rule that returns an HTTP response.
226    pub fn then_http_response<F>(self, response_fn: F) -> Rule
227    where
228        F: Fn() -> HttpResponse + Send + Sync + 'static,
229    {
230        self.sequence().http_response(response_fn).build_simple()
231    }
232
233    /// Creates a rule that computes an output based on the input.
234    ///
235    /// This allows generating responses based on the input request.
236    ///
237    /// # Examples
238    ///
239    /// ```rust,ignore
240    /// let rule = mock!(Client::get_object)
241    ///     .compute_output(|req| {
242    ///         GetObjectOutput::builder()
243    ///             .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
244    ///             .build()
245    ///     })
246    ///     .build();
247    /// ```
248    pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
249    where
250        F: Fn(&I) -> O + Send + Sync + 'static,
251    {
252        self.sequence().compute_output(compute_fn).build_simple()
253    }
254}
255
256type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
257
258/// A builder for creating response sequences
259pub struct ResponseSequenceBuilder<I, O, E> {
260    /// The response generators in the sequence
261    generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
262
263    /// Function that determines if this rule matches a request
264    input_filter: MatchFn,
265
266    /// flag indicating this is a "simple" rule
267    is_simple: bool,
268
269    /// Marker for the input, output, and error types
270    _marker: std::marker::PhantomData<I>,
271}
272
273/// Final sequence builder state  - can only `build()`
274pub struct FinalizedResponseSequenceBuilder<I, O, E> {
275    inner: ResponseSequenceBuilder<I, O, E>,
276}
277
278impl<I, O, E> ResponseSequenceBuilder<I, O, E>
279where
280    I: fmt::Debug + Send + Sync + 'static,
281    O: fmt::Debug + Send + Sync + 'static,
282    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
283{
284    /// Create a new response sequence builder
285    pub(crate) fn new(input_filter: MatchFn) -> Self {
286        Self {
287            generators: Vec::new(),
288            input_filter,
289            is_simple: false,
290            _marker: std::marker::PhantomData,
291        }
292    }
293
294    /// Add a modeled output response to the sequence
295    ///
296    /// # Examples
297    ///
298    /// ```rust,ignore
299    /// let rule = mock!(Client::get_object)
300    ///     .sequence()
301    ///     .output(|| GetObjectOutput::builder().build())
302    ///     .build();
303    /// ```
304    pub fn output<F>(mut self, output_fn: F) -> Self
305    where
306        F: Fn() -> O + Send + Sync + 'static,
307    {
308        let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
309        self.generators.push((generator, 1));
310        self
311    }
312
313    /// Add a modeled error response to the sequence
314    pub fn error<F>(mut self, error_fn: F) -> Self
315    where
316        F: Fn() -> E + Send + Sync + 'static,
317    {
318        let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
319        self.generators.push((generator, 1));
320        self
321    }
322
323    /// Add an HTTP status code response to the sequence
324    pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
325        let status_code = StatusCode::try_from(status).unwrap();
326
327        let generator: SequenceGeneratorFn<O, E> = match body {
328            Some(body) => Arc::new(move |_input: &Input| {
329                MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
330            }),
331            None => Arc::new(move |_input: &Input| {
332                MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
333            }),
334        };
335
336        self.generators.push((generator, 1));
337        self
338    }
339
340    /// Add an HTTP response to the sequence
341    pub fn http_response<F>(mut self, response_fn: F) -> Self
342    where
343        F: Fn() -> HttpResponse + Send + Sync + 'static,
344    {
345        let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
346        self.generators.push((generator, 1));
347        self
348    }
349
350    /// Add a computed output response to the sequence.  Note that this is not `pub`
351    /// because creating computed output rules off of sequenced rules doesn't work,
352    /// as we can't preserve the input across retries.  So we only expose `compute_output`
353    /// on unsequenced rules above.
354    fn compute_output<F>(mut self, compute_fn: F) -> Self
355    where
356        F: Fn(&I) -> O + Send + Sync + 'static,
357    {
358        let generator = Arc::new(move |input: &Input| {
359            if let Some(typed_input) = input.downcast_ref::<I>() {
360                MockResponse::Output(compute_fn(typed_input))
361            } else {
362                panic!("Input type mismatch in compute_output")
363            }
364        });
365        self.generators.push((generator, 1));
366        self
367    }
368
369    /// Repeat the last added response multiple times.
370    ///
371    /// This method sets the number of times the last response in the sequence will be used.
372    /// For example, if you add a response and then call `times(3)`, that response will be
373    /// returned for the next 3 calls to the rule.
374    ///
375    /// # Examples
376    ///
377    /// ```rust,ignore
378    /// // Create a rule that returns 503 twice, then succeeds
379    /// let rule = mock!(Client::get_object)
380    ///     .sequence()
381    ///     .http_status(503, None)
382    ///     .times(2)                                        // First two calls return 503
383    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
384    ///     .build();
385    /// ```
386    ///
387    /// # Panics
388    ///
389    /// Panics if:
390    /// - Called with a count of 0
391    /// - Called before adding any responses to the sequence
392    pub fn times(mut self, count: usize) -> Self {
393        if self.generators.is_empty() {
394            panic!("times(n) called before adding a response to the sequence");
395        }
396        match count {
397            0 => panic!("repeat count must be greater than zero"),
398            1 => {
399                return self;
400            }
401            _ => {}
402        }
403
404        // update the repeat count of the last generator
405        if let Some(last_generator) = self.generators.last_mut() {
406            last_generator.1 = count;
407        }
408        self
409    }
410    /// Make the last response in the sequence repeat indefinitely.
411    ///
412    /// This method causes the last response added to the sequence to be repeated
413    /// forever, making the rule never exhaust. After calling `repeatedly()`,
414    /// no more responses can be added to the sequence.
415    ///
416    /// # Examples
417    ///
418    /// ```rust,ignore
419    /// // Create a rule that returns an error once, then succeeds forever
420    /// let rule = mock!(Client::get_object)
421    ///     .sequence()
422    ///     .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
423    ///     .output(|| GetObjectOutput::builder().build())
424    ///     .repeatedly()
425    ///     .build();
426    ///
427    /// // First call will return NoSuchKey error
428    /// // All subsequent calls will return success
429    /// ```
430    ///
431    /// # Panics
432    ///
433    /// Panics if called before adding any responses to the sequence.
434    pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
435        if self.generators.is_empty() {
436            panic!("repeatedly() called before adding a response to the sequence");
437        }
438        let inner = self.times(usize::MAX);
439        FinalizedResponseSequenceBuilder { inner }
440    }
441
442    /// Build this a "simple" rule (internal detail)
443    pub(crate) fn build_simple(mut self) -> Rule {
444        self.is_simple = true;
445        self.repeatedly().build()
446    }
447
448    /// Build the rule with this response sequence
449    pub fn build(self) -> Rule {
450        let generators = self.generators;
451        let is_simple = self.is_simple;
452
453        // calculate total responses (sum of all repetitions)
454        let total_responses: usize = generators
455            .iter()
456            .map(|(_, count)| *count)
457            .fold(0, |acc, count| acc.saturating_add(count));
458
459        Rule::new(
460            self.input_filter,
461            Arc::new(move |idx, input| {
462                // find which generator to use
463                let mut current_idx = idx;
464                for (generator, repeat_count) in &generators {
465                    if current_idx < *repeat_count {
466                        return Some(generator(input));
467                    }
468                    current_idx -= repeat_count;
469                }
470                None
471            }),
472            total_responses,
473            is_simple,
474        )
475    }
476}
477
478impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
479where
480    I: fmt::Debug + Send + Sync + 'static,
481    O: fmt::Debug + Send + Sync + 'static,
482    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
483{
484    /// Build the rule with this response sequence
485    pub fn build(self) -> Rule {
486        self.inner.build()
487    }
488}