1use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
7use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_runtime_api::http::StatusCode;
10use aws_smithy_types::body::SdkBody;
11use std::fmt;
12use std::future::Future;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::Arc;
15
16#[derive(Debug)]
24pub(crate) enum MockResponse<O, E> {
25 Output(O),
27 Error(E),
29 Http(HttpResponse),
31}
32
33type MatchFn = Arc<dyn Fn(&Input) -> bool + Send + Sync>;
35type ServeFn = Arc<dyn Fn(usize) -> Option<MockResponse<Output, Error>> + Send + Sync>;
36
37#[derive(Clone)]
42pub struct Rule {
43 pub(crate) matcher: MatchFn,
45
46 pub(crate) response_handler: ServeFn,
48
49 pub(crate) call_count: Arc<AtomicUsize>,
51
52 pub(crate) max_responses: usize,
54}
55
56impl fmt::Debug for Rule {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 write!(f, "Rule")
59 }
60}
61
62impl Rule {
63 pub(crate) fn new<O, E>(
65 matcher: MatchFn,
66 response_handler: Arc<dyn Fn(usize) -> Option<MockResponse<O, E>> + Send + Sync>,
67 max_responses: usize,
68 ) -> Self
69 where
70 O: fmt::Debug + Send + Sync + 'static,
71 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
72 {
73 Rule {
74 matcher,
75 response_handler: Arc::new(move |idx: usize| {
76 if idx < max_responses {
77 response_handler(idx).map(|resp| match resp {
78 MockResponse::Output(o) => MockResponse::Output(Output::erase(o)),
79 MockResponse::Error(e) => MockResponse::Error(Error::erase(e)),
80 MockResponse::Http(http_resp) => MockResponse::Http(http_resp),
81 })
82 } else {
83 None
84 }
85 }),
86 call_count: Arc::new(AtomicUsize::new(0)),
87 max_responses,
88 }
89 }
90
91 pub(crate) fn next_response(&self) -> Option<MockResponse<Output, Error>> {
93 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
94 (self.response_handler)(idx)
95 }
96
97 pub fn num_calls(&self) -> usize {
99 self.call_count.load(Ordering::SeqCst)
100 }
101
102 pub fn is_exhausted(&self) -> bool {
104 self.num_calls() >= self.max_responses
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum RuleMode {
113 Sequential,
116 MatchAny,
120}
121
122pub struct RuleBuilder<I, O, E> {
127 pub(crate) input_filter: MatchFn,
129
130 pub(crate) _ty: std::marker::PhantomData<(I, O, E)>,
132}
133
134impl<I, O, E> RuleBuilder<I, O, E>
135where
136 I: fmt::Debug + Send + Sync + 'static,
137 O: fmt::Debug + Send + Sync + 'static,
138 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
139{
140 #[doc(hidden)]
142 pub fn new() -> Self {
143 RuleBuilder {
144 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
145 _ty: std::marker::PhantomData,
146 }
147 }
148
149 #[doc(hidden)]
151 pub fn new_from_mock<F, R>(_input_hint: impl Fn() -> I, _output_hint: impl Fn() -> F) -> Self
152 where
153 F: Future<Output = Result<O, SdkError<E, R>>>,
154 {
155 Self {
156 input_filter: Arc::new(|i: &Input| i.downcast_ref::<I>().is_some()),
157 _ty: Default::default(),
158 }
159 }
160
161 pub fn match_requests<F>(mut self, filter: F) -> Self
163 where
164 F: Fn(&I) -> bool + Send + Sync + 'static,
165 {
166 self.input_filter = Arc::new(move |i: &Input| match i.downcast_ref::<I>() {
167 Some(typed_input) => filter(typed_input),
168 _ => false,
169 });
170 self
171 }
172
173 pub fn sequence(self) -> ResponseSequenceBuilder<I, O, E> {
191 ResponseSequenceBuilder::new(self.input_filter)
192 }
193
194 pub fn then_output<F>(self, output_fn: F) -> Rule
196 where
197 F: Fn() -> O + Send + Sync + 'static,
198 {
199 self.sequence().output(output_fn).build()
200 }
201
202 pub fn then_error<F>(self, error_fn: F) -> Rule
204 where
205 F: Fn() -> E + Send + Sync + 'static,
206 {
207 self.sequence().error(error_fn).build()
208 }
209
210 pub fn then_http_response<F>(self, response_fn: F) -> Rule
212 where
213 F: Fn() -> HttpResponse + Send + Sync + 'static,
214 {
215 self.sequence().http_response(response_fn).build()
216 }
217}
218
219type SequenceGeneratorFn<O, E> = Arc<dyn Fn() -> MockResponse<O, E> + Send + Sync>;
220
221pub struct ResponseSequenceBuilder<I, O, E> {
223 generators: Vec<SequenceGeneratorFn<O, E>>,
225
226 input_filter: MatchFn,
228
229 _marker: std::marker::PhantomData<I>,
231}
232
233impl<I, O, E> ResponseSequenceBuilder<I, O, E>
234where
235 I: fmt::Debug + Send + Sync + 'static,
236 O: fmt::Debug + Send + Sync + 'static,
237 E: fmt::Debug + Send + Sync + std::error::Error + 'static,
238{
239 pub(crate) fn new(input_filter: MatchFn) -> Self {
241 Self {
242 generators: Vec::new(),
243 input_filter,
244 _marker: std::marker::PhantomData,
245 }
246 }
247
248 pub fn output<F>(mut self, output_fn: F) -> Self
259 where
260 F: Fn() -> O + Send + Sync + 'static,
261 {
262 let generator = Arc::new(move || MockResponse::Output(output_fn()));
263 self.generators.push(generator);
264 self
265 }
266
267 pub fn error<F>(mut self, error_fn: F) -> Self
269 where
270 F: Fn() -> E + Send + Sync + 'static,
271 {
272 let generator = Arc::new(move || MockResponse::Error(error_fn()));
273 self.generators.push(generator);
274 self
275 }
276
277 pub fn http_status(mut self, status: u16, body: Option<String>) -> Self {
279 let status_code = StatusCode::try_from(status).unwrap();
280
281 let generator: SequenceGeneratorFn<O, E> = match body {
282 Some(body) => Arc::new(move || {
283 MockResponse::Http(HttpResponse::new(status_code, SdkBody::from(body.clone())))
284 }),
285 None => Arc::new(move || {
286 MockResponse::Http(HttpResponse::new(status_code, SdkBody::empty()))
287 }),
288 };
289
290 self.generators.push(generator);
291 self
292 }
293
294 pub fn http_response<F>(mut self, response_fn: F) -> Self
296 where
297 F: Fn() -> HttpResponse + Send + Sync + 'static,
298 {
299 let generator = Arc::new(move || MockResponse::Http(response_fn()));
300 self.generators.push(generator);
301 self
302 }
303
304 pub fn times(mut self, count: usize) -> Self {
308 match count {
309 0 => panic!("repeat count must be greater than zero"),
310 1 => {
311 return self;
312 }
313 _ => {}
314 }
315
316 if let Some(last_generator) = self.generators.last().cloned() {
317 for _ in 1..count {
319 self.generators.push(last_generator.clone());
320 }
321 }
322 self
323 }
324
325 pub fn build(self) -> Rule {
327 let generators = self.generators;
328 let count = generators.len();
329
330 Rule::new(
331 self.input_filter,
332 Arc::new(move |idx| {
333 if idx < count {
334 Some(generators[idx]())
335 } else {
336 None
337 }
338 }),
339 count,
340 )
341 }
342}