aws_smithy_mocks/rule.rs
1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16/// A mock response that can be returned by a rule.
17///
18/// This enum represents the different types of responses that can be returned by a mock rule:
19/// - `Output`: A successful modeled response
20/// - `Error`: A modeled error
21/// - `Http`: An HTTP response
22///
23#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25 /// A successful modeled response.
26 Output(O),
27 /// A modeled error.
28 Error(E),
29 /// An HTTP response.
30 Http(HttpResponse),
31}
32
33/// A function that matches requests.
34type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37/// A rule for matching requests and providing mock responses.
38///
39/// Rules are created using the `mock!` macro or the `RuleBuilder`.
40///
41#[derive(Clone)]
42pub struct Rule {
43 /// Function that determines if this rule matches a request.
44 pub(crate) matcher: MatchFn,
45
46 /// Handler function that generates responses.
47 response_handler: ServeFn,
48
49 /// Number of times this rule has been called.
50 call_count: Arc<AtomicUsize>,
51
52 /// Maximum number of responses this rule will provide.
53 pub(crate) max_responses: usize,
54
55 /// Flag indicating this is a "simple" rule which changes how it is interpreted
56 /// depending on the RuleMode.
57 ///
58 /// See [smithy-rs#4135](https://github.com/smithy-lang/smithy-rs/issues/4135)
59 is_simple: bool,
60}
61
62impl fmt::Debug for Rule {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "Rule")
65 }
66}
67
68impl Rule {
69 /// Creates a new rule with the given matcher, response handler, and max responses.
70 pub(crate) fn new<O, E>(
71 matcher: MatchFn,
72 response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
73 max_responses: usize,
74 is_simple: bool,
75 ) -> Self
76 where
77 O: fmt::Debug + Send + Sync + 'static,
78 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
79 {
80 Rule {
81 matcher,
82 response_handler: Arc::new(move |idx: usize| {
83 if idx < max_responses {
84 response_handler(idx).map(|resp| match resp {
85 MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
86 MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
87 MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
88 })
89 } else {
90 None
91 }
92 }),
93 call_count: Arc::new(AtomicUsize::new(0)),
94 max_responses,
95 is_simple,
96 }
97 }
98
99 /// Test if this is a "simple" rule (non-sequenced)
100 pub(crate) fn is_simple(&self) -> bool {
101 self.is_simple
102 }
103
104 /// Gets the next response.
105 pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
106 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
107 (self.response_handler)(idx)
108 }
109
110 /// Returns the number of times this rule has been called.
111 pub fn num_calls(&self) -> usize {
112 self.call_count.load(Ordering::SeqCst)
113 }
114
115 /// Checks if this rule is exhausted (has provided all its responses).
116 pub fn is_exhausted(&self) -> bool {
117 self.num_calls() >= self.max_responses
118 }
119}
120
121/// RuleMode describes how rules will be interpreted.
122/// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
123/// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum RuleMode {
126 /// Match rules in the order they were added. The first matching rule will be applied and the
127 /// rules will remain unchanged
128 Sequential,
129 /// The first matching rule will be applied, and that rule will be removed from the list of rules
130 /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
131 /// will be consumed before moving to the next rule.
132 MatchAny,
133}
134
135/// A builder for creating rules.
136///
137/// This builder provides a fluent API for creating rules with different response types.
138///
139pub struct RuleBuilder<I, O, E> {
140 /// Function that determines if this rule matches a request.
141 pub(crate) input_filter: MatchFn,
142
143 /// Phantom data for the input type.
144 pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
145}
146
147impl<I, O, E> RuleBuilder<I, O, E>
148where
149 I: fmt::Debug + Send + Sync + 'static,
150 O: fmt::Debug + Send + Sync + 'static,
151 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
152{
153 /// Creates a new [`RuleBuilder`]
154 #[doc(hidden)]
155 pub fn new() -> Self {
156 RuleBuilder {
157 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
158 _ty: std::marker::PhantomData,
159 }
160 }
161
162 /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
163 #[doc(hidden)]
164 pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
165 where
166 F: Future<Output = Result<O, SdkError<E, R>>>,
167 {
168 Self {
169 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
170 _ty: Default::default(),
171 }
172 }
173
174 /// Sets the function that determines if this rule matches a request.
175 pub fn match_requests<F>(mut self, filter: F) -> Self
176 where
177 F: Fn(&I) -> bool + Send + Sync + 'static,
178 {
179 self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
180 Some(typed_input) => filter(typed_input),
181 _ => false,
182 });
183 self
184 }
185
186 /// Start building a response sequence
187 ///
188 /// A sequence allows a single rule to generate multiple responses which can
189 /// be used to test retry behavior.
190 ///
191 /// # Examples
192 ///
193 /// With repetition using `times()`:
194 ///
195 /// ```rust,ignore
196 /// let rule = mock!(Client::get_object)
197 /// .sequence()
198 /// .http_status(503, None)
199 /// .times(2) // First two calls return 503
200 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
201 /// .build();
202 /// ```
203 pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
204 ResponseSequenceBuilder::new(self.input_filter)
205 }
206
207 /// Creates a rule that returns a modeled output.
208 pub fn then_output<F>(self, output_fn: F) -> Rule
209 where
210 F: Fn() -> O + Send + Sync + 'static,
211 {
212 self.sequence().output(output_fn).build_simple()
213 }
214
215 /// Creates a rule that returns a modeled error.
216 pub fn then_error<F>(self, error_fn: F) -> Rule
217 where
218 F: Fn() -> E + Send + Sync + 'static,
219 {
220 self.sequence().error(error_fn).build_simple()
221 }
222
223 /// Creates a rule that returns an HTTP response.
224 pub fn then_http_response<F>(self, response_fn: F) -> Rule
225 where
226 F: Fn() -> HttpResponse + Send + Sync + 'static,
227 {
228 self.sequence().http_response(response_fn).build_simple()
229 }
230}
231
232type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
233
234/// A builder for creating response sequences
235pub struct ResponseSequenceBuilder<I, O, E> {
236 /// The response generators in the sequence
237 generators: Vec<(SequenceGeneratorFn<O, E>, usize)>,
238
239 /// Function that determines if this rule matches a request
240 input_filter: MatchFn,
241
242 /// flag indicating this is a "simple" rule
243 is_simple: bool,
244
245 /// Marker for the input, output, and error types
246 _marker: std::marker::PhantomData<I>,
247}
248
249/// Final sequence builder state - can only `build()`
250pub struct FinalizedResponseSequenceBuilder<I, O, E> {
251 inner: ResponseSequenceBuilder<I, O, E>,
252}
253
254impl<I, O, E> ResponseSequenceBuilder<I, O, E>
255where
256 I: fmt::Debug + Send + Sync + 'static,
257 O: fmt::Debug + Send + Sync + 'static,
258 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
259{
260 /// Create a new response sequence builder
261 pub(crate) fn new(input_filter: MatchFn) -> Self {
262 Self {
263 generators: Vec::new(),
264 input_filter,
265 is_simple: false,
266 _marker: std::marker::PhantomData,
267 }
268 }
269
270 /// Add a modeled output response to the sequence
271 ///
272 /// # Examples
273 ///
274 /// ```rust,ignore
275 /// let rule = mock!(Client::get_object)
276 /// .sequence()
277 /// .output(|| GetObjectOutput::builder().build())
278 /// .build();
279 /// ```
280 pub fn output<F>(mut self, output_fn: F) -> Self
281 where
282 F: Fn() -> O + Send + Sync + 'static,
283 {
284 let generator = Arc::new(move || MockResponse::Output(output_fn()));
285 self.generators.push((generator, 1));
286 self
287 }
288
289 /// Add a modeled error response to the sequence
290 pub fn error<F>(mut self, error_fn: F) -> Self
291 where
292 F: Fn() -> E + Send + Sync + 'static,
293 {
294 let generator = Arc::new(move || MockResponse::Error(error_fn()));
295 self.generators.push((generator, 1));
296 self
297 }
298
299 /// Add an HTTP status code response to the sequence
300 pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
301 let status_code = StatusCode::try_from(status).unwrap();
302
303 let generator: SequenceGeneratorFn<O, E> = match body {
304 Some(body) => Arc::new(move || {
305 MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
306 }),
307 None => Arc::new(move || {
308 MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
309 }),
310 };
311
312 self.generators.push((generator, 1));
313 self
314 }
315
316 /// Add an HTTP response to the sequence
317 pub fn http_response<F>(mut self, response_fn: F) -> Self
318 where
319 F: Fn() -> HttpResponse + Send + Sync + 'static,
320 {
321 let generator = Arc::new(move || MockResponse::Http(response_fn()));
322 self.generators.push((generator, 1));
323 self
324 }
325
326 /// Repeat the last added response multiple times.
327 ///
328 /// This method sets the number of times the last response in the sequence will be used.
329 /// For example, if you add a response and then call `times(3)`, that response will be
330 /// returned for the next 3 calls to the rule.
331 ///
332 /// # Examples
333 ///
334 /// ```rust,ignore
335 /// // Create a rule that returns 503 twice, then succeeds
336 /// let rule = mock!(Client::get_object)
337 /// .sequence()
338 /// .http_status(503, None)
339 /// .times(2) // First two calls return 503
340 /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
341 /// .build();
342 /// ```
343 ///
344 /// # Panics
345 ///
346 /// Panics if:
347 /// - Called with a count of 0
348 /// - Called before adding any responses to the sequence
349 pub fn times(mut self, count: usize) -> Self {
350 if self.generators.is_empty() {
351 panic!("times(n) called before adding a response to the sequence");
352 }
353 match count {
354 0 => panic!("repeat count must be greater than zero"),
355 1 => {
356 return self;
357 }
358 _ => {}
359 }
360
361 // update the repeat count of the last generator
362 if let Some(last_generator) = self.generators.last_mut() {
363 last_generator.1 = count;
364 }
365 self
366 }
367 /// Make the last response in the sequence repeat indefinitely.
368 ///
369 /// This method causes the last response added to the sequence to be repeated
370 /// forever, making the rule never exhaust. After calling `repeatedly()`,
371 /// no more responses can be added to the sequence.
372 ///
373 /// # Examples
374 ///
375 /// ```rust,ignore
376 /// // Create a rule that returns an error once, then succeeds forever
377 /// let rule = mock!(Client::get_object)
378 /// .sequence()
379 /// .error(|| GetObjectError::NoSuchKey(NoSuchKey::builder().build()))
380 /// .output(|| GetObjectOutput::builder().build())
381 /// .repeatedly()
382 /// .build();
383 ///
384 /// // First call will return NoSuchKey error
385 /// // All subsequent calls will return success
386 /// ```
387 ///
388 /// # Panics
389 ///
390 /// Panics if called before adding any responses to the sequence.
391 pub fn repeatedly(self) -> FinalizedResponseSequenceBuilder<I, O, E> {
392 if self.generators.is_empty() {
393 panic!("repeatedly() called before adding a response to the sequence");
394 }
395 let inner = self.times(usize::MAX);
396 FinalizedResponseSequenceBuilder { inner }
397 }
398
399 /// Build this a "simple" rule (internal detail)
400 pub(crate) fn build_simple(mut self) -> Rule {
401 self.is_simple = true;
402 self.repeatedly().build()
403 }
404
405 /// Build the rule with this response sequence
406 pub fn build(self) -> Rule {
407 let generators = self.generators;
408 let is_simple = self.is_simple;
409
410 // calculate total responses (sum of all repetitions)
411 let total_responses: usize = generators
412 .iter()
413 .map(|(_, count)| *count)
414 .fold(0, |acc, count| acc.saturating_add(count));
415
416 Rule::new(
417 self.input_filter,
418 Arc::new(move |idx| {
419 // find which generator to use
420 let mut current_idx = idx;
421 for (generator, repeat_count) in &generators {
422 if current_idx < *repeat_count {
423 return Some(generator());
424 }
425 current_idx -= repeat_count;
426 }
427 None
428 }),
429 total_responses,
430 is_simple,
431 )
432 }
433}
434
435impl<I, O, E> FinalizedResponseSequenceBuilder<I, O, E>
436where
437 I: fmt::Debug + Send + Sync + 'static,
438 O: fmt::Debug + Send + Sync + 'static,
439 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
440{
441 /// Build the rule with this response sequence
442 pub fn build(self) -> Rule {
443 self.inner.build()
444 }
445}