aws_smithy_http_server/
extension.rs1use std::hash::Hash;
23use std::{fmt, fmt::Debug, future::Future, ops::Deref, pin::Pin, task::Context, task::Poll};
24
25use futures_util::ready;
26use futures_util::TryFuture;
27use thiserror::Error;
28use tower::Service;
29
30use crate::operation::OperationShape;
31use crate::plugin::{HttpMarker, HttpPlugins, Plugin, PluginStack};
32use crate::shape_id::ShapeId;
33
34pub use crate::request::extension::{Extension, MissingExtension};
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct OperationExtension(pub ShapeId);
42
43#[derive(Debug, Clone, Error, PartialEq, Eq)]
45#[non_exhaustive]
46pub enum ParseError {
47    #[error("# was not found - missing namespace")]
48    MissingNamespace,
49}
50
51pin_project_lite::pin_project! {
52    pub struct OperationExtensionFuture<Fut> {
55        #[pin]
56        inner: Fut,
57        operation_extension: Option<OperationExtension>
58    }
59}
60
61impl<Fut, RespB> Future for OperationExtensionFuture<Fut>
62where
63    Fut: TryFuture<Ok = http::Response<RespB>>,
64{
65    type Output = Result<http::Response<RespB>, Fut::Error>;
66
67    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        let this = self.project();
69        let resp = ready!(this.inner.try_poll(cx));
70        let ext = this
71            .operation_extension
72            .take()
73            .expect("futures cannot be polled after completion");
74        Poll::Ready(resp.map(|mut resp| {
75            resp.extensions_mut().insert(ext);
76            resp
77        }))
78    }
79}
80
81#[derive(Debug, Clone)]
83pub struct OperationExtensionService<S> {
84    inner: S,
85    operation_extension: OperationExtension,
86}
87
88impl<S, B, RespBody> Service<http::Request<B>> for OperationExtensionService<S>
89where
90    S: Service<http::Request<B>, Response = http::Response<RespBody>>,
91{
92    type Response = http::Response<RespBody>;
93    type Error = S::Error;
94    type Future = OperationExtensionFuture<S::Future>;
95
96    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97        self.inner.poll_ready(cx)
98    }
99
100    fn call(&mut self, req: http::Request<B>) -> Self::Future {
101        OperationExtensionFuture {
102            inner: self.inner.call(req),
103            operation_extension: Some(self.operation_extension.clone()),
104        }
105    }
106}
107
108pub struct OperationExtensionPlugin;
110
111impl fmt::Debug for OperationExtensionPlugin {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_tuple("OperationExtensionPlugin").field(&"...").finish()
114    }
115}
116
117impl<Ser, Op, T> Plugin<Ser, Op, T> for OperationExtensionPlugin
118where
119    Op: OperationShape,
120{
121    type Output = OperationExtensionService<T>;
122
123    fn apply(&self, inner: T) -> Self::Output {
124        OperationExtensionService {
125            inner,
126            operation_extension: OperationExtension(Op::ID),
127        }
128    }
129}
130
131impl HttpMarker for OperationExtensionPlugin {}
132
133pub trait OperationExtensionExt<CurrentPlugin> {
137    fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>>;
139}
140
141impl<CurrentPlugin> OperationExtensionExt<CurrentPlugin> for HttpPlugins<CurrentPlugin> {
142    fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>> {
143        self.push(OperationExtensionPlugin)
144    }
145}
146
147#[derive(Debug, Clone)]
150pub struct ModeledErrorExtension(&'static str);
151
152impl ModeledErrorExtension {
153    pub fn new(value: &'static str) -> ModeledErrorExtension {
155        ModeledErrorExtension(value)
156    }
157}
158
159impl Deref for ModeledErrorExtension {
160    type Target = &'static str;
161
162    fn deref(&self) -> &Self::Target {
163        &self.0
164    }
165}
166
167#[derive(Debug, Clone)]
170pub struct RuntimeErrorExtension(String);
171
172impl RuntimeErrorExtension {
173    pub fn new(value: String) -> RuntimeErrorExtension {
175        RuntimeErrorExtension(value)
176    }
177}
178
179impl Deref for RuntimeErrorExtension {
180    type Target = String;
181
182    fn deref(&self) -> &Self::Target {
183        &self.0
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use tower::{service_fn, Layer, ServiceExt};
190
191    use crate::{plugin::PluginLayer, protocol::rest_json_1::RestJson1};
192
193    use super::*;
194
195    #[test]
196    fn ext_accept() {
197        let value = "com.amazonaws.ebs#CompleteSnapshot";
198        let ext = ShapeId::new(
199            "com.amazonaws.ebs#CompleteSnapshot",
200            "com.amazonaws.ebs",
201            "CompleteSnapshot",
202        );
203
204        assert_eq!(ext.absolute(), value);
205        assert_eq!(ext.namespace(), "com.amazonaws.ebs");
206        assert_eq!(ext.name(), "CompleteSnapshot");
207    }
208
209    #[tokio::test]
210    async fn plugin() {
211        struct DummyOp;
212
213        impl OperationShape for DummyOp {
214            const ID: ShapeId = ShapeId::new(
215                "com.amazonaws.ebs#CompleteSnapshot",
216                "com.amazonaws.ebs",
217                "CompleteSnapshot",
218            );
219
220            type Input = ();
221            type Output = ();
222            type Error = ();
223        }
224
225        let plugins = HttpPlugins::new().insert_operation_extension();
227
228        let layer = PluginLayer::new::<RestJson1, DummyOp>(plugins);
230        let svc = service_fn(|_: http::Request<()>| async { Ok::<_, ()>(http::Response::new(())) });
231        let svc = layer.layer(svc);
232
233        let response = svc.oneshot(http::Request::new(())).await.unwrap();
235        let expected = DummyOp::ID;
236        let actual = response.extensions().get::<OperationExtension>().unwrap();
237        assert_eq!(actual.0, expected);
238    }
239}