1use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17 AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
20use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
21use aws_smithy_runtime_api::client::http::HttpClient;
22use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
23use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
24use aws_smithy_runtime_api::client::interceptors::Intercept;
25use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
26use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, Metadata};
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
29use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
30use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
31use aws_smithy_runtime_api::client::runtime_plugin::{
32 RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
33};
34use aws_smithy_runtime_api::client::ser_de::{
35 DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
36};
37use aws_smithy_runtime_api::shared::IntoShared;
38use aws_smithy_runtime_api::{
39 box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
40};
41use aws_smithy_types::config_bag::{ConfigBag, Layer};
42use aws_smithy_types::retry::RetryConfig;
43use aws_smithy_types::timeout::TimeoutConfig;
44use std::borrow::Cow;
45use std::fmt;
46use std::marker::PhantomData;
47use tracing::{debug_span, Instrument};
48
49struct FnSerializer<F, I> {
50 f: F,
51 _phantom: PhantomData<I>,
52}
53impl<F, I> FnSerializer<F, I> {
54 fn new(f: F) -> Self {
55 Self {
56 f,
57 _phantom: Default::default(),
58 }
59 }
60}
61impl<F, I> SerializeRequest for FnSerializer<F, I>
62where
63 F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
64 I: fmt::Debug + Send + Sync + 'static,
65{
66 fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
67 let input: I = input.downcast().expect("correct type");
68 (self.f)(input)
69 }
70}
71impl<F, I> fmt::Debug for FnSerializer<F, I> {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 write!(f, "FnSerializer")
74 }
75}
76
77struct FnDeserializer<F, O, E> {
78 f: F,
79 _phantom: PhantomData<(O, E)>,
80}
81impl<F, O, E> FnDeserializer<F, O, E> {
82 fn new(deserializer: F) -> Self {
83 Self {
84 f: deserializer,
85 _phantom: Default::default(),
86 }
87 }
88}
89impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
90where
91 F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
92 O: fmt::Debug + Send + Sync + 'static,
93 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
94{
95 fn deserialize_nonstreaming(
96 &self,
97 response: &HttpResponse,
98 ) -> Result<Output, OrchestratorError<Error>> {
99 (self.f)(response)
100 .map(|output| Output::erase(output))
101 .map_err(|err| err.map_operation_error(Error::erase))
102 }
103}
104impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 write!(f, "FnDeserializer")
107 }
108}
109
110#[derive(Debug)]
112pub struct Operation<I, O, E> {
113 service_name: Cow<'static, str>,
114 operation_name: Cow<'static, str>,
115 runtime_plugins: RuntimePlugins,
116 _phantom: PhantomData<(I, O, E)>,
117}
118
119impl<I, O, E> Clone for Operation<I, O, E> {
121 fn clone(&self) -> Self {
122 Self {
123 service_name: self.service_name.clone(),
124 operation_name: self.operation_name.clone(),
125 runtime_plugins: self.runtime_plugins.clone(),
126 _phantom: self._phantom,
127 }
128 }
129}
130
131impl Operation<(), (), ()> {
132 pub fn builder() -> OperationBuilder {
134 OperationBuilder::new()
135 }
136}
137
138impl<I, O, E> Operation<I, O, E>
139where
140 I: fmt::Debug + Send + Sync + 'static,
141 O: fmt::Debug + Send + Sync + 'static,
142 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
143{
144 pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
147 let input = Input::erase(input);
148
149 let output = super::invoke(
150 &self.service_name,
151 &self.operation_name,
152 input,
153 &self.runtime_plugins,
154 )
155 .instrument(debug_span!(
156 "invoke",
157 "rpc.service" = &self.service_name.as_ref(),
158 "rpc.method" = &self.operation_name.as_ref()
159 ))
160 .await
161 .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
162
163 Ok(output.downcast().expect("correct type"))
164 }
165}
166
167#[derive(Debug)]
169pub struct OperationBuilder<I = (), O = (), E = ()> {
170 service_name: Option<Cow<'static, str>>,
171 operation_name: Option<Cow<'static, str>>,
172 behavior_version: Option<BehaviorVersion>,
173 config: Layer,
174 runtime_components: RuntimeComponentsBuilder,
175 runtime_plugins: Vec<SharedRuntimePlugin>,
176 _phantom: PhantomData<(I, O, E)>,
177}
178
179impl Default for OperationBuilder<(), (), ()> {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185impl OperationBuilder<(), (), ()> {
186 pub fn new() -> Self {
188 Self {
189 service_name: None,
190 operation_name: None,
191 behavior_version: None,
192 config: Layer::new("operation"),
193 runtime_components: RuntimeComponentsBuilder::new("operation"),
194 runtime_plugins: Vec::new(),
195 _phantom: Default::default(),
196 }
197 }
198}
199
200impl<I, O, E> OperationBuilder<I, O, E> {
201 pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
203 self.service_name = Some(service_name.into());
204 self
205 }
206
207 pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
209 self.operation_name = Some(operation_name.into());
210 self
211 }
212
213 pub fn behavior_version(mut self, behavior_version: BehaviorVersion) -> Self {
215 self.behavior_version = Some(behavior_version);
216 self
217 }
218
219 pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
221 self.runtime_components.set_http_client(Some(connector));
222 self
223 }
224
225 pub fn endpoint_url(mut self, url: &str) -> Self {
227 self.config.store_put(EndpointResolverParams::new(()));
228 self.runtime_components
229 .set_endpoint_resolver(Some(SharedEndpointResolver::new(
230 StaticUriEndpointResolver::uri(url),
231 )));
232 self
233 }
234
235 pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
237 self.runtime_components
238 .push_retry_classifier(retry_classifier);
239 self
240 }
241
242 pub fn no_retry(mut self) -> Self {
244 self.runtime_components
245 .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
246 self
247 }
248
249 pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
251 self.config.store_put(retry_config.clone());
252 self.runtime_components
253 .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
254 self
255 }
256
257 pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
259 self.config.store_put(timeout_config);
260 self
261 }
262
263 pub fn no_auth(mut self) -> Self {
265 self.config
266 .store_put(AuthSchemeOptionResolverParams::new(()));
267 self.runtime_components
268 .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
269 StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
270 )));
271 self.runtime_components
272 .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
273 self.runtime_components
274 .set_identity_cache(Some(IdentityCache::no_cache()));
275 self.runtime_components.set_identity_resolver(
276 NO_AUTH_SCHEME_ID,
277 SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
278 );
279 self
280 }
281
282 pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
284 self.runtime_components
285 .set_sleep_impl(Some(async_sleep.into_shared()));
286 self
287 }
288
289 pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
291 self.runtime_components
292 .set_time_source(Some(time_source.into_shared()));
293 self
294 }
295
296 pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
298 self.runtime_components.push_interceptor(interceptor);
299 self
300 }
301
302 pub fn with_connection_poisoning(self) -> Self {
304 self.interceptor(ConnectionPoisoningInterceptor::new())
305 }
306
307 pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
309 self.runtime_plugins.push(runtime_plugin.into_shared());
310 self
311 }
312
313 pub fn stalled_stream_protection(
315 mut self,
316 stalled_stream_protection: StalledStreamProtectionConfig,
317 ) -> Self {
318 self.config.store_put(stalled_stream_protection);
319 self
320 }
321
322 pub fn serializer<I2>(
324 mut self,
325 serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
326 ) -> OperationBuilder<I2, O, E>
327 where
328 I2: fmt::Debug + Send + Sync + 'static,
329 {
330 self.config
331 .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
332 OperationBuilder {
333 service_name: self.service_name,
334 operation_name: self.operation_name,
335 behavior_version: self.behavior_version,
336 config: self.config,
337 runtime_components: self.runtime_components,
338 runtime_plugins: self.runtime_plugins,
339 _phantom: Default::default(),
340 }
341 }
342
343 pub fn deserializer<O2, E2>(
345 mut self,
346 deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
347 + Send
348 + Sync
349 + 'static,
350 ) -> OperationBuilder<I, O2, E2>
351 where
352 O2: fmt::Debug + Send + Sync + 'static,
353 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
354 {
355 self.config
356 .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
357 deserializer,
358 )));
359 OperationBuilder {
360 service_name: self.service_name,
361 operation_name: self.operation_name,
362 behavior_version: self.behavior_version,
363 config: self.config,
364 runtime_components: self.runtime_components,
365 runtime_plugins: self.runtime_plugins,
366 _phantom: Default::default(),
367 }
368 }
369
370 #[allow(clippy::implied_bounds_in_impls)] pub fn deserializer_impl<O2, E2>(
373 mut self,
374 deserializer: impl DeserializeResponse + Send + Sync + 'static,
375 ) -> OperationBuilder<I, O2, E2>
376 where
377 O2: fmt::Debug + Send + Sync + 'static,
378 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
379 {
380 let deserializer: SharedResponseDeserializer = deserializer.into_shared();
381 self.config.store_put(deserializer);
382
383 OperationBuilder {
384 service_name: self.service_name,
385 operation_name: self.operation_name,
386 behavior_version: self.behavior_version,
387 config: self.config,
388 runtime_components: self.runtime_components,
389 runtime_plugins: self.runtime_plugins,
390 _phantom: Default::default(),
391 }
392 }
393
394 pub fn build(self) -> Operation<I, O, E> {
396 let service_name = self.service_name.expect("service_name required");
397 let operation_name = self.operation_name.expect("operation_name required");
398 let mut config = self.config;
399 config.store_put(Metadata::new(operation_name.clone(), service_name.clone()));
400 let mut runtime_plugins = RuntimePlugins::new()
401 .with_client_plugins(default_plugins({
402 let mut params =
403 DefaultPluginParams::new().with_retry_partition_name(service_name.clone());
404 if let Some(bv) = self.behavior_version {
405 params = params.with_behavior_version(bv);
406 }
407 params
408 }))
409 .with_client_plugin(
410 StaticRuntimePlugin::new()
411 .with_config(config.freeze())
412 .with_runtime_components(self.runtime_components),
413 );
414 for runtime_plugin in self.runtime_plugins {
415 runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
416 }
417
418 #[cfg(debug_assertions)]
419 {
420 let mut config = ConfigBag::base();
421 let components = runtime_plugins
422 .apply_client_configuration(&mut config)
423 .expect("the runtime plugins should succeed");
424
425 assert!(
426 components.http_client().is_some(),
427 "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
428 );
429 assert!(
430 components.endpoint_resolver().is_some(),
431 "a endpoint_resolver is required"
432 );
433 assert!(
434 components.retry_strategy().is_some(),
435 "a retry_strategy is required"
436 );
437 assert!(
438 config.load::<SharedRequestSerializer>().is_some(),
439 "a serializer is required"
440 );
441 assert!(
442 config.load::<SharedResponseDeserializer>().is_some(),
443 "a deserializer is required"
444 );
445 assert!(
446 config.load::<EndpointResolverParams>().is_some(),
447 "endpoint resolver params are required"
448 );
449 assert!(
450 config.load::<TimeoutConfig>().is_some(),
451 "timeout config is required"
452 );
453 }
454
455 Operation {
456 service_name,
457 operation_name,
458 runtime_plugins,
459 _phantom: Default::default(),
460 }
461 }
462}
463
464#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
465mod tests {
466 use super::*;
467 use crate::client::retries::classifiers::HttpStatusCodeClassifier;
468 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
469 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
470 use aws_smithy_runtime_api::client::result::ConnectorError;
471 use aws_smithy_types::body::SdkBody;
472 use std::convert::Infallible;
473
474 #[tokio::test]
475 async fn operation() {
476 let (connector, request_rx) = capture_request(Some(
477 http_1x::Response::builder()
478 .status(418)
479 .body(SdkBody::from(&b"I'm a teapot!"[..]))
480 .unwrap(),
481 ));
482 let operation = Operation::builder()
483 .service_name("test")
484 .operation_name("test")
485 .http_client(connector)
486 .endpoint_url("http://localhost:1234")
487 .no_auth()
488 .no_retry()
489 .timeout_config(TimeoutConfig::disabled())
490 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
491 .deserializer::<_, Infallible>(|response| {
492 assert_eq!(418, u16::from(response.status()));
493 Ok(std::str::from_utf8(response.body().bytes().unwrap())
494 .unwrap()
495 .to_string())
496 })
497 .build();
498
499 let output = operation
500 .invoke("what are you?".to_string())
501 .await
502 .expect("success");
503 assert_eq!("I'm a teapot!", output);
504
505 let request = request_rx.expect_request();
506 assert_eq!("http://localhost:1234/", request.uri());
507 assert_eq!(b"what are you?", request.body().bytes().unwrap());
508 }
509
510 #[tokio::test]
511 async fn operation_retries() {
512 let connector = StaticReplayClient::new(vec![
513 ReplayEvent::new(
514 http_1x::Request::builder()
515 .uri("http://localhost:1234/")
516 .body(SdkBody::from(&b"what are you?"[..]))
517 .unwrap(),
518 http_1x::Response::builder()
519 .status(503)
520 .body(SdkBody::from(&b""[..]))
521 .unwrap(),
522 ),
523 ReplayEvent::new(
524 http_1x::Request::builder()
525 .uri("http://localhost:1234/")
526 .body(SdkBody::from(&b"what are you?"[..]))
527 .unwrap(),
528 http_1x::Response::builder()
529 .status(418)
530 .body(SdkBody::from(&b"I'm a teapot!"[..]))
531 .unwrap(),
532 ),
533 ]);
534 let operation = Operation::builder()
535 .service_name("test")
536 .operation_name("test")
537 .http_client(connector.clone())
538 .endpoint_url("http://localhost:1234")
539 .no_auth()
540 .standard_retry(&RetryConfig::standard())
541 .retry_classifier(HttpStatusCodeClassifier::default())
542 .timeout_config(TimeoutConfig::disabled())
543 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
544 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
545 .deserializer::<_, Infallible>(|response| {
546 if u16::from(response.status()) == 503 {
547 Err(OrchestratorError::connector(ConnectorError::io(
548 "test".into(),
549 )))
550 } else {
551 assert_eq!(418, u16::from(response.status()));
552 Ok(std::str::from_utf8(response.body().bytes().unwrap())
553 .unwrap()
554 .to_string())
555 }
556 })
557 .build();
558
559 let output = operation
560 .invoke("what are you?".to_string())
561 .await
562 .expect("success");
563 assert_eq!("I'm a teapot!", output);
564
565 connector.assert_requests_match(&[]);
566 }
567}