aws_smithy_mocks/rule.rs
1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum MockResponse<O, E> {
26 /// A successful modeled response.
27 Output(O),
28 /// A modeled error.
29 Error(E),
30 /// An HTTP response.
31 Http(HttpResponse),
32}
33
34/// A function that matches requests.
35type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
36type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
37
38/// A rule for matching requests and providing mock responses.
39///
40/// Rules are created using the `mock!` macro or the `RuleBuilder`.
41///
42#[derive(Clone)]
43pub struct Rule {
44 /// Function that determines if this rule matches a request.
45 pub(crate) matcher: MatchFn,
46
47 /// Handler function that generates responses.
48 response_handler: ServeFn,
49
50 /// Number of times this rule has been called.
51 call_count: Arc<AtomicUsize>,
52
53 /// Maximum number of responses this rule will provide.
54 pub(crate) max_responses: usize,
55
56 /// Flag indicating this is a "simple" rule which changes how it is interpreted
57 /// depending on the RuleMode.
58 ///
59 /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
60 is_simple: bool,
61}
62
63impl fmt::Debug for Rule {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(f, "Rule")
66 }
67}
68
69impl Rule {
70 /// Creates a new rule with the given matcher, response handler, and max responses.
71 #[allow(clippy::type_complexity)]
72 pub(crate) fn new<O, E>(
73 matcher: MatchFn,
74 response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
75 max_responses: usize,
76 is_simple: bool,
77 ) -> Self
78 where
79 O: fmt::Debug + Send + Sync + 'static,
80 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
81 {
82 Rule {
83 matcher,
84 response_handler: Arc::new(move |idx: usize, input: &Input| {
85 if idx < max_responses {
86 response_handler(idx, input).map(|resp| match resp {
87 MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
88 MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
89 MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
90 })
91 } else {
92 None
93 }
94 }),
95 call_count: Arc::new(AtomicUsize::new(0)),
96 max_responses,
97 is_simple,
98 }
99 }
100
101 /// Test if this is a "simple" rule (non-sequenced)
102 pub(crate) fn is_simple(&self) -> bool {
103 self.is_simple
104 }
105
106 /// Gets the next response.
107 pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
108 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
109 (self.response_handler)(idx, input)
110 }
111
112 /// Returns the number of times this rule has been called.
113 pub fn num_calls(&self) -> usize {
114 self.call_count.load(Ordering::SeqCst)
115 }
116
117 /// Checks if this rule is exhausted (has provided all its responses).
118 pub fn is_exhausted(&self) -> bool {
119 self.num_calls() >= self.max_responses
120 }
121}
122
123/// RuleMode describes how rules will be interpreted.
124/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
125/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum RuleMode {
128 /// Match rules in the order they were added. The first matching rule will be applied and the
129 /// rules will remain unchanged
130 Sequential,
131 /// The first matching rule will be applied, and that rule will be removed from the list of rules
132 /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
133 /// will be consumed before moving to the next rule.
134 MatchAny,
135}
136
137/// A builder for creating rules.
138///
139/// This builder provides a fluent API for creating rules with different response types.
140///
141pub struct RuleBuilder<I, O, E> {
142 /// Function that determines if this rule matches a request.
143 pub(crate) input_filter: MatchFn,
144
145 /// Phantom data for the input type.
146 pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
147}
148
149impl<I, O, E> RuleBuilder<I, O, E>
150where
151 I: fmt::Debug + Send + Sync + 'static,
152 O: fmt::Debug + Send + Sync + 'static,
153 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
154{
155 /// Creates a new [`RuleBuilder`]
156 #[doc(hidden)]
157 pub fn new() -> Self {
158 RuleBuilder {
159 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
160 _ty: std::marker::PhantomData,
161 }
162 }
163
164 /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
165 #[doc(hidden)]
166 pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
167 where
168 F: Future<Output = Result<O, SdkError<E, R>>>,
169 {
170 Self {
171 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
172 _ty: Default::default(),
173 }
174 }
175
176 /// Sets the function that determines if this rule matches a request.
177 pub fn match_requests<F>(mut self, filter: F) -> Self
178 where
179 F: Fn(&I) -> bool + Send + Sync + 'static,
180 {
181 self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
182 Some(typed_input) => filter(typed_input),
183 _ => false,
184 });
185 self
186 }
187
188 /// Start building a response sequence
189 ///
190 /// A sequence allows a single rule to generate multiple responses which can
191 /// be used to test retry behavior.
192 ///
193 /// # Examples
194 ///
195 /// With repetition using `times()`:
196 ///
197 /// ```rust,ignore
198 /// let rule = mock!(Client::get_object)
199 /// .sequence()
200 /// .http_status(503, None)
201 /// .times(2) // First two calls return 503
202 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
203 /// .build();
204 /// ```
205 pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
206 ResponseSequenceBuilder::new(self.input_filter)
207 }
208
209 /// Creates a rule that returns a modeled output.
210 pub fn then_output<F>(self, output_fn: F) -> Rule
211 where
212 F: Fn() -> O + Send + Sync + 'static,
213 {
214 self.sequence().output(output_fn).build_simple()
215 }
216
217 /// Creates a rule that returns a modeled error.
218 pub fn then_error<F>(self, error_fn: F) -> Rule
219 where
220 F: Fn() -> E + Send + Sync + 'static,
221 {
222 self.sequence().error(error_fn).build_simple()
223 }
224
225 /// Creates a rule that returns an HTTP response.
226 pub fn then_http_response<F>(self, response_fn: F) -> Rule
227 where
228 F: Fn() -> HttpResponse + Send + Sync + 'static,
229 {
230 self.sequence().http_response(response_fn).build_simple()
231 }
232
233 /// Creates a rule that computes an output based on the input.
234 ///
235 /// This allows generating responses based on the input request.
236 ///
237 /// # Examples
238 ///
239 /// ```rust,ignore
240 /// let rule = mock!(Client::get_object)
241 /// .compute_output(|req| {
242 /// GetObjectOutput::builder()
243 /// .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
244 /// .build()
245 /// })
246 /// .build();
247 /// ```
248 pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
249 where
250 F: Fn(&I) -> O + Send + Sync + 'static,
251 {
252 self.sequence().compute_output(compute_fn).build_simple()
253 }
254
255 /// Creates a rule that computes an arbitrary response based on the input.
256 ///
257 /// This allows generating any type of response (output, error, or HTTP) based on the input request.
258 /// Unlike `then_compute_output`, this method can return errors or HTTP responses conditionally.
259 ///
260 /// # Examples
261 ///
262 /// ```rust,ignore
263 /// let rule = mock!(Client::get_object)
264 /// .then_compute_response(|req| {
265 /// if req.key() == Some("error") {
266 /// MockResponse::Error(GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
267 /// } else {
268 /// MockResponse::Output(GetObjectOutput::builder()
269 /// .body(ByteStream::from_static(b"content"))
270 /// .build())
271 /// }
272 /// })
273 /// .build();
274 /// ```
275 pub fn then_compute_response<F>(self, compute_fn: F) -> Rule
276 where
277 F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
278 {
279 self.sequence().compute_response(compute_fn).build_simple()
280 }
281}
282
283type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
284
285/// A builder for creating response sequences
286pub struct ResponseSequenceBuilder<I, O, E> {
287 /// The response generators in the sequence
288 generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
289
290 /// Function that determines if this rule matches a request
291 input_filter: MatchFn,
292
293 /// flag indicating this is a "simple" rule
294 is_simple: bool,
295
296 /// Marker for the input, output, and error types
297 _marker: std::marker::PhantomData<I>,
298}
299
300/// Final sequence builder state - can only `build()`
301pub struct FinalizedResponseSequenceBuilder<I, O, E> {
302 inner: ResponseSequenceBuilder<I, O, E>,
303}
304
305impl<I, O, E> ResponseSequenceBuilder<I, O, E>
306where
307 I: fmt::Debug + Send + Sync + 'static,
308 O: fmt::Debug + Send + Sync + 'static,
309 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
310{
311 /// Create a new response sequence builder
312 pub(crate) fn new(input_filter: MatchFn) -> Self {
313 Self {
314 generators: Vec::new(),
315 input_filter,
316 is_simple: false,
317 _marker: std::marker::PhantomData,
318 }
319 }
320
321 /// Add a modeled output response to the sequence
322 ///
323 /// # Examples
324 ///
325 /// ```rust,ignore
326 /// let rule = mock!(Client::get_object)
327 /// .sequence()
328 /// .output(|| GetObjectOutput::builder().build())
329 /// .build();
330 /// ```
331 pub fn output<F>(mut self, output_fn: F) -> Self
332 where
333 F: Fn() -> O + Send + Sync + 'static,
334 {
335 let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
336 self.generators.push((generator, 1));
337 self
338 }
339
340 /// Add a modeled error response to the sequence
341 pub fn error<F>(mut self, error_fn: F) -> Self
342 where
343 F: Fn() -> E + Send + Sync + 'static,
344 {
345 let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
346 self.generators.push((generator, 1));
347 self
348 }
349
350 /// Add an HTTP status code response to the sequence
351 pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
352 let status_code = StatusCode::try_from(status).unwrap();
353
354 let generator: SequenceGeneratorFn<O, E> = match body {
355 Some(body) => Arc::new(move |_input: &Input| {
356 MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
357 }),
358 None => Arc::new(move |_input: &Input| {
359 MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
360 }),
361 };
362
363 self.generators.push((generator, 1));
364 self
365 }
366
367 /// Add an HTTP response to the sequence
368 pub fn http_response<F>(mut self, response_fn: F) -> Self
369 where
370 F: Fn() -> HttpResponse + Send + Sync + 'static,
371 {
372 let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
373 self.generators.push((generator, 1));
374 self
375 }
376
377 /// Add a computed output response to the sequence. Note that this is not `pub`
378 /// because creating computed output rules off of sequenced rules doesn't work,
379 /// as we can't preserve the input across retries. So we only expose `compute_output`
380 /// on unsequenced rules above.
381 fn compute_output<F>(mut self, compute_fn: F) -> Self
382 where
383 F: Fn(&I) -> O + Send + Sync + 'static,
384 {
385 let generator = Arc::new(move |input: &Input| {
386 if let Some(typed_input) = input.downcast_ref::<I>() {
387 MockResponse::Output(compute_fn(typed_input))
388 } else {
389 panic!("Input type mismatch in compute_output")
390 }
391 });
392 self.generators.push((generator, 1));
393 self
394 }
395
396 /// Add a computed response to the sequence. Not `pub` for same reason as `compute_output`.
397 fn compute_response<F>(mut self, compute_fn: F) -> Self
398 where
399 F: Fn(&I) -> MockResponse<O, E> + Send + Sync + 'static,
400 {
401 let generator = Arc::new(move |input: &Input| {
402 if let Some(typed_input) = input.downcast_ref::<I>() {
403 compute_fn(typed_input)
404 } else {
405 panic!("Input type mismatch in compute_response")
406 }
407 });
408 self.generators.push((generator, 1));
409 self
410 }
411
412 /// Repeat the last added response multiple times.
413 ///
414 /// This method sets the number of times the last response in the sequence will be used.
415 /// For example, if you add a response and then call `times(3)`, that response will be
416 /// returned for the next 3 calls to the rule.
417 ///
418 /// # Examples
419 ///
420 /// ```rust,ignore
421 /// // Create a rule that returns 503 twice, then succeeds
422 /// let rule = mock!(Client::get_object)
423 /// .sequence()
424 /// .http_status(503, None)
425 /// .times(2) // First two calls return 503
426 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
427 /// .build();
428 /// ```
429 ///
430 /// # Panics
431 ///
432 /// Panics if:
433 /// - Called with a count of 0
434 /// - Called before adding any responses to the sequence
435 pub fn times(mut self, count: usize) -> Self {
436 if self.generators.is_empty() {
437 panic!("times(n) called before adding a response to the sequence");
438 }
439 match count {
440 0 => panic!("repeat count must be greater than zero"),
441 1 => {
442 return self;
443 }
444 _ => {}
445 }
446
447 // update the repeat count of the last generator
448 if let Some(last_generator) = self.generators.last_mut() {
449 last_generator.1 = count;
450 }
451 self
452 }
453 /// Make the last response in the sequence repeat indefinitely.
454 ///
455 /// This method causes the last response added to the sequence to be repeated
456 /// forever, making the rule never exhaust. After calling `repeatedly()`,
457 /// no more responses can be added to the sequence.
458 ///
459 /// # Examples
460 ///
461 /// ```rust,ignore
462 /// // Create a rule that returns an error once, then succeeds forever
463 /// let rule = mock!(Client::get_object)
464 /// .sequence()
465 /// .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
466 /// .output(|| GetObjectOutput::builder().build())
467 /// .repeatedly()
468 /// .build();
469 ///
470 /// // First call will return NoSuchKey error
471 /// // All subsequent calls will return success
472 /// ```
473 ///
474 /// # Panics
475 ///
476 /// Panics if called before adding any responses to the sequence.
477 pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
478 if self.generators.is_empty() {
479 panic!("repeatedly() called before adding a response to the sequence");
480 }
481 let inner = self.times(usize::MAX);
482 FinalizedResponseSequenceBuilder { inner }
483 }
484
485 /// Build this a "simple" rule (internal detail)
486 pub(crate) fn build_simple(mut self) -> Rule {
487 self.is_simple = true;
488 self.repeatedly().build()
489 }
490
491 /// Build the rule with this response sequence
492 pub fn build(self) -> Rule {
493 let generators = self.generators;
494 let is_simple = self.is_simple;
495
496 // calculate total responses (sum of all repetitions)
497 let total_responses: usize = generators
498 .iter()
499 .map(|(_, count)| *count)
500 .fold(0, |acc, count| acc.saturating_add(count));
501
502 Rule::new(
503 self.input_filter,
504 Arc::new(move |idx, input| {
505 // find which generator to use
506 let mut current_idx = idx;
507 for (generator, repeat_count) in &generators {
508 if current_idx < *repeat_count {
509 return Some(generator(input));
510 }
511 current_idx -= repeat_count;
512 }
513 None
514 }),
515 total_responses,
516 is_simple,
517 )
518 }
519}
520
521impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
522where
523 I: fmt::Debug + Send + Sync + 'static,
524 O: fmt::Debug + Send + Sync + 'static,
525 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
526{
527 /// Build the rule with this response sequence
528 pub fn build(self) -> Rule {
529 self.inner.build()
530 }
531}