aws_smithy_runtime/client/orchestrator/
operation.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17    AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
20use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
21use aws_smithy_runtime_api::client::http::HttpClient;
22use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
23use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
24use aws_smithy_runtime_api::client::interceptors::Intercept;
25use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
26use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, Metadata};
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
29use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
30use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
31use aws_smithy_runtime_api::client::runtime_plugin::{
32    RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
33};
34use aws_smithy_runtime_api::client::ser_de::{
35    DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
36};
37use aws_smithy_runtime_api::shared::IntoShared;
38use aws_smithy_runtime_api::{
39    box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
40};
41use aws_smithy_types::config_bag::{ConfigBag, Layer};
42use aws_smithy_types::retry::RetryConfig;
43use aws_smithy_types::timeout::TimeoutConfig;
44use std::borrow::Cow;
45use std::fmt;
46use std::marker::PhantomData;
47use tracing::{debug_span, Instrument};
48
49struct FnSerializer<F, I> {
50    f: F,
51    _phantom: PhantomData<I>,
52}
53impl<F, I> FnSerializer<F, I> {
54    fn new(f: F) -> Self {
55        Self {
56            f,
57            _phantom: Default::default(),
58        }
59    }
60}
61impl<F, I> SerializeRequest for FnSerializer<F, I>
62where
63    F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
64    I: fmt::Debug + Send + Sync + 'static,
65{
66    fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
67        let input: I = input.downcast().expect("correct type");
68        (self.f)(input)
69    }
70}
71impl<F, I> fmt::Debug for FnSerializer<F, I> {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        write!(f, "FnSerializer")
74    }
75}
76
77struct FnDeserializer<F, O, E> {
78    f: F,
79    _phantom: PhantomData<(O, E)>,
80}
81impl<F, O, E> FnDeserializer<F, O, E> {
82    fn new(deserializer: F) -> Self {
83        Self {
84            f: deserializer,
85            _phantom: Default::default(),
86        }
87    }
88}
89impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
90where
91    F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
92    O: fmt::Debug + Send + Sync + 'static,
93    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
94{
95    fn deserialize_nonstreaming(
96        &self,
97        response: &HttpResponse,
98    ) -> Result<Output, OrchestratorError<Error>> {
99        (self.f)(response)
100            .map(|output| Output::erase(output))
101            .map_err(|err| err.map_operation_error(Error::erase))
102    }
103}
104impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        write!(f, "FnDeserializer")
107    }
108}
109
110/// Orchestrates execution of a HTTP request without any modeled input or output.
111#[derive(Debug)]
112pub struct Operation<I, O, E> {
113    service_name: Cow<'static, str>,
114    operation_name: Cow<'static, str>,
115    runtime_plugins: RuntimePlugins,
116    _phantom: PhantomData<(I, O, E)>,
117}
118
119// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E
120impl<I, O, E> Clone for Operation<I, O, E> {
121    fn clone(&self) -> Self {
122        Self {
123            service_name: self.service_name.clone(),
124            operation_name: self.operation_name.clone(),
125            runtime_plugins: self.runtime_plugins.clone(),
126            _phantom: self._phantom,
127        }
128    }
129}
130
131impl Operation<(), (), ()> {
132    /// Returns a new `OperationBuilder` for the `Operation`.
133    pub fn builder() -> OperationBuilder {
134        OperationBuilder::new()
135    }
136}
137
138impl<I, O, E> Operation<I, O, E>
139where
140    I: fmt::Debug + Send + Sync + 'static,
141    O: fmt::Debug + Send + Sync + 'static,
142    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
143{
144    /// Invokes this `Operation` with the given `input` and returns either an output for success
145    /// or an [`SdkError`] for failure
146    pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
147        let input = Input::erase(input);
148
149        let output = super::invoke(
150            &self.service_name,
151            &self.operation_name,
152            input,
153            &self.runtime_plugins,
154        )
155        .instrument(debug_span!(
156            "invoke",
157            "rpc.service" = &self.service_name.as_ref(),
158            "rpc.method" = &self.operation_name.as_ref()
159        ))
160        .await
161        .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
162
163        Ok(output.downcast().expect("correct type"))
164    }
165}
166
167/// Builder for [`Operation`].
168#[derive(Debug)]
169pub struct OperationBuilder<I = (), O = (), E = ()> {
170    service_name: Option<Cow<'static, str>>,
171    operation_name: Option<Cow<'static, str>>,
172    behavior_version: Option<BehaviorVersion>,
173    config: Layer,
174    runtime_components: RuntimeComponentsBuilder,
175    runtime_plugins: Vec<SharedRuntimePlugin>,
176    _phantom: PhantomData<(I, O, E)>,
177}
178
179impl Default for OperationBuilder<(), (), ()> {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185impl OperationBuilder<(), (), ()> {
186    /// Creates a new [`OperationBuilder`].
187    pub fn new() -> Self {
188        Self {
189            service_name: None,
190            operation_name: None,
191            behavior_version: None,
192            config: Layer::new("operation"),
193            runtime_components: RuntimeComponentsBuilder::new("operation"),
194            runtime_plugins: Vec::new(),
195            _phantom: Default::default(),
196        }
197    }
198}
199
200impl<I, O, E> OperationBuilder<I, O, E> {
201    /// Configures the service name for the builder.
202    pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
203        self.service_name = Some(service_name.into());
204        self
205    }
206
207    /// Configures the operation name for the builder.
208    pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
209        self.operation_name = Some(operation_name.into());
210        self
211    }
212
213    /// Configures the behavior version for the builder.
214    pub fn behavior_version(mut self, behavior_version: BehaviorVersion) -> Self {
215        self.behavior_version = Some(behavior_version);
216        self
217    }
218
219    /// Configures the http client for the builder.
220    pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
221        self.runtime_components.set_http_client(Some(connector));
222        self
223    }
224
225    /// Configures the endpoint URL for the builder.
226    pub fn endpoint_url(mut self, url: &str) -> Self {
227        self.config.store_put(EndpointResolverParams::new(()));
228        self.runtime_components
229            .set_endpoint_resolver(Some(SharedEndpointResolver::new(
230                StaticUriEndpointResolver::uri(url),
231            )));
232        self
233    }
234
235    /// Configures the retry classifier for the builder.
236    pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
237        self.runtime_components
238            .push_retry_classifier(retry_classifier);
239        self
240    }
241
242    /// Disables the retry for the operation.
243    pub fn no_retry(mut self) -> Self {
244        self.runtime_components
245            .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
246        self
247    }
248
249    /// Configures the standard retry for the builder.
250    pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
251        self.config.store_put(retry_config.clone());
252        self.runtime_components
253            .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
254        self
255    }
256
257    /// Configures the timeout configuration for the builder.
258    pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
259        self.config.store_put(timeout_config);
260        self
261    }
262
263    /// Disables auth for the operation.
264    pub fn no_auth(mut self) -> Self {
265        self.config
266            .store_put(AuthSchemeOptionResolverParams::new(()));
267        self.runtime_components
268            .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
269                StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
270            )));
271        self.runtime_components
272            .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
273        self.runtime_components
274            .set_identity_cache(Some(IdentityCache::no_cache()));
275        self.runtime_components.set_identity_resolver(
276            NO_AUTH_SCHEME_ID,
277            SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
278        );
279        self
280    }
281
282    /// Configures the sleep for the builder.
283    pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
284        self.runtime_components
285            .set_sleep_impl(Some(async_sleep.into_shared()));
286        self
287    }
288
289    /// Configures the time source for the builder.
290    pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
291        self.runtime_components
292            .set_time_source(Some(time_source.into_shared()));
293        self
294    }
295
296    /// Configures the interceptor for the builder.
297    pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
298        self.runtime_components.push_interceptor(interceptor);
299        self
300    }
301
302    /// Registers the [`ConnectionPoisoningInterceptor`].
303    pub fn with_connection_poisoning(self) -> Self {
304        self.interceptor(ConnectionPoisoningInterceptor::new())
305    }
306
307    /// Configures the runtime plugin for the builder.
308    pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
309        self.runtime_plugins.push(runtime_plugin.into_shared());
310        self
311    }
312
313    /// Configures stalled stream protection with the given config.
314    pub fn stalled_stream_protection(
315        mut self,
316        stalled_stream_protection: StalledStreamProtectionConfig,
317    ) -> Self {
318        self.config.store_put(stalled_stream_protection);
319        self
320    }
321
322    /// Configures the serializer for the builder.
323    pub fn serializer<I2>(
324        mut self,
325        serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
326    ) -> OperationBuilder<I2, O, E>
327    where
328        I2: fmt::Debug + Send + Sync + 'static,
329    {
330        self.config
331            .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
332        OperationBuilder {
333            service_name: self.service_name,
334            operation_name: self.operation_name,
335            behavior_version: self.behavior_version,
336            config: self.config,
337            runtime_components: self.runtime_components,
338            runtime_plugins: self.runtime_plugins,
339            _phantom: Default::default(),
340        }
341    }
342
343    /// Configures the deserializer for the builder.
344    pub fn deserializer<O2, E2>(
345        mut self,
346        deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
347            + Send
348            + Sync
349            + 'static,
350    ) -> OperationBuilder<I, O2, E2>
351    where
352        O2: fmt::Debug + Send + Sync + 'static,
353        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
354    {
355        self.config
356            .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
357                deserializer,
358            )));
359        OperationBuilder {
360            service_name: self.service_name,
361            operation_name: self.operation_name,
362            behavior_version: self.behavior_version,
363            config: self.config,
364            runtime_components: self.runtime_components,
365            runtime_plugins: self.runtime_plugins,
366            _phantom: Default::default(),
367        }
368    }
369
370    /// Configures the a deserializer implementation for the builder.
371    #[allow(clippy::implied_bounds_in_impls)] // for `Send` and `Sync`
372    pub fn deserializer_impl<O2, E2>(
373        mut self,
374        deserializer: impl DeserializeResponse + Send + Sync + 'static,
375    ) -> OperationBuilder<I, O2, E2>
376    where
377        O2: fmt::Debug + Send + Sync + 'static,
378        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
379    {
380        let deserializer: SharedResponseDeserializer = deserializer.into_shared();
381        self.config.store_put(deserializer);
382
383        OperationBuilder {
384            service_name: self.service_name,
385            operation_name: self.operation_name,
386            behavior_version: self.behavior_version,
387            config: self.config,
388            runtime_components: self.runtime_components,
389            runtime_plugins: self.runtime_plugins,
390            _phantom: Default::default(),
391        }
392    }
393
394    /// Creates an `Operation` from the builder.
395    pub fn build(self) -> Operation<I, O, E> {
396        let service_name = self.service_name.expect("service_name required");
397        let operation_name = self.operation_name.expect("operation_name required");
398        let mut config = self.config;
399        config.store_put(Metadata::new(operation_name.clone(), service_name.clone()));
400        let mut runtime_plugins = RuntimePlugins::new()
401            .with_client_plugins(default_plugins({
402                let mut params =
403                    DefaultPluginParams::new().with_retry_partition_name(service_name.clone());
404                if let Some(bv) = self.behavior_version {
405                    params = params.with_behavior_version(bv);
406                }
407                params
408            }))
409            .with_client_plugin(
410                StaticRuntimePlugin::new()
411                    .with_config(config.freeze())
412                    .with_runtime_components(self.runtime_components),
413            );
414        for runtime_plugin in self.runtime_plugins {
415            runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
416        }
417
418        #[cfg(debug_assertions)]
419        {
420            let mut config = ConfigBag::base();
421            let components = runtime_plugins
422                .apply_client_configuration(&mut config)
423                .expect("the runtime plugins should succeed");
424
425            assert!(
426                components.http_client().is_some(),
427                "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
428            );
429            assert!(
430                components.endpoint_resolver().is_some(),
431                "a endpoint_resolver is required"
432            );
433            assert!(
434                components.retry_strategy().is_some(),
435                "a retry_strategy is required"
436            );
437            assert!(
438                config.load::<SharedRequestSerializer>().is_some(),
439                "a serializer is required"
440            );
441            assert!(
442                config.load::<SharedResponseDeserializer>().is_some(),
443                "a deserializer is required"
444            );
445            assert!(
446                config.load::<EndpointResolverParams>().is_some(),
447                "endpoint resolver params are required"
448            );
449            assert!(
450                config.load::<TimeoutConfig>().is_some(),
451                "timeout config is required"
452            );
453        }
454
455        Operation {
456            service_name,
457            operation_name,
458            runtime_plugins,
459            _phantom: Default::default(),
460        }
461    }
462}
463
464#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
465mod tests {
466    use super::*;
467    use crate::client::retries::classifiers::HttpStatusCodeClassifier;
468    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
469    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
470    use aws_smithy_runtime_api::client::result::ConnectorError;
471    use aws_smithy_types::body::SdkBody;
472    use std::convert::Infallible;
473
474    #[tokio::test]
475    async fn operation() {
476        let (connector, request_rx) = capture_request(Some(
477            http_1x::Response::builder()
478                .status(418)
479                .body(SdkBody::from(&b"I'm a teapot!"[..]))
480                .unwrap(),
481        ));
482        let operation = Operation::builder()
483            .service_name("test")
484            .operation_name("test")
485            .http_client(connector)
486            .endpoint_url("http://localhost:1234")
487            .no_auth()
488            .no_retry()
489            .timeout_config(TimeoutConfig::disabled())
490            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
491            .deserializer::<_, Infallible>(|response| {
492                assert_eq!(418, u16::from(response.status()));
493                Ok(std::str::from_utf8(response.body().bytes().unwrap())
494                    .unwrap()
495                    .to_string())
496            })
497            .build();
498
499        let output = operation
500            .invoke("what are you?".to_string())
501            .await
502            .expect("success");
503        assert_eq!("I'm a teapot!", output);
504
505        let request = request_rx.expect_request();
506        assert_eq!("http://localhost:1234/", request.uri());
507        assert_eq!(b"what are you?", request.body().bytes().unwrap());
508    }
509
510    #[tokio::test]
511    async fn operation_retries() {
512        let connector = StaticReplayClient::new(vec![
513            ReplayEvent::new(
514                http_1x::Request::builder()
515                    .uri("http://localhost:1234/")
516                    .body(SdkBody::from(&b"what are you?"[..]))
517                    .unwrap(),
518                http_1x::Response::builder()
519                    .status(503)
520                    .body(SdkBody::from(&b""[..]))
521                    .unwrap(),
522            ),
523            ReplayEvent::new(
524                http_1x::Request::builder()
525                    .uri("http://localhost:1234/")
526                    .body(SdkBody::from(&b"what are you?"[..]))
527                    .unwrap(),
528                http_1x::Response::builder()
529                    .status(418)
530                    .body(SdkBody::from(&b"I'm a teapot!"[..]))
531                    .unwrap(),
532            ),
533        ]);
534        let operation = Operation::builder()
535            .service_name("test")
536            .operation_name("test")
537            .http_client(connector.clone())
538            .endpoint_url("http://localhost:1234")
539            .no_auth()
540            .standard_retry(&RetryConfig::standard())
541            .retry_classifier(HttpStatusCodeClassifier::default())
542            .timeout_config(TimeoutConfig::disabled())
543            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
544            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
545            .deserializer::<_, Infallible>(|response| {
546                if u16::from(response.status()) == 503 {
547                    Err(OrchestratorError::connector(ConnectorError::io(
548                        "test".into(),
549                    )))
550                } else {
551                    assert_eq!(418, u16::from(response.status()));
552                    Ok(std::str::from_utf8(response.body().bytes().unwrap())
553                        .unwrap()
554                        .to_string())
555                }
556            })
557            .build();
558
559        let output = operation
560            .invoke("what are you?".to_string())
561            .await
562            .expect("success");
563        assert_eq!("I'm a teapot!", output);
564
565        connector.assert_requests_match(&[]);
566    }
567}