aws_smithy_http_server/
extension.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Extension types.
7//!
8//! Extension types are types that are stored in and extracted from _both_ requests and
9//! responses.
10//!
11//! There is only one _generic_ extension type _for requests_, [`Extension`].
12//!
13//! On the other hand, the server SDK uses multiple concrete extension types for responses in order
14//! to store a variety of information, like the operation that was executed, the operation error
15//! that got returned, or the runtime error that happened, among others. The information stored in
16//! these types may be useful to [`tower::Layer`]s that post-process the response: for instance, a
17//! particular metrics layer implementation might want to emit metrics about the number of times an
18//! an operation got executed.
19//!
20//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
21
22use std::hash::Hash;
23use std::{fmt, fmt::Debug, future::Future, ops::Deref, pin::Pin, task::Context, task::Poll};
24
25use futures_util::ready;
26use futures_util::TryFuture;
27use thiserror::Error;
28use tower::Service;
29
30use crate::operation::OperationShape;
31use crate::plugin::{HttpMarker, HttpPlugins, Plugin, PluginStack};
32use crate::shape_id::ShapeId;
33
34pub use crate::request::extension::{Extension, MissingExtension};
35
36/// Extension type used to store information about Smithy operations in HTTP responses.
37/// This extension type is inserted, via the [`OperationExtensionPlugin`], whenever it has been correctly determined
38/// that the request should be routed to a particular operation. The operation handler might not even get invoked
39/// because the request fails to deserialize into the modeled operation input.
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct OperationExtension(pub ShapeId);
42
43/// An error occurred when parsing an absolute operation shape ID.
44#[derive(Debug, Clone, Error, PartialEq, Eq)]
45#[non_exhaustive]
46pub enum ParseError {
47    #[error("# was not found - missing namespace")]
48    MissingNamespace,
49}
50
51pin_project_lite::pin_project! {
52    /// The [`Service::Future`] of [`OperationExtensionService`] - inserts an [`OperationExtension`] into the
53    /// [`http::Response]`.
54    pub struct OperationExtensionFuture<Fut> {
55        #[pin]
56        inner: Fut,
57        operation_extension: Option<OperationExtension>
58    }
59}
60
61impl<Fut, RespB> Future for OperationExtensionFuture<Fut>
62where
63    Fut: TryFuture<Ok = http::Response<RespB>>,
64{
65    type Output = Result<http::Response<RespB>, Fut::Error>;
66
67    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        let this = self.project();
69        let resp = ready!(this.inner.try_poll(cx));
70        let ext = this
71            .operation_extension
72            .take()
73            .expect("futures cannot be polled after completion");
74        Poll::Ready(resp.map(|mut resp| {
75            resp.extensions_mut().insert(ext);
76            resp
77        }))
78    }
79}
80
81/// Inserts a [`OperationExtension`] into the extensions of the [`http::Response`].
82#[derive(Debug, Clone)]
83pub struct OperationExtensionService<S> {
84    inner: S,
85    operation_extension: OperationExtension,
86}
87
88impl<S, B, RespBody> Service<http::Request<B>> for OperationExtensionService<S>
89where
90    S: Service<http::Request<B>, Response = http::Response<RespBody>>,
91{
92    type Response = http::Response<RespBody>;
93    type Error = S::Error;
94    type Future = OperationExtensionFuture<S::Future>;
95
96    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97        self.inner.poll_ready(cx)
98    }
99
100    fn call(&mut self, req: http::Request<B>) -> Self::Future {
101        OperationExtensionFuture {
102            inner: self.inner.call(req),
103            operation_extension: Some(self.operation_extension.clone()),
104        }
105    }
106}
107
108/// A [`Plugin`] which applies [`OperationExtensionService`] to every operation.
109pub struct OperationExtensionPlugin;
110
111impl fmt::Debug for OperationExtensionPlugin {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_tuple("OperationExtensionPlugin").field(&"...").finish()
114    }
115}
116
117impl<Ser, Op, T> Plugin<Ser, Op, T> for OperationExtensionPlugin
118where
119    Op: OperationShape,
120{
121    type Output = OperationExtensionService<T>;
122
123    fn apply(&self, inner: T) -> Self::Output {
124        OperationExtensionService {
125            inner,
126            operation_extension: OperationExtension(Op::ID),
127        }
128    }
129}
130
131impl HttpMarker for OperationExtensionPlugin {}
132
133/// An extension trait on [`HttpPlugins`] allowing the application of [`OperationExtensionPlugin`].
134///
135/// See [`module`](crate::extension) documentation for more info.
136pub trait OperationExtensionExt<CurrentPlugin> {
137    /// Apply the [`OperationExtensionPlugin`], which inserts the [`OperationExtension`] into every [`http::Response`].
138    fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>>;
139}
140
141impl<CurrentPlugin> OperationExtensionExt<CurrentPlugin> for HttpPlugins<CurrentPlugin> {
142    fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>> {
143        self.push(OperationExtensionPlugin)
144    }
145}
146
147/// Extension type used to store the type of user-modeled error returned by an operation handler.
148/// These are modeled errors, defined in the Smithy model.
149#[derive(Debug, Clone)]
150pub struct ModeledErrorExtension(&'static str);
151
152impl ModeledErrorExtension {
153    /// Creates a new `ModeledErrorExtension`.
154    pub fn new(value: &'static str) -> ModeledErrorExtension {
155        ModeledErrorExtension(value)
156    }
157}
158
159impl Deref for ModeledErrorExtension {
160    type Target = &'static str;
161
162    fn deref(&self) -> &Self::Target {
163        &self.0
164    }
165}
166
167/// Extension type used to store the _name_ of the possible runtime errors.
168/// These are _unmodeled_ errors; the operation handler was not invoked.
169#[derive(Debug, Clone)]
170pub struct RuntimeErrorExtension(String);
171
172impl RuntimeErrorExtension {
173    /// Creates a new `RuntimeErrorExtension`.
174    pub fn new(value: String) -> RuntimeErrorExtension {
175        RuntimeErrorExtension(value)
176    }
177}
178
179impl Deref for RuntimeErrorExtension {
180    type Target = String;
181
182    fn deref(&self) -> &Self::Target {
183        &self.0
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use tower::{service_fn, Layer, ServiceExt};
190
191    use crate::{plugin::PluginLayer, protocol::rest_json_1::RestJson1};
192
193    use super::*;
194
195    #[test]
196    fn ext_accept() {
197        let value = "com.amazonaws.ebs#CompleteSnapshot";
198        let ext = ShapeId::new(
199            "com.amazonaws.ebs#CompleteSnapshot",
200            "com.amazonaws.ebs",
201            "CompleteSnapshot",
202        );
203
204        assert_eq!(ext.absolute(), value);
205        assert_eq!(ext.namespace(), "com.amazonaws.ebs");
206        assert_eq!(ext.name(), "CompleteSnapshot");
207    }
208
209    #[tokio::test]
210    async fn plugin() {
211        struct DummyOp;
212
213        impl OperationShape for DummyOp {
214            const ID: ShapeId = ShapeId::new(
215                "com.amazonaws.ebs#CompleteSnapshot",
216                "com.amazonaws.ebs",
217                "CompleteSnapshot",
218            );
219
220            type Input = ();
221            type Output = ();
222            type Error = ();
223        }
224
225        // Apply `Plugin`.
226        let plugins = HttpPlugins::new().insert_operation_extension();
227
228        // Apply `Plugin`s `Layer`.
229        let layer = PluginLayer::new::<RestJson1, DummyOp>(plugins);
230        let svc = service_fn(|_: http::Request<()>| async { Ok::<_, ()>(http::Response::new(())) });
231        let svc = layer.layer(svc);
232
233        // Check for `OperationExtension`.
234        let response = svc.oneshot(http::Request::new(())).await.unwrap();
235        let expected = DummyOp::ID;
236        let actual = response.extensions().get::<OperationExtension>().unwrap();
237        assert_eq!(actual.0, expected);
238    }
239}