1use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17 AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
20use aws_smithy_runtime_api::client::http::HttpClient;
21use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
22use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
23use aws_smithy_runtime_api::client::interceptors::Intercept;
24use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
25use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, Metadata};
26use aws_smithy_runtime_api::client::result::SdkError;
27use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
28use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
29use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
30use aws_smithy_runtime_api::client::runtime_plugin::{
31 RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
32};
33use aws_smithy_runtime_api::client::ser_de::{
34 DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
35};
36use aws_smithy_runtime_api::shared::IntoShared;
37use aws_smithy_runtime_api::{
38 box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
39};
40use aws_smithy_types::config_bag::{ConfigBag, Layer};
41use aws_smithy_types::retry::RetryConfig;
42use aws_smithy_types::timeout::TimeoutConfig;
43use std::borrow::Cow;
44use std::fmt;
45use std::marker::PhantomData;
46use tracing::{debug_span, Instrument};
47
48struct FnSerializer<F, I> {
49 f: F,
50 _phantom: PhantomData<I>,
51}
52impl<F, I> FnSerializer<F, I> {
53 fn new(f: F) -> Self {
54 Self {
55 f,
56 _phantom: Default::default(),
57 }
58 }
59}
60impl<F, I> SerializeRequest for FnSerializer<F, I>
61where
62 F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
63 I: fmt::Debug + Send + Sync + 'static,
64{
65 fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
66 let input: I = input.downcast().expect("correct type");
67 (self.f)(input)
68 }
69}
70impl<F, I> fmt::Debug for FnSerializer<F, I> {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "FnSerializer")
73 }
74}
75
76struct FnDeserializer<F, O, E> {
77 f: F,
78 _phantom: PhantomData<(O, E)>,
79}
80impl<F, O, E> FnDeserializer<F, O, E> {
81 fn new(deserializer: F) -> Self {
82 Self {
83 f: deserializer,
84 _phantom: Default::default(),
85 }
86 }
87}
88impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
89where
90 F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
91 O: fmt::Debug + Send + Sync + 'static,
92 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
93{
94 fn deserialize_nonstreaming(
95 &self,
96 response: &HttpResponse,
97 ) -> Result<Output, OrchestratorError<Error>> {
98 (self.f)(response)
99 .map(|output| Output::erase(output))
100 .map_err(|err| err.map_operation_error(Error::erase))
101 }
102}
103impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 write!(f, "FnDeserializer")
106 }
107}
108
109#[derive(Debug)]
111pub struct Operation<I, O, E> {
112 service_name: Cow<'static, str>,
113 operation_name: Cow<'static, str>,
114 runtime_plugins: RuntimePlugins,
115 _phantom: PhantomData<(I, O, E)>,
116}
117
118impl<I, O, E> Clone for Operation<I, O, E> {
120 fn clone(&self) -> Self {
121 Self {
122 service_name: self.service_name.clone(),
123 operation_name: self.operation_name.clone(),
124 runtime_plugins: self.runtime_plugins.clone(),
125 _phantom: self._phantom,
126 }
127 }
128}
129
130impl Operation<(), (), ()> {
131 pub fn builder() -> OperationBuilder {
133 OperationBuilder::new()
134 }
135}
136
137impl<I, O, E> Operation<I, O, E>
138where
139 I: fmt::Debug + Send + Sync + 'static,
140 O: fmt::Debug + Send + Sync + 'static,
141 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
142{
143 pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
146 let input = Input::erase(input);
147
148 let output = super::invoke(
149 &self.service_name,
150 &self.operation_name,
151 input,
152 &self.runtime_plugins,
153 )
154 .instrument(debug_span!(
155 "invoke",
156 "rpc.service" = &self.service_name.as_ref(),
157 "rpc.method" = &self.operation_name.as_ref()
158 ))
159 .await
160 .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
161
162 Ok(output.downcast().expect("correct type"))
163 }
164}
165
166#[derive(Debug)]
168pub struct OperationBuilder<I = (), O = (), E = ()> {
169 service_name: Option<Cow<'static, str>>,
170 operation_name: Option<Cow<'static, str>>,
171 config: Layer,
172 runtime_components: RuntimeComponentsBuilder,
173 runtime_plugins: Vec<SharedRuntimePlugin>,
174 _phantom: PhantomData<(I, O, E)>,
175}
176
177impl Default for OperationBuilder<(), (), ()> {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl OperationBuilder<(), (), ()> {
184 pub fn new() -> Self {
186 Self {
187 service_name: None,
188 operation_name: None,
189 config: Layer::new("operation"),
190 runtime_components: RuntimeComponentsBuilder::new("operation"),
191 runtime_plugins: Vec::new(),
192 _phantom: Default::default(),
193 }
194 }
195}
196
197impl<I, O, E> OperationBuilder<I, O, E> {
198 pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
200 self.service_name = Some(service_name.into());
201 self
202 }
203
204 pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
206 self.operation_name = Some(operation_name.into());
207 self
208 }
209
210 pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
212 self.runtime_components.set_http_client(Some(connector));
213 self
214 }
215
216 pub fn endpoint_url(mut self, url: &str) -> Self {
218 self.config.store_put(EndpointResolverParams::new(()));
219 self.runtime_components
220 .set_endpoint_resolver(Some(SharedEndpointResolver::new(
221 StaticUriEndpointResolver::uri(url),
222 )));
223 self
224 }
225
226 pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
228 self.runtime_components
229 .push_retry_classifier(retry_classifier);
230 self
231 }
232
233 pub fn no_retry(mut self) -> Self {
235 self.runtime_components
236 .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
237 self
238 }
239
240 pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
242 self.config.store_put(retry_config.clone());
243 self.runtime_components
244 .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
245 self
246 }
247
248 pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
250 self.config.store_put(timeout_config);
251 self
252 }
253
254 pub fn no_auth(mut self) -> Self {
256 self.config
257 .store_put(AuthSchemeOptionResolverParams::new(()));
258 self.runtime_components
259 .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
260 StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
261 )));
262 self.runtime_components
263 .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
264 self.runtime_components
265 .set_identity_cache(Some(IdentityCache::no_cache()));
266 self.runtime_components.set_identity_resolver(
267 NO_AUTH_SCHEME_ID,
268 SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
269 );
270 self
271 }
272
273 pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
275 self.runtime_components
276 .set_sleep_impl(Some(async_sleep.into_shared()));
277 self
278 }
279
280 pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
282 self.runtime_components
283 .set_time_source(Some(time_source.into_shared()));
284 self
285 }
286
287 pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
289 self.runtime_components.push_interceptor(interceptor);
290 self
291 }
292
293 pub fn with_connection_poisoning(self) -> Self {
295 self.interceptor(ConnectionPoisoningInterceptor::new())
296 }
297
298 pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
300 self.runtime_plugins.push(runtime_plugin.into_shared());
301 self
302 }
303
304 pub fn stalled_stream_protection(
306 mut self,
307 stalled_stream_protection: StalledStreamProtectionConfig,
308 ) -> Self {
309 self.config.store_put(stalled_stream_protection);
310 self
311 }
312
313 pub fn serializer<I2>(
315 mut self,
316 serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
317 ) -> OperationBuilder<I2, O, E>
318 where
319 I2: fmt::Debug + Send + Sync + 'static,
320 {
321 self.config
322 .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
323 OperationBuilder {
324 service_name: self.service_name,
325 operation_name: self.operation_name,
326 config: self.config,
327 runtime_components: self.runtime_components,
328 runtime_plugins: self.runtime_plugins,
329 _phantom: Default::default(),
330 }
331 }
332
333 pub fn deserializer<O2, E2>(
335 mut self,
336 deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
337 + Send
338 + Sync
339 + 'static,
340 ) -> OperationBuilder<I, O2, E2>
341 where
342 O2: fmt::Debug + Send + Sync + 'static,
343 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
344 {
345 self.config
346 .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
347 deserializer,
348 )));
349 OperationBuilder {
350 service_name: self.service_name,
351 operation_name: self.operation_name,
352 config: self.config,
353 runtime_components: self.runtime_components,
354 runtime_plugins: self.runtime_plugins,
355 _phantom: Default::default(),
356 }
357 }
358
359 #[allow(clippy::implied_bounds_in_impls)] pub fn deserializer_impl<O2, E2>(
362 mut self,
363 deserializer: impl DeserializeResponse + Send + Sync + 'static,
364 ) -> OperationBuilder<I, O2, E2>
365 where
366 O2: fmt::Debug + Send + Sync + 'static,
367 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
368 {
369 let deserializer: SharedResponseDeserializer = deserializer.into_shared();
370 self.config.store_put(deserializer);
371
372 OperationBuilder {
373 service_name: self.service_name,
374 operation_name: self.operation_name,
375 config: self.config,
376 runtime_components: self.runtime_components,
377 runtime_plugins: self.runtime_plugins,
378 _phantom: Default::default(),
379 }
380 }
381
382 pub fn build(self) -> Operation<I, O, E> {
384 let service_name = self.service_name.expect("service_name required");
385 let operation_name = self.operation_name.expect("operation_name required");
386 let mut config = self.config;
387 config.store_put(Metadata::new(operation_name.clone(), service_name.clone()));
388 let mut runtime_plugins = RuntimePlugins::new()
389 .with_client_plugins(default_plugins(
390 DefaultPluginParams::new().with_retry_partition_name(service_name.clone()),
391 ))
392 .with_client_plugin(
393 StaticRuntimePlugin::new()
394 .with_config(config.freeze())
395 .with_runtime_components(self.runtime_components),
396 );
397 for runtime_plugin in self.runtime_plugins {
398 runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
399 }
400
401 #[cfg(debug_assertions)]
402 {
403 let mut config = ConfigBag::base();
404 let components = runtime_plugins
405 .apply_client_configuration(&mut config)
406 .expect("the runtime plugins should succeed");
407
408 assert!(
409 components.http_client().is_some(),
410 "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
411 );
412 assert!(
413 components.endpoint_resolver().is_some(),
414 "a endpoint_resolver is required"
415 );
416 assert!(
417 components.retry_strategy().is_some(),
418 "a retry_strategy is required"
419 );
420 assert!(
421 config.load::<SharedRequestSerializer>().is_some(),
422 "a serializer is required"
423 );
424 assert!(
425 config.load::<SharedResponseDeserializer>().is_some(),
426 "a deserializer is required"
427 );
428 assert!(
429 config.load::<EndpointResolverParams>().is_some(),
430 "endpoint resolver params are required"
431 );
432 assert!(
433 config.load::<TimeoutConfig>().is_some(),
434 "timeout config is required"
435 );
436 }
437
438 Operation {
439 service_name,
440 operation_name,
441 runtime_plugins,
442 _phantom: Default::default(),
443 }
444 }
445}
446
447#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
448mod tests {
449 use super::*;
450 use crate::client::retries::classifiers::HttpStatusCodeClassifier;
451 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
452 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
453 use aws_smithy_runtime_api::client::result::ConnectorError;
454 use aws_smithy_types::body::SdkBody;
455 use std::convert::Infallible;
456
457 #[tokio::test]
458 async fn operation() {
459 let (connector, request_rx) = capture_request(Some(
460 http_1x::Response::builder()
461 .status(418)
462 .body(SdkBody::from(&b"I'm a teapot!"[..]))
463 .unwrap(),
464 ));
465 let operation = Operation::builder()
466 .service_name("test")
467 .operation_name("test")
468 .http_client(connector)
469 .endpoint_url("http://localhost:1234")
470 .no_auth()
471 .no_retry()
472 .timeout_config(TimeoutConfig::disabled())
473 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
474 .deserializer::<_, Infallible>(|response| {
475 assert_eq!(418, u16::from(response.status()));
476 Ok(std::str::from_utf8(response.body().bytes().unwrap())
477 .unwrap()
478 .to_string())
479 })
480 .build();
481
482 let output = operation
483 .invoke("what are you?".to_string())
484 .await
485 .expect("success");
486 assert_eq!("I'm a teapot!", output);
487
488 let request = request_rx.expect_request();
489 assert_eq!("http://localhost:1234/", request.uri());
490 assert_eq!(b"what are you?", request.body().bytes().unwrap());
491 }
492
493 #[tokio::test]
494 async fn operation_retries() {
495 let connector = StaticReplayClient::new(vec![
496 ReplayEvent::new(
497 http_1x::Request::builder()
498 .uri("http://localhost:1234/")
499 .body(SdkBody::from(&b"what are you?"[..]))
500 .unwrap(),
501 http_1x::Response::builder()
502 .status(503)
503 .body(SdkBody::from(&b""[..]))
504 .unwrap(),
505 ),
506 ReplayEvent::new(
507 http_1x::Request::builder()
508 .uri("http://localhost:1234/")
509 .body(SdkBody::from(&b"what are you?"[..]))
510 .unwrap(),
511 http_1x::Response::builder()
512 .status(418)
513 .body(SdkBody::from(&b"I'm a teapot!"[..]))
514 .unwrap(),
515 ),
516 ]);
517 let operation = Operation::builder()
518 .service_name("test")
519 .operation_name("test")
520 .http_client(connector.clone())
521 .endpoint_url("http://localhost:1234")
522 .no_auth()
523 .standard_retry(&RetryConfig::standard())
524 .retry_classifier(HttpStatusCodeClassifier::default())
525 .timeout_config(TimeoutConfig::disabled())
526 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
527 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
528 .deserializer::<_, Infallible>(|response| {
529 if u16::from(response.status()) == 503 {
530 Err(OrchestratorError::connector(ConnectorError::io(
531 "test".into(),
532 )))
533 } else {
534 assert_eq!(418, u16::from(response.status()));
535 Ok(std::str::from_utf8(response.body().bytes().unwrap())
536 .unwrap()
537 .to_string())
538 }
539 })
540 .build();
541
542 let output = operation
543 .invoke("what are you?".to_string())
544 .await
545 .expect("success");
546 assert_eq!("I'm a teapot!", output);
547
548 connector.assert_requests_match(&[]);
549 }
550}