aws_smithy_mocks/
rule.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum MockResponse<O, E> {
26    /// A successful modeled response.
27    Output(O),
28    /// A modeled error.
29    Error(E),
30    /// An HTTP response.
31    Http(HttpResponse),
32}
33
34/// A function that matches requests.
35type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
36type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
37
38/// A rule for matching requests and providing mock responses.
39///
40/// Rules are created using the `mock!` macro or the `RuleBuilder`.
41///
42#[derive(Clone)]
43pub struct Rule {
44    /// Function that determines if this rule matches a request.
45    pub(crate) matcher: MatchFn,
46
47    /// Handler function that generates responses.
48    response_handler: ServeFn,
49
50    /// Number of times this rule has been called.
51    call_count: Arc<AtomicUsize>,
52
53    /// Maximum number of responses this rule will provide.
54    pub(crate) max_responses: usize,
55
56    /// Flag indicating this is a "simple" rule which changes how it is interpreted
57    /// depending on the RuleMode.
58    ///
59    /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
60    is_simple: bool,
61}
62
63impl fmt::Debug for Rule {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "Rule")
66    }
67}
68
69impl Rule {
70    /// Creates a new rule with the given matcher, response handler, and max responses.
71    #[allow(clippy::type_complexity)]
72    pub(crate) fn new<O, E>(
73        matcher: MatchFn,
74        response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
75        max_responses: usize,
76        is_simple: bool,
77    ) -> Self
78    where
79        O: fmt::Debug + Send + Sync + 'static,
80        E: fmt::Debug + Send + Sync + std::error::Error + 'static,
81    {
82        Rule {
83            matcher,
84            response_handler: Arc::new(move |idx: usize, input: &Input| {
85                if idx < max_responses {
86                    response_handler(idx, input).map(|resp| match resp {
87                        MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
88                        MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
89                        MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
90                    })
91                } else {
92                    None
93                }
94            }),
95            call_count: Arc::new(AtomicUsize::new(0)),
96            max_responses,
97            is_simple,
98        }
99    }
100
101    /// Test if this is a "simple" rule (non-sequenced)
102    pub(crate) fn is_simple(&self) -> bool {
103        self.is_simple
104    }
105
106    /// Gets the next response.
107    pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
108        let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
109        (self.response_handler)(idx, input)
110    }
111
112    /// Returns the number of times this rule has been called.
113    pub fn num_calls(&self) -> usize {
114        self.call_count.load(Ordering::SeqCst)
115    }
116
117    /// Checks if this rule is exhausted (has provided all its responses).
118    pub fn is_exhausted(&self) -> bool {
119        self.num_calls() >= self.max_responses
120    }
121}
122
123/// RuleMode describes how rules will be interpreted.
124/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
125/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum RuleMode {
128    /// Match rules in the order they were added. The first matching rule will be applied and the
129    /// rules will remain unchanged
130    Sequential,
131    /// The first matching rule will be applied, and that rule will be removed from the list of rules
132    /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
133    /// will be consumed before moving to the next rule.
134    MatchAny,
135}
136
137/// A builder for creating rules.
138///
139/// This builder provides a fluent API for creating rules with different response types.
140///
141pub struct RuleBuilder<I, O, E> {
142    /// Function that determines if this rule matches a request.
143    pub(crate) input_filter: MatchFn,
144
145    /// Phantom data for the input type.
146    pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
147}
148
149impl<I, O, E> RuleBuilder<I, O, E>
150where
151    I: fmt::Debug + Send + Sync + 'static,
152    O: fmt::Debug + Send + Sync + 'static,
153    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
154{
155    /// Creates a new [`RuleBuilder`]
156    #[doc(hidden)]
157    pub fn new() -> Self {
158        RuleBuilder {
159            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
160            _ty: std::marker::PhantomData,
161        }
162    }
163
164    /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
165    #[doc(hidden)]
166    pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
167    where
168        F: Future<Output = Result<O, SdkError<E, R>>>,
169    {
170        Self {
171            input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
172            _ty: Default::default(),
173        }
174    }
175
176    /// Sets the function that determines if this rule matches a request.
177    pub fn match_requests<F>(mut self, filter: F) -> Self
178    where
179        F: Fn(&I) -> bool + Send + Sync + 'static,
180    {
181        self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
182            Some(typed_input) => filter(typed_input),
183            _ => false,
184        });
185        self
186    }
187
188    /// Start building a response sequence
189    ///
190    /// A sequence allows a single rule to generate multiple responses which can
191    /// be used to test retry behavior.
192    ///
193    /// # Examples
194    ///
195    /// With repetition using `times()`:
196    ///
197    /// ```rust,ignore
198    /// let rule = mock!(Client::get_object)
199    ///     .sequence()
200    ///     .http_status(503, None)
201    ///     .times(2)                                        // First two calls return 503
202    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
203    ///     .build();
204    /// ```
205    pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
206        ResponseSequenceBuilder::new(self.input_filter)
207    }
208
209    /// Creates a rule that returns a modeled output.
210    pub fn then_output<F>(self, output_fn: F) -> Rule
211    where
212        F: Fn() -> O + Send + Sync + 'static,
213    {
214        self.sequence().output(output_fn).build_simple()
215    }
216
217    /// Creates a rule that returns a modeled error.
218    pub fn then_error<F>(self, error_fn: F) -> Rule
219    where
220        F: Fn() -> E + Send + Sync + 'static,
221    {
222        self.sequence().error(error_fn).build_simple()
223    }
224
225    /// Creates a rule that returns an HTTP response.
226    pub fn then_http_response<F>(self, response_fn: F) -> Rule
227    where
228        F: Fn() -> HttpResponse + Send + Sync + 'static,
229    {
230        self.sequence().http_response(response_fn).build_simple()
231    }
232
233    /// Creates a rule that computes an output based on the input.
234    ///
235    /// This allows generating responses based on the input request.
236    ///
237    /// # Examples
238    ///
239    /// ```rust,ignore
240    /// let rule = mock!(Client::get_object)
241    ///     .compute_output(|req| {
242    ///         GetObjectOutput::builder()
243    ///             .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
244    ///             .build()
245    ///     })
246    ///     .build();
247    /// ```
248    pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
249    where
250        F: Fn(&I) -> O + Send + Sync + 'static,
251    {
252        self.sequence().compute_output(compute_fn).build_simple()
253    }
254
255    /// Creates a rule that computes an arbitrary response based on the input.
256    ///
257    /// This allows generating any type of response (output, error, or HTTP) based on the input request.
258    /// Unlike `then_compute_output`, this method can return errors or HTTP responses conditionally.
259    ///
260    /// # Examples
261    ///
262    /// ```rust,ignore
263    /// let rule = mock!(Client::get_object)
264    ///     .then_compute_response(|req| {
265    ///         if req.key() == Some("error") {
266    ///             MockResponse::Error(GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
267    ///         } else {
268    ///             MockResponse::Output(GetObjectOutput::builder()
269    ///                 .body(ByteStream::from_static(b"content"))
270    ///                 .build())
271    ///         }
272    ///     })
273    ///     .build();
274    /// ```
275    pub fn then_compute_response<F>(self, compute_fn: F) -> Rule
276    where
277        F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
278    {
279        self.sequence().compute_response(compute_fn).build_simple()
280    }
281}
282
283type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
284
285/// A builder for creating response sequences
286pub struct ResponseSequenceBuilder<I, O, E> {
287    /// The response generators in the sequence
288    generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
289
290    /// Function that determines if this rule matches a request
291    input_filter: MatchFn,
292
293    /// flag indicating this is a "simple" rule
294    is_simple: bool,
295
296    /// Marker for the input, output, and error types
297    _marker: std::marker::PhantomData<I>,
298}
299
300/// Final sequence builder state  - can only `build()`
301pub struct FinalizedResponseSequenceBuilder<I, O, E> {
302    inner: ResponseSequenceBuilder<I, O, E>,
303}
304
305impl<I, O, E> ResponseSequenceBuilder<I, O, E>
306where
307    I: fmt::Debug + Send + Sync + 'static,
308    O: fmt::Debug + Send + Sync + 'static,
309    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
310{
311    /// Create a new response sequence builder
312    pub(crate) fn new(input_filter: MatchFn) -> Self {
313        Self {
314            generators: Vec::new(),
315            input_filter,
316            is_simple: false,
317            _marker: std::marker::PhantomData,
318        }
319    }
320
321    /// Add a modeled output response to the sequence
322    ///
323    /// # Examples
324    ///
325    /// ```rust,ignore
326    /// let rule = mock!(Client::get_object)
327    ///     .sequence()
328    ///     .output(|| GetObjectOutput::builder().build())
329    ///     .build();
330    /// ```
331    pub fn output<F>(mut self, output_fn: F) -> Self
332    where
333        F: Fn() -> O + Send + Sync + 'static,
334    {
335        let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
336        self.generators.push((generator, 1));
337        self
338    }
339
340    /// Add a modeled error response to the sequence
341    pub fn error<F>(mut self, error_fn: F) -> Self
342    where
343        F: Fn() -> E + Send + Sync + 'static,
344    {
345        let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
346        self.generators.push((generator, 1));
347        self
348    }
349
350    /// Add an HTTP status code response to the sequence
351    pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
352        let status_code = StatusCode::try_from(status).unwrap();
353
354        let generator: SequenceGeneratorFn<O, E> = match body {
355            Some(body) => Arc::new(move |_input: &Input| {
356                MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
357            }),
358            None => Arc::new(move |_input: &Input| {
359                MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
360            }),
361        };
362
363        self.generators.push((generator, 1));
364        self
365    }
366
367    /// Add an HTTP response to the sequence
368    pub fn http_response<F>(mut self, response_fn: F) -> Self
369    where
370        F: Fn() -> HttpResponse + Send + Sync + 'static,
371    {
372        let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
373        self.generators.push((generator, 1));
374        self
375    }
376
377    /// Add a computed output response to the sequence.  Note that this is not `pub`
378    /// because creating computed output rules off of sequenced rules doesn't work,
379    /// as we can't preserve the input across retries.  So we only expose `compute_output`
380    /// on unsequenced rules above.
381    fn compute_output<F>(mut self, compute_fn: F) -> Self
382    where
383        F: Fn(&I) -> O + Send + Sync + 'static,
384    {
385        let generator = Arc::new(move |input: &Input| {
386            if let Some(typed_input) = input.downcast_ref::<I>() {
387                MockResponse::Output(compute_fn(typed_input))
388            } else {
389                panic!("Input type mismatch in compute_output")
390            }
391        });
392        self.generators.push((generator, 1));
393        self
394    }
395
396    /// Add a computed response to the sequence. Not `pub` for same reason as `compute_output`.
397    fn compute_response<F>(mut self, compute_fn: F) -> Self
398    where
399        F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
400    {
401        let generator = Arc::new(move |input: &Input| {
402            if let Some(typed_input) = input.downcast_ref::<I>() {
403                compute_fn(typed_input)
404            } else {
405                panic!("Input type mismatch in compute_response")
406            }
407        });
408        self.generators.push((generator, 1));
409        self
410    }
411
412    /// Repeat the last added response multiple times.
413    ///
414    /// This method sets the number of times the last response in the sequence will be used.
415    /// For example, if you add a response and then call `times(3)`, that response will be
416    /// returned for the next 3 calls to the rule.
417    ///
418    /// # Examples
419    ///
420    /// ```rust,ignore
421    /// // Create a rule that returns 503 twice, then succeeds
422    /// let rule = mock!(Client::get_object)
423    ///     .sequence()
424    ///     .http_status(503, None)
425    ///     .times(2)                                        // First two calls return 503
426    ///     .output(|| GetObjectOutput::builder().build())   // Third call succeeds
427    ///     .build();
428    /// ```
429    ///
430    /// # Panics
431    ///
432    /// Panics if:
433    /// - Called with a count of 0
434    /// - Called before adding any responses to the sequence
435    pub fn times(mut self, count: usize) -> Self {
436        if self.generators.is_empty() {
437            panic!("times(n) called before adding a response to the sequence");
438        }
439        match count {
440            0 => panic!("repeat count must be greater than zero"),
441            1 => {
442                return self;
443            }
444            _ => {}
445        }
446
447        // update the repeat count of the last generator
448        if let Some(last_generator) = self.generators.last_mut() {
449            last_generator.1 = count;
450        }
451        self
452    }
453    /// Make the last response in the sequence repeat indefinitely.
454    ///
455    /// This method causes the last response added to the sequence to be repeated
456    /// forever, making the rule never exhaust. After calling `repeatedly()`,
457    /// no more responses can be added to the sequence.
458    ///
459    /// # Examples
460    ///
461    /// ```rust,ignore
462    /// // Create a rule that returns an error once, then succeeds forever
463    /// let rule = mock!(Client::get_object)
464    ///     .sequence()
465    ///     .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
466    ///     .output(|| GetObjectOutput::builder().build())
467    ///     .repeatedly()
468    ///     .build();
469    ///
470    /// // First call will return NoSuchKey error
471    /// // All subsequent calls will return success
472    /// ```
473    ///
474    /// # Panics
475    ///
476    /// Panics if called before adding any responses to the sequence.
477    pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
478        if self.generators.is_empty() {
479            panic!("repeatedly() called before adding a response to the sequence");
480        }
481        let inner = self.times(usize::MAX);
482        FinalizedResponseSequenceBuilder { inner }
483    }
484
485    /// Build this a "simple" rule (internal detail)
486    pub(crate) fn build_simple(mut self) -> Rule {
487        self.is_simple = true;
488        self.repeatedly().build()
489    }
490
491    /// Build the rule with this response sequence
492    pub fn build(self) -> Rule {
493        let generators = self.generators;
494        let is_simple = self.is_simple;
495
496        // calculate total responses (sum of all repetitions)
497        let total_responses: usize = generators
498            .iter()
499            .map(|(_, count)| *count)
500            .fold(0, |acc, count| acc.saturating_add(count));
501
502        Rule::new(
503            self.input_filter,
504            Arc::new(move |idx, input| {
505                // find which generator to use
506                let mut current_idx = idx;
507                for (generator, repeat_count) in &generators {
508                    if current_idx < *repeat_count {
509                        return Some(generator(input));
510                    }
511                    current_idx -= repeat_count;
512                }
513                None
514            }),
515            total_responses,
516            is_simple,
517        )
518    }
519}
520
521impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
522where
523    I: fmt::Debug + Send + Sync + 'static,
524    O: fmt::Debug + Send + Sync + 'static,
525    E: fmt::Debug + Send + Sync + std::error::Error + 'static,
526{
527    /// Build the rule with this response sequence
528    pub fn build(self) -> Rule {
529        self.inner.build()
530    }
531}