aws_smithy_mocks/rule.rs
1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24#[allow(clippy::large_enum_variant)]
25pub(crate) enum MockResponse<O, E> {
26 /// A successful modeled response.
27 Output(O),
28 /// A modeled error.
29 Error(E),
30 /// An HTTP response.
31 Http(HttpResponse),
32}
33
34/// A function that matches requests.
35type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
36type ServeFn = Arc<dyn Fn(usize, &Input) -> Option<MockResponse<Output, Error>> + Send + Sync>;
37
38/// A rule for matching requests and providing mock responses.
39///
40/// Rules are created using the `mock!` macro or the `RuleBuilder`.
41///
42#[derive(Clone)]
43pub struct Rule {
44 /// Function that determines if this rule matches a request.
45 pub(crate) matcher: MatchFn,
46
47 /// Handler function that generates responses.
48 response_handler: ServeFn,
49
50 /// Number of times this rule has been called.
51 call_count: Arc<AtomicUsize>,
52
53 /// Maximum number of responses this rule will provide.
54 pub(crate) max_responses: usize,
55
56 /// Flag indicating this is a "simple" rule which changes how it is interpreted
57 /// depending on the RuleMode.
58 ///
59 /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
60 is_simple: bool,
61}
62
63impl fmt::Debug for Rule {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(f, "Rule")
66 }
67}
68
69impl Rule {
70 /// Creates a new rule with the given matcher, response handler, and max responses.
71 #[allow(clippy::type_complexity)]
72 pub(crate) fn new<O, E>(
73 matcher: MatchFn,
74 response_handler: Arc<dyn Fn(usize, &Input) -> Option<MockResponse<O, E>> + Send + Sync>,
75 max_responses: usize,
76 is_simple: bool,
77 ) -> Self
78 where
79 O: fmt::Debug + Send + Sync + 'static,
80 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
81 {
82 Rule {
83 matcher,
84 response_handler: Arc::new(move |idx: usize, input: &Input| {
85 if idx < max_responses {
86 response_handler(idx, input).map(|resp| match resp {
87 MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
88 MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
89 MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
90 })
91 } else {
92 None
93 }
94 }),
95 call_count: Arc::new(AtomicUsize::new(0)),
96 max_responses,
97 is_simple,
98 }
99 }
100
101 /// Test if this is a "simple" rule (non-sequenced)
102 pub(crate) fn is_simple(&self) -> bool {
103 self.is_simple
104 }
105
106 /// Gets the next response.
107 pub(crate) fn next_response(&self, input: &Input) -> Option<MockResponse<Output, Error>> {
108 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
109 (self.response_handler)(idx, input)
110 }
111
112 /// Returns the number of times this rule has been called.
113 pub fn num_calls(&self) -> usize {
114 self.call_count.load(Ordering::SeqCst)
115 }
116
117 /// Checks if this rule is exhausted (has provided all its responses).
118 pub fn is_exhausted(&self) -> bool {
119 self.num_calls() >= self.max_responses
120 }
121}
122
123/// RuleMode describes how rules will be interpreted.
124/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
125/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum RuleMode {
128 /// Match rules in the order they were added. The first matching rule will be applied and the
129 /// rules will remain unchanged
130 Sequential,
131 /// The first matching rule will be applied, and that rule will be removed from the list of rules
132 /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
133 /// will be consumed before moving to the next rule.
134 MatchAny,
135}
136
137/// A builder for creating rules.
138///
139/// This builder provides a fluent API for creating rules with different response types.
140///
141pub struct RuleBuilder<I, O, E> {
142 /// Function that determines if this rule matches a request.
143 pub(crate) input_filter: MatchFn,
144
145 /// Phantom data for the input type.
146 pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
147}
148
149impl<I, O, E> RuleBuilder<I, O, E>
150where
151 I: fmt::Debug + Send + Sync + 'static,
152 O: fmt::Debug + Send + Sync + 'static,
153 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
154{
155 /// Creates a new [`RuleBuilder`]
156 #[doc(hidden)]
157 pub fn new() -> Self {
158 RuleBuilder {
159 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
160 _ty: std::marker::PhantomData,
161 }
162 }
163
164 /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
165 #[doc(hidden)]
166 pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
167 where
168 F: Future<Output = Result<O, SdkError<E, R>>>,
169 {
170 Self {
171 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
172 _ty: Default::default(),
173 }
174 }
175
176 /// Sets the function that determines if this rule matches a request.
177 pub fn match_requests<F>(mut self, filter: F) -> Self
178 where
179 F: Fn(&I) -> bool + Send + Sync + 'static,
180 {
181 self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
182 Some(typed_input) => filter(typed_input),
183 _ => false,
184 });
185 self
186 }
187
188 /// Start building a response sequence
189 ///
190 /// A sequence allows a single rule to generate multiple responses which can
191 /// be used to test retry behavior.
192 ///
193 /// # Examples
194 ///
195 /// With repetition using `times()`:
196 ///
197 /// ```rust,ignore
198 /// let rule = mock!(Client::get_object)
199 /// .sequence()
200 /// .http_status(503, None)
201 /// .times(2) // First two calls return 503
202 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
203 /// .build();
204 /// ```
205 pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
206 ResponseSequenceBuilder::new(self.input_filter)
207 }
208
209 /// Creates a rule that returns a modeled output.
210 pub fn then_output<F>(self, output_fn: F) -> Rule
211 where
212 F: Fn() -> O + Send + Sync + 'static,
213 {
214 self.sequence().output(output_fn).build_simple()
215 }
216
217 /// Creates a rule that returns a modeled error.
218 pub fn then_error<F>(self, error_fn: F) -> Rule
219 where
220 F: Fn() -> E + Send + Sync + 'static,
221 {
222 self.sequence().error(error_fn).build_simple()
223 }
224
225 /// Creates a rule that returns an HTTP response.
226 pub fn then_http_response<F>(self, response_fn: F) -> Rule
227 where
228 F: Fn() -> HttpResponse + Send + Sync + 'static,
229 {
230 self.sequence().http_response(response_fn).build_simple()
231 }
232
233 /// Creates a rule that computes an output based on the input.
234 ///
235 /// This allows generating responses based on the input request.
236 ///
237 /// # Examples
238 ///
239 /// ```rust,ignore
240 /// let rule = mock!(Client::get_object)
241 /// .compute_output(|req| {
242 /// GetObjectOutput::builder()
243 /// .body(ByteStream::from_static(format!("content for {}", req.key().unwrap_or("unknown")).as_bytes()))
244 /// .build()
245 /// })
246 /// .build();
247 /// ```
248 pub fn then_compute_output<F>(self, compute_fn: F) -> Rule
249 where
250 F: Fn(&I) -> O + Send + Sync + 'static,
251 {
252 self.sequence().compute_output(compute_fn).build_simple()
253 }
254}
255
256type SequenceGeneratorFn<O, E> = Arc<dyn Fn(&Input) -> MockResponse<O, E> + Send + Sync>;
257
258/// A builder for creating response sequences
259pub struct ResponseSequenceBuilder<I, O, E> {
260 /// The response generators in the sequence
261 generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
262
263 /// Function that determines if this rule matches a request
264 input_filter: MatchFn,
265
266 /// flag indicating this is a "simple" rule
267 is_simple: bool,
268
269 /// Marker for the input, output, and error types
270 _marker: std::marker::PhantomData<I>,
271}
272
273/// Final sequence builder state - can only `build()`
274pub struct FinalizedResponseSequenceBuilder<I, O, E> {
275 inner: ResponseSequenceBuilder<I, O, E>,
276}
277
278impl<I, O, E> ResponseSequenceBuilder<I, O, E>
279where
280 I: fmt::Debug + Send + Sync + 'static,
281 O: fmt::Debug + Send + Sync + 'static,
282 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
283{
284 /// Create a new response sequence builder
285 pub(crate) fn new(input_filter: MatchFn) -> Self {
286 Self {
287 generators: Vec::new(),
288 input_filter,
289 is_simple: false,
290 _marker: std::marker::PhantomData,
291 }
292 }
293
294 /// Add a modeled output response to the sequence
295 ///
296 /// # Examples
297 ///
298 /// ```rust,ignore
299 /// let rule = mock!(Client::get_object)
300 /// .sequence()
301 /// .output(|| GetObjectOutput::builder().build())
302 /// .build();
303 /// ```
304 pub fn output<F>(mut self, output_fn: F) -> Self
305 where
306 F: Fn() -> O + Send + Sync + 'static,
307 {
308 let generator = Arc::new(move |_input: &Input| MockResponse::Output(output_fn()));
309 self.generators.push((generator, 1));
310 self
311 }
312
313 /// Add a modeled error response to the sequence
314 pub fn error<F>(mut self, error_fn: F) -> Self
315 where
316 F: Fn() -> E + Send + Sync + 'static,
317 {
318 let generator = Arc::new(move |_input: &Input| MockResponse::Error(error_fn()));
319 self.generators.push((generator, 1));
320 self
321 }
322
323 /// Add an HTTP status code response to the sequence
324 pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
325 let status_code = StatusCode::try_from(status).unwrap();
326
327 let generator: SequenceGeneratorFn<O, E> = match body {
328 Some(body) => Arc::new(move |_input: &Input| {
329 MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
330 }),
331 None => Arc::new(move |_input: &Input| {
332 MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
333 }),
334 };
335
336 self.generators.push((generator, 1));
337 self
338 }
339
340 /// Add an HTTP response to the sequence
341 pub fn http_response<F>(mut self, response_fn: F) -> Self
342 where
343 F: Fn() -> HttpResponse + Send + Sync + 'static,
344 {
345 let generator = Arc::new(move |_input: &Input| MockResponse::Http(response_fn()));
346 self.generators.push((generator, 1));
347 self
348 }
349
350 /// Add a computed output response to the sequence. Note that this is not `pub`
351 /// because creating computed output rules off of sequenced rules doesn't work,
352 /// as we can't preserve the input across retries. So we only expose `compute_output`
353 /// on unsequenced rules above.
354 fn compute_output<F>(mut self, compute_fn: F) -> Self
355 where
356 F: Fn(&I) -> O + Send + Sync + 'static,
357 {
358 let generator = Arc::new(move |input: &Input| {
359 if let Some(typed_input) = input.downcast_ref::<I>() {
360 MockResponse::Output(compute_fn(typed_input))
361 } else {
362 panic!("Input type mismatch in compute_output")
363 }
364 });
365 self.generators.push((generator, 1));
366 self
367 }
368
369 /// Repeat the last added response multiple times.
370 ///
371 /// This method sets the number of times the last response in the sequence will be used.
372 /// For example, if you add a response and then call `times(3)`, that response will be
373 /// returned for the next 3 calls to the rule.
374 ///
375 /// # Examples
376 ///
377 /// ```rust,ignore
378 /// // Create a rule that returns 503 twice, then succeeds
379 /// let rule = mock!(Client::get_object)
380 /// .sequence()
381 /// .http_status(503, None)
382 /// .times(2) // First two calls return 503
383 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
384 /// .build();
385 /// ```
386 ///
387 /// # Panics
388 ///
389 /// Panics if:
390 /// - Called with a count of 0
391 /// - Called before adding any responses to the sequence
392 pub fn times(mut self, count: usize) -> Self {
393 if self.generators.is_empty() {
394 panic!("times(n) called before adding a response to the sequence");
395 }
396 match count {
397 0 => panic!("repeat count must be greater than zero"),
398 1 => {
399 return self;
400 }
401 _ => {}
402 }
403
404 // update the repeat count of the last generator
405 if let Some(last_generator) = self.generators.last_mut() {
406 last_generator.1 = count;
407 }
408 self
409 }
410 /// Make the last response in the sequence repeat indefinitely.
411 ///
412 /// This method causes the last response added to the sequence to be repeated
413 /// forever, making the rule never exhaust. After calling `repeatedly()`,
414 /// no more responses can be added to the sequence.
415 ///
416 /// # Examples
417 ///
418 /// ```rust,ignore
419 /// // Create a rule that returns an error once, then succeeds forever
420 /// let rule = mock!(Client::get_object)
421 /// .sequence()
422 /// .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
423 /// .output(|| GetObjectOutput::builder().build())
424 /// .repeatedly()
425 /// .build();
426 ///
427 /// // First call will return NoSuchKey error
428 /// // All subsequent calls will return success
429 /// ```
430 ///
431 /// # Panics
432 ///
433 /// Panics if called before adding any responses to the sequence.
434 pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
435 if self.generators.is_empty() {
436 panic!("repeatedly() called before adding a response to the sequence");
437 }
438 let inner = self.times(usize::MAX);
439 FinalizedResponseSequenceBuilder { inner }
440 }
441
442 /// Build this a "simple" rule (internal detail)
443 pub(crate) fn build_simple(mut self) -> Rule {
444 self.is_simple = true;
445 self.repeatedly().build()
446 }
447
448 /// Build the rule with this response sequence
449 pub fn build(self) -> Rule {
450 let generators = self.generators;
451 let is_simple = self.is_simple;
452
453 // calculate total responses (sum of all repetitions)
454 let total_responses: usize = generators
455 .iter()
456 .map(|(_, count)| *count)
457 .fold(0, |acc, count| acc.saturating_add(count));
458
459 Rule::new(
460 self.input_filter,
461 Arc::new(move |idx, input| {
462 // find which generator to use
463 let mut current_idx = idx;
464 for (generator, repeat_count) in &generators {
465 if current_idx < *repeat_count {
466 return Some(generator(input));
467 }
468 current_idx -= repeat_count;
469 }
470 None
471 }),
472 total_responses,
473 is_simple,
474 )
475 }
476}
477
478impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
479where
480 I: fmt::Debug + Send + Sync + 'static,
481 O: fmt::Debug + Send + Sync + 'static,
482 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
483{
484 /// Build the rule with this response sequence
485 pub fn build(self) -> Rule {
486 self.inner.build()
487 }
488}