aws_smithy_mocks/rule.rs
1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25 /// A successful modeled response.
26 Output(O),
27 /// A modeled error.
28 Error(E),
29 /// An HTTP response.
30 Http(HttpResponse),
31}
32
33/// A function that matches requests.
34type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37/// A rule for matching requests and providing mock responses.
38///
39/// Rules are created using the `mock!` macro or the `RuleBuilder`.
40///
41#[derive(Clone)]
42pub struct Rule {
43 /// Function that determines if this rule matches a request.
44 pub(crate) matcher: MatchFn,
45
46 /// Handler function that generates responses.
47 response_handler: ServeFn,
48
49 /// Number of times this rule has been called.
50 call_count: Arc<AtomicUsize>,
51
52 /// Maximum number of responses this rule will provide.
53 pub(crate) max_responses: usize,
54
55 /// Flag indicating this is a "simple" rule which changes how it is interpreted
56 /// depending on the RuleMode.
57 ///
58 /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
59 is_simple: bool,
60}
61
62impl fmt::Debug for Rule {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "Rule")
65 }
66}
67
68impl Rule {
69 /// Creates a new rule with the given matcher, response handler, and max responses.
70 #[allow(clippy::type_complexity)]
71 pub(crate) fn new<O, E>(
72 matcher: MatchFn,
73 response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
74 max_responses: usize,
75 is_simple: bool,
76 ) -> Self
77 where
78 O: fmt::Debug + Send + Sync + 'static,
79 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
80 {
81 Rule {
82 matcher,
83 response_handler: Arc::new(move |idx: usize, input: &Input| {
84 if idx < max_responses {
85 response_handler(idx, input).map(|resp| match resp {
86 MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
87 MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
88 MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
89 })
90 } else {
91 None
92 }
93 }),
94 call_count: Arc::new(AtomicUsize::new(0)),
95 max_responses,
96 is_simple,
97 }
98 }
99
100 /// Test if this is a "simple" rule (non-sequenced)
101 pub(crate) fn is_simple(&self) -> bool {
102 self.is_simple
103 }
104
105 /// Gets the next response.
106 pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
107 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
108 (self.response_handler)(idx, input)
109 }
110
111 /// Returns the number of times this rule has been called.
112 pub fn num_calls(&self) -> usize {
113 self.call_count.load(Ordering::SeqCst)
114 }
115
116 /// Checks if this rule is exhausted (has provided all its responses).
117 pub fn is_exhausted(&self) -> bool {
118 self.num_calls() >= self.max_responses
119 }
120}
121
122/// RuleMode describes how rules will be interpreted.
123/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
124/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum RuleMode {
127 /// Match rules in the order they were added. The first matching rule will be applied and the
128 /// rules will remain unchanged
129 Sequential,
130 /// The first matching rule will be applied, and that rule will be removed from the list of rules
131 /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
132 /// will be consumed before moving to the next rule.
133 MatchAny,
134}
135
136/// A builder for creating rules.
137///
138/// This builder provides a fluent API for creating rules with different response types.
139///
140pub struct RuleBuilder<I, O, E> {
141 /// Function that determines if this rule matches a request.
142 pub(crate) input_filter: MatchFn,
143
144 /// Phantom data for the input type.
145 pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
146}
147
148impl<I, O, E> RuleBuilder<I, O, E>
149where
150 I: fmt::Debug + Send + Sync + 'static,
151 O: fmt::Debug + Send + Sync + 'static,
152 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
153{
154 /// Creates a new [`RuleBuilder`]
155 #[doc(hidden)]
156 pub fn new() -> Self {
157 RuleBuilder {
158 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
159 _ty: std::marker::PhantomData,
160 }
161 }
162
163 /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
164 #[doc(hidden)]
165 pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
166 where
167 F: Future<Output = Result<O, SdkError<E, R>>>,
168 {
169 Self {
170 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
171 _ty: Default::default(),
172 }
173 }
174
175 /// Sets the function that determines if this rule matches a request.
176 pub fn match_requests<F>(mut self, filter: F) -> Self
177 where
178 F: Fn(&I) -> bool + Send + Sync + 'static,
179 {
180 self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
181 Some(typed_input) => filter(typed_input),
182 _ => false,
183 });
184 self
185 }
186
187 /// Start building a response sequence
188 ///
189 /// A sequence allows a single rule to generate multiple responses which can
190 /// be used to test retry behavior.
191 ///
192 /// # Examples
193 ///
194 /// With repetition using `times()`:
195 ///
196 /// ```rust,ignore
197 /// let rule = mock!(Client::get_object)
198 /// .sequence()
199 /// .http_status(503, None)
200 /// .times(2) // First two calls return 503
201 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
202 /// .build();
203 /// ```
204 pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
205 ResponseSequenceBuilder::new(self.input_filter)
206 }
207
208 /// Creates a rule that returns a modeled output.
209 pub fn then_output<F>(self, output_fn: F) -> Rule
210 where
211 F: Fn() -> O + Send + Sync + 'static,
212 {
213 self.sequence().output(output_fn).build_simple()
214 }
215
216 /// Creates a rule that returns a modeled error.
217 pub fn then_error<F>(self, error_fn: F) -> Rule
218 where
219 F: Fn() -> E + Send + Sync + 'static,
220 {
221 self.sequence().error(error_fn).build_simple()
222 }
223
224 /// Creates a rule that returns an HTTP response.
225 pub fn then_http_response<F>(self, response_fn: F) -> Rule
226 where
227 F: Fn() -> HttpResponse + Send + Sync + 'static,
228 {
229 self.sequence().http_response(response_fn).build_simple()
230 }
231
232 /// Creates a rule that computes an output based on the input.
233 ///
234 /// This allows generating responses based on the input request.
235 ///
236 /// # Examples
237 ///
238 /// ```rust,ignore
239 /// let rule = mock!(Client::get_object)
240 /// .compute_output(|req| {
241 /// GetObjectOutput::builder()
242 /// .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
243 /// .build()
244 /// })
245 /// .build();
246 /// ```
247 pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
248 where
249 F: Fn(&I) -> O + Send + Sync + 'static,
250 {
251 self.sequence().compute_output(compute_fn).build_simple()
252 }
253}
254
255type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
256
257/// A builder for creating response sequences
258pub struct ResponseSequenceBuilder<I, O, E> {
259 /// The response generators in the sequence
260 generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
261
262 /// Function that determines if this rule matches a request
263 input_filter: MatchFn,
264
265 /// flag indicating this is a "simple" rule
266 is_simple: bool,
267
268 /// Marker for the input, output, and error types
269 _marker: std::marker::PhantomData<I>,
270}
271
272/// Final sequence builder state - can only `build()`
273pub struct FinalizedResponseSequenceBuilder<I, O, E> {
274 inner: ResponseSequenceBuilder<I, O, E>,
275}
276
277impl<I, O, E> ResponseSequenceBuilder<I, O, E>
278where
279 I: fmt::Debug + Send + Sync + 'static,
280 O: fmt::Debug + Send + Sync + 'static,
281 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
282{
283 /// Create a new response sequence builder
284 pub(crate) fn new(input_filter: MatchFn) -> Self {
285 Self {
286 generators: Vec::new(),
287 input_filter,
288 is_simple: false,
289 _marker: std::marker::PhantomData,
290 }
291 }
292
293 /// Add a modeled output response to the sequence
294 ///
295 /// # Examples
296 ///
297 /// ```rust,ignore
298 /// let rule = mock!(Client::get_object)
299 /// .sequence()
300 /// .output(|| GetObjectOutput::builder().build())
301 /// .build();
302 /// ```
303 pub fn output<F>(mut self, output_fn: F) -> Self
304 where
305 F: Fn() -> O + Send + Sync + 'static,
306 {
307 let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
308 self.generators.push((generator, 1));
309 self
310 }
311
312 /// Add a modeled error response to the sequence
313 pub fn error<F>(mut self, error_fn: F) -> Self
314 where
315 F: Fn() -> E + Send + Sync + 'static,
316 {
317 let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
318 self.generators.push((generator, 1));
319 self
320 }
321
322 /// Add an HTTP status code response to the sequence
323 pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
324 let status_code = StatusCode::try_from(status).unwrap();
325
326 let generator: SequenceGeneratorFn<O, E> = match body {
327 Some(body) => Arc::new(move |_input: &Input| {
328 MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
329 }),
330 None => Arc::new(move |_input: &Input| {
331 MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
332 }),
333 };
334
335 self.generators.push((generator, 1));
336 self
337 }
338
339 /// Add an HTTP response to the sequence
340 pub fn http_response<F>(mut self, response_fn: F) -> Self
341 where
342 F: Fn() -> HttpResponse + Send + Sync + 'static,
343 {
344 let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
345 self.generators.push((generator, 1));
346 self
347 }
348
349 /// Add a computed output response to the sequence. Note that this is not `pub`
350 /// because creating computed output rules off of sequenced rules doesn't work,
351 /// as we can't preserve the input across retries. So we only expose `compute_output`
352 /// on unsequenced rules above.
353 fn compute_output<F>(mut self, compute_fn: F) -> Self
354 where
355 F: Fn(&I) -> O + Send + Sync + 'static,
356 {
357 let generator = Arc::new(move |input: &Input| {
358 if let Some(typed_input) = input.downcast_ref::<I>() {
359 MockResponse::Output(compute_fn(typed_input))
360 } else {
361 panic!("Input type mismatch in compute_output")
362 }
363 });
364 self.generators.push((generator, 1));
365 self
366 }
367
368 /// Repeat the last added response multiple times.
369 ///
370 /// This method sets the number of times the last response in the sequence will be used.
371 /// For example, if you add a response and then call `times(3)`, that response will be
372 /// returned for the next 3 calls to the rule.
373 ///
374 /// # Examples
375 ///
376 /// ```rust,ignore
377 /// // Create a rule that returns 503 twice, then succeeds
378 /// let rule = mock!(Client::get_object)
379 /// .sequence()
380 /// .http_status(503, None)
381 /// .times(2) // First two calls return 503
382 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
383 /// .build();
384 /// ```
385 ///
386 /// # Panics
387 ///
388 /// Panics if:
389 /// - Called with a count of 0
390 /// - Called before adding any responses to the sequence
391 pub fn times(mut self, count: usize) -> Self {
392 if self.generators.is_empty() {
393 panic!("times(n) called before adding a response to the sequence");
394 }
395 match count {
396 0 => panic!("repeat count must be greater than zero"),
397 1 => {
398 return self;
399 }
400 _ => {}
401 }
402
403 // update the repeat count of the last generator
404 if let Some(last_generator) = self.generators.last_mut() {
405 last_generator.1 = count;
406 }
407 self
408 }
409 /// Make the last response in the sequence repeat indefinitely.
410 ///
411 /// This method causes the last response added to the sequence to be repeated
412 /// forever, making the rule never exhaust. After calling `repeatedly()`,
413 /// no more responses can be added to the sequence.
414 ///
415 /// # Examples
416 ///
417 /// ```rust,ignore
418 /// // Create a rule that returns an error once, then succeeds forever
419 /// let rule = mock!(Client::get_object)
420 /// .sequence()
421 /// .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
422 /// .output(|| GetObjectOutput::builder().build())
423 /// .repeatedly()
424 /// .build();
425 ///
426 /// // First call will return NoSuchKey error
427 /// // All subsequent calls will return success
428 /// ```
429 ///
430 /// # Panics
431 ///
432 /// Panics if called before adding any responses to the sequence.
433 pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
434 if self.generators.is_empty() {
435 panic!("repeatedly() called before adding a response to the sequence");
436 }
437 let inner = self.times(usize::MAX);
438 FinalizedResponseSequenceBuilder { inner }
439 }
440
441 /// Build this a "simple" rule (internal detail)
442 pub(crate) fn build_simple(mut self) -> Rule {
443 self.is_simple = true;
444 self.repeatedly().build()
445 }
446
447 /// Build the rule with this response sequence
448 pub fn build(self) -> Rule {
449 let generators = self.generators;
450 let is_simple = self.is_simple;
451
452 // calculate total responses (sum of all repetitions)
453 let total_responses: usize = generators
454 .iter()
455 .map(|(_, count)| *count)
456 .fold(0, |acc, count| acc.saturating_add(count));
457
458 Rule::new(
459 self.input_filter,
460 Arc::new(move |idx, input| {
461 // find which generator to use
462 let mut current_idx = idx;
463 for (generator, repeat_count) in &generators {
464 if current_idx < *repeat_count {
465 return Some(generator(input));
466 }
467 current_idx -= repeat_count;
468 }
469 None
470 }),
471 total_responses,
472 is_simple,
473 )
474 }
475}
476
477impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
478where
479 I: fmt::Debug + Send + Sync + 'static,
480 O: fmt::Debug + Send + Sync + 'static,
481 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
482{
483 /// Build the rule with this response sequence
484 pub fn build(self) -> Rule {
485 self.inner.build()
486 }
487}