1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
|
7 + | use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
|
8 + | use aws_smithy_runtime_api::client::result::SdkError;
|
9 + | use aws_smithy_runtime_api::http::StatusCode;
|
10 + | use aws_smithy_types::body::SdkBody;
|
11 + | use std::fmt;
|
12 + | use std::future::Future;
|
13 + | use std::sync::atomic::{AtomicUsize, Ordering};
|
14 + | use std::sync::Arc;
|
15 + |
|
16 + | /// A mock response that can be returned by a rule.
|
17 + | ///
|
18 + | /// This enum represents the different types of responses that can be returned by a mock rule:
|
19 + | /// - `Output`: A successful modeled response
|
20 + | /// - `Error`: A modeled error
|
21 + | /// - `Http`: An HTTP response
|
22 + | ///
|
23 + | #[derive(Debug)]
|
24 + | pub(crate) enum MockResponse<O, E> {
|
25 + | /// A successful modeled response.
|
26 + | Output(O),
|
27 + | /// A modeled error.
|
28 + | Error(E),
|
29 + | /// An HTTP response.
|
30 + | Http(HttpResponse),
|
31 + | }
|
32 + |
|
33 + | /// A function that matches requests.
|
34 + | type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
|
35 + | type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
|
36 + |
|
37 + | /// A rule for matching requests and providing mock responses.
|
38 + | ///
|
39 + | /// Rules are created using the `mock!` macro or the `RuleBuilder`.
|
40 + | ///
|
41 + | #[derive(Clone)]
|
42 + | pub struct Rule {
|
43 + | /// Function that determines if this rule matches a request.
|
44 + | pub(crate) matcher: MatchFn,
|
45 + |
|
46 + | /// Handler function that generates responses.
|
47 + | pub(crate) response_handler: ServeFn,
|
48 + |
|
49 + | /// Number of times this rule has been called.
|
50 + | pub(crate) call_count: Arc<AtomicUsize>,
|
51 + |
|
52 + | /// Maximum number of responses this rule will provide.
|
53 + | pub(crate) max_responses: usize,
|
54 + | }
|
55 + |
|
56 + | impl fmt::Debug for Rule {
|
57 + | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
58 + | write!(f, "Rule")
|
59 + | }
|
60 + | }
|
61 + |
|
62 + | impl Rule {
|
63 + | /// Creates a new rule with the given matcher, response handler, and max responses.
|
64 + | pub(crate) fn new<O, E>(
|
65 + | matcher: MatchFn,
|
66 + | response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
|
67 + | max_responses: usize,
|
68 + | ) -> Self
|
69 + | where
|
70 + | O: fmt::Debug + Send + Sync + 'static,
|
71 + | E: fmt::Debug + Send + Sync + std::error::Error + 'static,
|
72 + | {
|
73 + | Rule {
|
74 + | matcher,
|
75 + | response_handler: Arc::new(move |idx: usize| {
|
76 + | if idx < max_responses {
|
77 + | response_handler(idx).map(|resp| match resp {
|
78 + | MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
|
79 + | MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
|
80 + | MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
|
81 + | })
|
82 + | } else {
|
83 + | None
|
84 + | }
|
85 + | }),
|
86 + | call_count: Arc::new(AtomicUsize::new(0)),
|
87 + | max_responses,
|
88 + | }
|
89 + | }
|
90 + |
|
91 + | /// Gets the next response.
|
92 + | pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
|
93 + | let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
|
94 + | (self.response_handler)(idx)
|
95 + | }
|
96 + |
|
97 + | /// Returns the number of times this rule has been called.
|
98 + | pub fn num_calls(&self) -> usize {
|
99 + | self.call_count.load(Ordering::SeqCst)
|
100 + | }
|
101 + |
|
102 + | /// Checks if this rule is exhausted (has provided all its responses).
|
103 + | pub fn is_exhausted(&self) -> bool {
|
104 + | self.num_calls() >= self.max_responses
|
105 + | }
|
106 + | }
|
107 + |
|
108 + | /// RuleMode describes how rules will be interpreted.
|
109 + | /// - In RuleMode::MatchAny, the first matching rule will be applied, and the rules will remain unchanged.
|
110 + | /// - In RuleMode::Sequential, the first matching rule will be applied, and that rule will be removed from the list of rules **once it is exhausted**.
|
111 + | #[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
112 + | pub enum RuleMode {
|
113 + | /// Match rules in the order they were added. The first matching rule will be applied and the
|
114 + | /// rules will remain unchanged
|
115 + | Sequential,
|
116 + | /// The first matching rule will be applied, and that rule will be removed from the list of rules
|
117 + | /// **once it is exhausted**. Each rule can have multiple responses, and all responses in a rule
|
118 + | /// will be consumed before moving to the next rule.
|
119 + | MatchAny,
|
120 + | }
|
121 + |
|
122 + | /// A builder for creating rules.
|
123 + | ///
|
124 + | /// This builder provides a fluent API for creating rules with different response types.
|
125 + | ///
|
126 + | pub struct RuleBuilder<I, O, E> {
|
127 + | /// Function that determines if this rule matches a request.
|
128 + | pub(crate) input_filter: MatchFn,
|
129 + |
|
130 + | /// Phantom data for the input type.
|
131 + | pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
|
132 + | }
|
133 + |
|
134 + | impl<I, O, E> RuleBuilder<I, O, E>
|
135 + | where
|
136 + | I: fmt::Debug + Send + Sync + 'static,
|
137 + | O: fmt::Debug + Send + Sync + 'static,
|
138 + | E: fmt::Debug + Send + Sync + std::error::Error + 'static,
|
139 + | {
|
140 + | /// Creates a new [`RuleBuilder`]
|
141 + | #[doc(hidden)]
|
142 + | pub fn new() -> Self {
|
143 + | RuleBuilder {
|
144 + | input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
|
145 + | _ty: std::marker::PhantomData,
|
146 + | }
|
147 + | }
|
148 + |
|
149 + | /// Creates a new [`RuleBuilder`]. This is normally constructed with the [`mock!`] macro
|
150 + | #[doc(hidden)]
|
151 + | pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
|
152 + | where
|
153 + | F: Future<Output = Result<O, SdkError<E, R>>>,
|
154 + | {
|
155 + | Self {
|
156 + | input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
|
157 + | _ty: Default::default(),
|
158 + | }
|
159 + | }
|
160 + |
|
161 + | /// Sets the function that determines if this rule matches a request.
|
162 + | pub fn match_requests<F>(mut self, filter: F) -> Self
|
163 + | where
|
164 + | F: Fn(&I) -> bool + Send + Sync + 'static,
|
165 + | {
|
166 + | self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
|
167 + | Some(typed_input) => filter(typed_input),
|
168 + | _ => false,
|
169 + | });
|
170 + | self
|
171 + | }
|
172 + |
|
173 + | /// Start building a response sequence
|
174 + | ///
|
175 + | /// A sequence allows a single rule to generate multiple responses which can
|
176 + | /// be used to test retry behavior.
|
177 + | ///
|
178 + | /// # Examples
|
179 + | ///
|
180 + | /// With repetition using `times()`:
|
181 + | ///
|
182 + | /// ```rust,ignore
|
183 + | /// let rule = mock!(Client::get_object)
|
184 + | /// .sequence()
|
185 + | /// .http_status(503, None)
|
186 + | /// .times(2) // First two calls return 503
|
187 + | /// .output(|| GetObjectOutput::builder().build()) // Third call succeeds
|
188 + | /// .build();
|
189 + | /// ```
|
190 + | pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
|
191 + | ResponseSequenceBuilder::new(self.input_filter)
|
192 + | }
|
193 + |
|
194 + | /// Creates a rule that returns a modeled output.
|
195 + | pub fn then_output<F>(self, output_fn: F) -> Rule
|
196 + | where
|
197 + | F: Fn() -> O + Send + Sync + 'static,
|
198 + | {
|
199 + | self.sequence().output(output_fn).build()
|
200 + | }
|
201 + |
|
202 + | /// Creates a rule that returns a modeled error.
|
203 + | pub fn then_error<F>(self, error_fn: F) -> Rule
|
204 + | where
|
205 + | F: Fn() -> E + Send + Sync + 'static,
|
206 + | {
|
207 + | self.sequence().error(error_fn).build()
|
208 + | }
|
209 + |
|
210 + | /// Creates a rule that returns an HTTP response.
|
211 + | pub fn then_http_response<F>(self, response_fn: F) -> Rule
|
212 + | where
|
213 + | F: Fn() -> HttpResponse + Send + Sync + 'static,
|
214 + | {
|
215 + | self.sequence().http_response(response_fn).build()
|
216 + | }
|
217 + | }
|
218 + |
|
219 + | type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
|
220 + |
|
221 + | /// A builder for creating response sequences
|
222 + | pub struct ResponseSequenceBuilder<I, O, E> {
|
223 + | /// The response generators in the sequence
|
224 + | generators: Vec<SequenceGeneratorFn<O, E>>,
|
225 + |
|
226 + | /// Function that determines if this rule matches a request
|
227 + | input_filter: MatchFn,
|
228 + |
|
229 + | /// Marker for the input, output, and error types
|
230 + | _marker: std::marker::PhantomData<I>,
|
231 + | }
|
232 + |
|
233 + | impl<I, O, E> ResponseSequenceBuilder<I, O, E>
|
234 + | where
|
235 + | I: fmt::Debug + Send + Sync + 'static,
|
236 + | O: fmt::Debug + Send + Sync + 'static,
|
237 + | E: fmt::Debug + Send + Sync + std::error::Error + 'static,
|
238 + | {
|
239 + | /// Create a new response sequence builder
|
240 + | pub(crate) fn new(input_filter: MatchFn) -> Self {
|
241 + | Self {
|
242 + | generators: Vec::new(),
|
243 + | input_filter,
|
244 + | _marker: std::marker::PhantomData,
|
245 + | }
|
246 + | }
|
247 + |
|
248 + | /// Add a modeled output response to the sequence
|
249 + | ///
|
250 + | /// # Examples
|
251 + | ///
|
252 + | /// ```rust,ignore
|
253 + | /// let rule = mock!(Client::get_object)
|
254 + | /// .sequence()
|
255 + | /// .output(|| GetObjectOutput::builder().build())
|
256 + | /// .build();
|
257 + | /// ```
|
258 + | pub fn output<F>(mut self, output_fn: F) -> Self
|
259 + | where
|
260 + | F: Fn() -> O + Send + Sync + 'static,
|
261 + | {
|
262 + | let generator = Arc::new(move || MockResponse::Output(output_fn()));
|
263 + | self.generators.push(generator);
|
264 + | self
|
265 + | }
|
266 + |
|
267 + | /// Add a modeled error response to the sequence
|
268 + | pub fn error<F>(mut self, error_fn: F) -> Self
|
269 + | where
|
270 + | F: Fn() -> E + Send + Sync + 'static,
|
271 + | {
|
272 + | let generator = Arc::new(move || MockResponse::Error(error_fn()));
|
273 + | self.generators.push(generator);
|
274 + | self
|
275 + | }
|
276 + |
|
277 + | /// Add an HTTP status code response to the sequence
|
278 + | pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
|
279 + | let status_code = StatusCode::try_from(status).unwrap();
|
280 + |
|
281 + | let generator: SequenceGeneratorFn<O, E> = match body {
|
282 + | Some(body) => Arc::new(move || {
|
283 + | MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
|
284 + | }),
|
285 + | None => Arc::new(move || {
|
286 + | MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
|
287 + | }),
|
288 + | };
|
289 + |
|
290 + | self.generators.push(generator);
|
291 + | self
|
292 + | }
|
293 + |
|
294 + | /// Add an HTTP response to the sequence
|
295 + | pub fn http_response<F>(mut self, response_fn: F) -> Self
|
296 + | where
|
297 + | F: Fn() -> HttpResponse + Send + Sync + 'static,
|
298 + | {
|
299 + | let generator = Arc::new(move || MockResponse::Http(response_fn()));
|
300 + | self.generators.push(generator);
|
301 + | self
|
302 + | }
|
303 + |
|
304 + | /// Repeat the last added response multiple times (total count)
|
305 + | pub fn times(mut self, count: usize) -> Self {
|
306 + | if count <= 1 {
|
307 + | return self;
|
308 + | }
|
309 + |
|
310 + | if let Some(last_generator) = self.generators.last().cloned() {
|
311 + | // Add count-1 more copies (we already have one)
|
312 + | for _ in 1..count {
|
313 + | self.generators.push(last_generator.clone());
|
314 + | }
|
315 + | }
|
316 + | self
|
317 + | }
|
318 + |
|
319 + | /// Build the rule with this response sequence
|
320 + | pub fn build(self) -> Rule {
|
321 + | let generators = self.generators;
|
322 + | let count = generators.len();
|
323 + |
|
324 + | Rule::new(
|
325 + | self.input_filter,
|
326 + | Arc::new(move |idx| {
|
327 + | if idx < count {
|
328 + | Some(generators[idx]())
|
329 + | } else {
|
330 + | None
|
331 + | }
|
332 + | }),
|
333 + | count,
|
334 + | )
|
335 + | }
|
336 + | }
|