aws_smithy_runtime/client/orchestrator/
operation.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17    AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
20use aws_smithy_runtime_api::client::http::HttpClient;
21use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
22use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
23use aws_smithy_runtime_api::client::interceptors::Intercept;
24use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
25use aws_smithy_runtime_api::client::orchestrator::{HttpResponse, Metadata};
26use aws_smithy_runtime_api::client::result::SdkError;
27use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
28use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
29use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
30use aws_smithy_runtime_api::client::runtime_plugin::{
31    RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
32};
33use aws_smithy_runtime_api::client::ser_de::{
34    DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
35};
36use aws_smithy_runtime_api::shared::IntoShared;
37use aws_smithy_runtime_api::{
38    box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
39};
40use aws_smithy_types::config_bag::{ConfigBag, Layer};
41use aws_smithy_types::retry::RetryConfig;
42use aws_smithy_types::timeout::TimeoutConfig;
43use std::borrow::Cow;
44use std::fmt;
45use std::marker::PhantomData;
46use tracing::{debug_span, Instrument};
47
48struct FnSerializer<F, I> {
49    f: F,
50    _phantom: PhantomData<I>,
51}
52impl<F, I> FnSerializer<F, I> {
53    fn new(f: F) -> Self {
54        Self {
55            f,
56            _phantom: Default::default(),
57        }
58    }
59}
60impl<F, I> SerializeRequest for FnSerializer<F, I>
61where
62    F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
63    I: fmt::Debug + Send + Sync + 'static,
64{
65    fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
66        let input: I = input.downcast().expect("correct type");
67        (self.f)(input)
68    }
69}
70impl<F, I> fmt::Debug for FnSerializer<F, I> {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        write!(f, "FnSerializer")
73    }
74}
75
76struct FnDeserializer<F, O, E> {
77    f: F,
78    _phantom: PhantomData<(O, E)>,
79}
80impl<F, O, E> FnDeserializer<F, O, E> {
81    fn new(deserializer: F) -> Self {
82        Self {
83            f: deserializer,
84            _phantom: Default::default(),
85        }
86    }
87}
88impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
89where
90    F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
91    O: fmt::Debug + Send + Sync + 'static,
92    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
93{
94    fn deserialize_nonstreaming(
95        &self,
96        response: &HttpResponse,
97    ) -> Result<Output, OrchestratorError<Error>> {
98        (self.f)(response)
99            .map(|output| Output::erase(output))
100            .map_err(|err| err.map_operation_error(Error::erase))
101    }
102}
103impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        write!(f, "FnDeserializer")
106    }
107}
108
109/// Orchestrates execution of a HTTP request without any modeled input or output.
110#[derive(Debug)]
111pub struct Operation<I, O, E> {
112    service_name: Cow<'static, str>,
113    operation_name: Cow<'static, str>,
114    runtime_plugins: RuntimePlugins,
115    _phantom: PhantomData<(I, O, E)>,
116}
117
118// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E
119impl<I, O, E> Clone for Operation<I, O, E> {
120    fn clone(&self) -> Self {
121        Self {
122            service_name: self.service_name.clone(),
123            operation_name: self.operation_name.clone(),
124            runtime_plugins: self.runtime_plugins.clone(),
125            _phantom: self._phantom,
126        }
127    }
128}
129
130impl Operation<(), (), ()> {
131    /// Returns a new `OperationBuilder` for the `Operation`.
132    pub fn builder() -> OperationBuilder {
133        OperationBuilder::new()
134    }
135}
136
137impl<I, O, E> Operation<I, O, E>
138where
139    I: fmt::Debug + Send + Sync + 'static,
140    O: fmt::Debug + Send + Sync + 'static,
141    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
142{
143    /// Invokes this `Operation` with the given `input` and returns either an output for success
144    /// or an [`SdkError`] for failure
145    pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
146        let input = Input::erase(input);
147
148        let output = super::invoke(
149            &self.service_name,
150            &self.operation_name,
151            input,
152            &self.runtime_plugins,
153        )
154        .instrument(debug_span!(
155            "invoke",
156            "rpc.service" = &self.service_name.as_ref(),
157            "rpc.method" = &self.operation_name.as_ref()
158        ))
159        .await
160        .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
161
162        Ok(output.downcast().expect("correct type"))
163    }
164}
165
166/// Builder for [`Operation`].
167#[derive(Debug)]
168pub struct OperationBuilder<I = (), O = (), E = ()> {
169    service_name: Option<Cow<'static, str>>,
170    operation_name: Option<Cow<'static, str>>,
171    config: Layer,
172    runtime_components: RuntimeComponentsBuilder,
173    runtime_plugins: Vec<SharedRuntimePlugin>,
174    _phantom: PhantomData<(I, O, E)>,
175}
176
177impl Default for OperationBuilder<(), (), ()> {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl OperationBuilder<(), (), ()> {
184    /// Creates a new [`OperationBuilder`].
185    pub fn new() -> Self {
186        Self {
187            service_name: None,
188            operation_name: None,
189            config: Layer::new("operation"),
190            runtime_components: RuntimeComponentsBuilder::new("operation"),
191            runtime_plugins: Vec::new(),
192            _phantom: Default::default(),
193        }
194    }
195}
196
197impl<I, O, E> OperationBuilder<I, O, E> {
198    /// Configures the service name for the builder.
199    pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
200        self.service_name = Some(service_name.into());
201        self
202    }
203
204    /// Configures the operation name for the builder.
205    pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
206        self.operation_name = Some(operation_name.into());
207        self
208    }
209
210    /// Configures the http client for the builder.
211    pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
212        self.runtime_components.set_http_client(Some(connector));
213        self
214    }
215
216    /// Configures the endpoint URL for the builder.
217    pub fn endpoint_url(mut self, url: &str) -> Self {
218        self.config.store_put(EndpointResolverParams::new(()));
219        self.runtime_components
220            .set_endpoint_resolver(Some(SharedEndpointResolver::new(
221                StaticUriEndpointResolver::uri(url),
222            )));
223        self
224    }
225
226    /// Configures the retry classifier for the builder.
227    pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
228        self.runtime_components
229            .push_retry_classifier(retry_classifier);
230        self
231    }
232
233    /// Disables the retry for the operation.
234    pub fn no_retry(mut self) -> Self {
235        self.runtime_components
236            .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
237        self
238    }
239
240    /// Configures the standard retry for the builder.
241    pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
242        self.config.store_put(retry_config.clone());
243        self.runtime_components
244            .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
245        self
246    }
247
248    /// Configures the timeout configuration for the builder.
249    pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
250        self.config.store_put(timeout_config);
251        self
252    }
253
254    /// Disables auth for the operation.
255    pub fn no_auth(mut self) -> Self {
256        self.config
257            .store_put(AuthSchemeOptionResolverParams::new(()));
258        self.runtime_components
259            .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
260                StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
261            )));
262        self.runtime_components
263            .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
264        self.runtime_components
265            .set_identity_cache(Some(IdentityCache::no_cache()));
266        self.runtime_components.set_identity_resolver(
267            NO_AUTH_SCHEME_ID,
268            SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
269        );
270        self
271    }
272
273    /// Configures the sleep for the builder.
274    pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
275        self.runtime_components
276            .set_sleep_impl(Some(async_sleep.into_shared()));
277        self
278    }
279
280    /// Configures the time source for the builder.
281    pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
282        self.runtime_components
283            .set_time_source(Some(time_source.into_shared()));
284        self
285    }
286
287    /// Configures the interceptor for the builder.
288    pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
289        self.runtime_components.push_interceptor(interceptor);
290        self
291    }
292
293    /// Registers the [`ConnectionPoisoningInterceptor`].
294    pub fn with_connection_poisoning(self) -> Self {
295        self.interceptor(ConnectionPoisoningInterceptor::new())
296    }
297
298    /// Configures the runtime plugin for the builder.
299    pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
300        self.runtime_plugins.push(runtime_plugin.into_shared());
301        self
302    }
303
304    /// Configures stalled stream protection with the given config.
305    pub fn stalled_stream_protection(
306        mut self,
307        stalled_stream_protection: StalledStreamProtectionConfig,
308    ) -> Self {
309        self.config.store_put(stalled_stream_protection);
310        self
311    }
312
313    /// Configures the serializer for the builder.
314    pub fn serializer<I2>(
315        mut self,
316        serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
317    ) -> OperationBuilder<I2, O, E>
318    where
319        I2: fmt::Debug + Send + Sync + 'static,
320    {
321        self.config
322            .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
323        OperationBuilder {
324            service_name: self.service_name,
325            operation_name: self.operation_name,
326            config: self.config,
327            runtime_components: self.runtime_components,
328            runtime_plugins: self.runtime_plugins,
329            _phantom: Default::default(),
330        }
331    }
332
333    /// Configures the deserializer for the builder.
334    pub fn deserializer<O2, E2>(
335        mut self,
336        deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
337            + Send
338            + Sync
339            + 'static,
340    ) -> OperationBuilder<I, O2, E2>
341    where
342        O2: fmt::Debug + Send + Sync + 'static,
343        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
344    {
345        self.config
346            .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
347                deserializer,
348            )));
349        OperationBuilder {
350            service_name: self.service_name,
351            operation_name: self.operation_name,
352            config: self.config,
353            runtime_components: self.runtime_components,
354            runtime_plugins: self.runtime_plugins,
355            _phantom: Default::default(),
356        }
357    }
358
359    /// Configures the a deserializer implementation for the builder.
360    #[allow(clippy::implied_bounds_in_impls)] // for `Send` and `Sync`
361    pub fn deserializer_impl<O2, E2>(
362        mut self,
363        deserializer: impl DeserializeResponse + Send + Sync + 'static,
364    ) -> OperationBuilder<I, O2, E2>
365    where
366        O2: fmt::Debug + Send + Sync + 'static,
367        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
368    {
369        let deserializer: SharedResponseDeserializer = deserializer.into_shared();
370        self.config.store_put(deserializer);
371
372        OperationBuilder {
373            service_name: self.service_name,
374            operation_name: self.operation_name,
375            config: self.config,
376            runtime_components: self.runtime_components,
377            runtime_plugins: self.runtime_plugins,
378            _phantom: Default::default(),
379        }
380    }
381
382    /// Creates an `Operation` from the builder.
383    pub fn build(self) -> Operation<I, O, E> {
384        let service_name = self.service_name.expect("service_name required");
385        let operation_name = self.operation_name.expect("operation_name required");
386        let mut config = self.config;
387        config.store_put(Metadata::new(operation_name.clone(), service_name.clone()));
388        let mut runtime_plugins = RuntimePlugins::new()
389            .with_client_plugins(default_plugins(
390                DefaultPluginParams::new().with_retry_partition_name(service_name.clone()),
391            ))
392            .with_client_plugin(
393                StaticRuntimePlugin::new()
394                    .with_config(config.freeze())
395                    .with_runtime_components(self.runtime_components),
396            );
397        for runtime_plugin in self.runtime_plugins {
398            runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
399        }
400
401        #[cfg(debug_assertions)]
402        {
403            let mut config = ConfigBag::base();
404            let components = runtime_plugins
405                .apply_client_configuration(&mut config)
406                .expect("the runtime plugins should succeed");
407
408            assert!(
409                components.http_client().is_some(),
410                "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
411            );
412            assert!(
413                components.endpoint_resolver().is_some(),
414                "a endpoint_resolver is required"
415            );
416            assert!(
417                components.retry_strategy().is_some(),
418                "a retry_strategy is required"
419            );
420            assert!(
421                config.load::<SharedRequestSerializer>().is_some(),
422                "a serializer is required"
423            );
424            assert!(
425                config.load::<SharedResponseDeserializer>().is_some(),
426                "a deserializer is required"
427            );
428            assert!(
429                config.load::<EndpointResolverParams>().is_some(),
430                "endpoint resolver params are required"
431            );
432            assert!(
433                config.load::<TimeoutConfig>().is_some(),
434                "timeout config is required"
435            );
436        }
437
438        Operation {
439            service_name,
440            operation_name,
441            runtime_plugins,
442            _phantom: Default::default(),
443        }
444    }
445}
446
447#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
448mod tests {
449    use super::*;
450    use crate::client::retries::classifiers::HttpStatusCodeClassifier;
451    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
452    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
453    use aws_smithy_runtime_api::client::result::ConnectorError;
454    use aws_smithy_types::body::SdkBody;
455    use std::convert::Infallible;
456
457    #[tokio::test]
458    async fn operation() {
459        let (connector, request_rx) = capture_request(Some(
460            http_1x::Response::builder()
461                .status(418)
462                .body(SdkBody::from(&b"I'm a teapot!"[..]))
463                .unwrap(),
464        ));
465        let operation = Operation::builder()
466            .service_name("test")
467            .operation_name("test")
468            .http_client(connector)
469            .endpoint_url("http://localhost:1234")
470            .no_auth()
471            .no_retry()
472            .timeout_config(TimeoutConfig::disabled())
473            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
474            .deserializer::<_, Infallible>(|response| {
475                assert_eq!(418, u16::from(response.status()));
476                Ok(std::str::from_utf8(response.body().bytes().unwrap())
477                    .unwrap()
478                    .to_string())
479            })
480            .build();
481
482        let output = operation
483            .invoke("what are you?".to_string())
484            .await
485            .expect("success");
486        assert_eq!("I'm a teapot!", output);
487
488        let request = request_rx.expect_request();
489        assert_eq!("http://localhost:1234/", request.uri());
490        assert_eq!(b"what are you?", request.body().bytes().unwrap());
491    }
492
493    #[tokio::test]
494    async fn operation_retries() {
495        let connector = StaticReplayClient::new(vec![
496            ReplayEvent::new(
497                http_1x::Request::builder()
498                    .uri("http://localhost:1234/")
499                    .body(SdkBody::from(&b"what are you?"[..]))
500                    .unwrap(),
501                http_1x::Response::builder()
502                    .status(503)
503                    .body(SdkBody::from(&b""[..]))
504                    .unwrap(),
505            ),
506            ReplayEvent::new(
507                http_1x::Request::builder()
508                    .uri("http://localhost:1234/")
509                    .body(SdkBody::from(&b"what are you?"[..]))
510                    .unwrap(),
511                http_1x::Response::builder()
512                    .status(418)
513                    .body(SdkBody::from(&b"I'm a teapot!"[..]))
514                    .unwrap(),
515            ),
516        ]);
517        let operation = Operation::builder()
518            .service_name("test")
519            .operation_name("test")
520            .http_client(connector.clone())
521            .endpoint_url("http://localhost:1234")
522            .no_auth()
523            .standard_retry(&RetryConfig::standard())
524            .retry_classifier(HttpStatusCodeClassifier::default())
525            .timeout_config(TimeoutConfig::disabled())
526            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
527            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
528            .deserializer::<_, Infallible>(|response| {
529                if u16::from(response.status()) == 503 {
530                    Err(OrchestratorError::connector(ConnectorError::io(
531                        "test".into(),
532                    )))
533                } else {
534                    assert_eq!(418, u16::from(response.status()));
535                    Ok(std::str::from_utf8(response.body().bytes().unwrap())
536                        .unwrap()
537                        .to_string())
538                }
539            })
540            .build();
541
542        let output = operation
543            .invoke("what are you?".to_string())
544            .await
545            .expect("success");
546        assert_eq!("I'm a teapot!", output);
547
548        connector.assert_requests_match(&[]);
549    }
550}