aws_smithy_http_server/
extension.rs1use std::hash::Hash;
23use std::{fmt, fmt::Debug, future::Future, ops::Deref, pin::Pin, task::Context, task::Poll};
24
25use futures_util::ready;
26use futures_util::TryFuture;
27use thiserror::Error;
28use tower::Service;
29
30use crate::operation::OperationShape;
31use crate::plugin::{HttpMarker, HttpPlugins, Plugin, PluginStack};
32use crate::shape_id::ShapeId;
33
34pub use crate::request::extension::{Extension, MissingExtension};
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct OperationExtension(pub ShapeId);
42
43#[derive(Debug, Clone, Error, PartialEq, Eq)]
45#[non_exhaustive]
46pub enum ParseError {
47 #[error("# was not found - missing namespace")]
48 MissingNamespace,
49}
50
51pin_project_lite::pin_project! {
52 pub struct OperationExtensionFuture<Fut> {
55 #[pin]
56 inner: Fut,
57 operation_extension: Option<OperationExtension>
58 }
59}
60
61impl<Fut, RespB> Future for OperationExtensionFuture<Fut>
62where
63 Fut: TryFuture<Ok = http::Response<RespB>>,
64{
65 type Output = Result<http::Response<RespB>, Fut::Error>;
66
67 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68 let this = self.project();
69 let resp = ready!(this.inner.try_poll(cx));
70 let ext = this
71 .operation_extension
72 .take()
73 .expect("futures cannot be polled after completion");
74 Poll::Ready(resp.map(|mut resp| {
75 resp.extensions_mut().insert(ext);
76 resp
77 }))
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct OperationExtensionService<S> {
84 inner: S,
85 operation_extension: OperationExtension,
86}
87
88impl<S, B, RespBody> Service<http::Request<B>> for OperationExtensionService<S>
89where
90 S: Service<http::Request<B>, Response = http::Response<RespBody>>,
91{
92 type Response = http::Response<RespBody>;
93 type Error = S::Error;
94 type Future = OperationExtensionFuture<S::Future>;
95
96 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97 self.inner.poll_ready(cx)
98 }
99
100 fn call(&mut self, req: http::Request<B>) -> Self::Future {
101 OperationExtensionFuture {
102 inner: self.inner.call(req),
103 operation_extension: Some(self.operation_extension.clone()),
104 }
105 }
106}
107
108pub struct OperationExtensionPlugin;
110
111impl fmt::Debug for OperationExtensionPlugin {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 f.debug_tuple("OperationExtensionPlugin").field(&"...").finish()
114 }
115}
116
117impl<Ser, Op, T> Plugin<Ser, Op, T> for OperationExtensionPlugin
118where
119 Op: OperationShape,
120{
121 type Output = OperationExtensionService<T>;
122
123 fn apply(&self, inner: T) -> Self::Output {
124 OperationExtensionService {
125 inner,
126 operation_extension: OperationExtension(Op::ID),
127 }
128 }
129}
130
131impl HttpMarker for OperationExtensionPlugin {}
132
133pub trait OperationExtensionExt<CurrentPlugin> {
137 fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>>;
139}
140
141impl<CurrentPlugin> OperationExtensionExt<CurrentPlugin> for HttpPlugins<CurrentPlugin> {
142 fn insert_operation_extension(self) -> HttpPlugins<PluginStack<OperationExtensionPlugin, CurrentPlugin>> {
143 self.push(OperationExtensionPlugin)
144 }
145}
146
147#[derive(Debug, Clone)]
150pub struct ModeledErrorExtension(&'static str);
151
152impl ModeledErrorExtension {
153 pub fn new(value: &'static str) -> ModeledErrorExtension {
155 ModeledErrorExtension(value)
156 }
157}
158
159impl Deref for ModeledErrorExtension {
160 type Target = &'static str;
161
162 fn deref(&self) -> &Self::Target {
163 &self.0
164 }
165}
166
167#[derive(Debug, Clone)]
170pub struct RuntimeErrorExtension(String);
171
172impl RuntimeErrorExtension {
173 pub fn new(value: String) -> RuntimeErrorExtension {
175 RuntimeErrorExtension(value)
176 }
177}
178
179impl Deref for RuntimeErrorExtension {
180 type Target = String;
181
182 fn deref(&self) -> &Self::Target {
183 &self.0
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use tower::{service_fn, Layer, ServiceExt};
190
191 use crate::{plugin::PluginLayer, protocol::rest_json_1::RestJson1};
192
193 use super::*;
194
195 #[test]
196 fn ext_accept() {
197 let value = "com.amazonaws.ebs#CompleteSnapshot";
198 let ext = ShapeId::new(
199 "com.amazonaws.ebs#CompleteSnapshot",
200 "com.amazonaws.ebs",
201 "CompleteSnapshot",
202 );
203
204 assert_eq!(ext.absolute(), value);
205 assert_eq!(ext.namespace(), "com.amazonaws.ebs");
206 assert_eq!(ext.name(), "CompleteSnapshot");
207 }
208
209 #[tokio::test]
210 async fn plugin() {
211 struct DummyOp;
212
213 impl OperationShape for DummyOp {
214 const ID: ShapeId = ShapeId::new(
215 "com.amazonaws.ebs#CompleteSnapshot",
216 "com.amazonaws.ebs",
217 "CompleteSnapshot",
218 );
219
220 type Input = ();
221 type Output = ();
222 type Error = ();
223 }
224
225 let plugins = HttpPlugins::new().insert_operation_extension();
227
228 let layer = PluginLayer::new::<RestJson1, DummyOp>(plugins);
230 let svc = service_fn(|_: http::Request<()>| async { Ok::<_, ()>(http::Response::new(())) });
231 let svc = layer.layer(svc);
232
233 let response = svc.oneshot(http::Request::new(())).await.unwrap();
235 let expected = DummyOp::ID;
236 let actual = response.extensions().get::<OperationExtension>().unwrap();
237 assert_eq!(actual.0, expected);
238 }
239}