aws_smithy_http_server_python/middleware/
layer.rs1use std::{
9 convert::Infallible,
10 marker::PhantomData,
11 mem,
12 task::{Context, Poll},
13};
14
15use aws_smithy_http_server::{
16 body::{Body, BoxBody},
17 response::IntoResponse,
18};
19use futures::{future::BoxFuture, TryFutureExt};
20use http::{Request, Response};
21use pyo3::Python;
22use pyo3_asyncio::TaskLocals;
23use tower::{util::BoxService, Layer, Service, ServiceExt};
24
25use super::PyMiddlewareHandler;
26use crate::{util::error::rich_py_err, PyMiddlewareException};
27
28#[derive(Debug, Clone)]
32pub struct PyMiddlewareLayer<P> {
33 handler: PyMiddlewareHandler,
34 locals: TaskLocals,
35 _protocol: PhantomData<P>,
36}
37
38impl<P> PyMiddlewareLayer<P> {
39 pub fn new(handler: PyMiddlewareHandler, locals: TaskLocals) -> Self {
40 Self {
41 handler,
42 locals,
43 _protocol: PhantomData,
44 }
45 }
46}
47
48impl<S, P> Layer<S> for PyMiddlewareLayer<P>
49where
50 PyMiddlewareException: IntoResponse<P>,
51{
52 type Service = PyMiddlewareService<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 PyMiddlewareService::new(
56 inner,
57 self.handler.clone(),
58 self.locals.clone(),
59 PyMiddlewareException::into_response,
60 )
61 }
62}
63
64#[derive(Clone, Debug)]
66pub struct PyMiddlewareService<S> {
67 inner: S,
68 handler: PyMiddlewareHandler,
69 locals: TaskLocals,
70 into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
71}
72
73impl<S> PyMiddlewareService<S> {
74 pub fn new(
75 inner: S,
76 handler: PyMiddlewareHandler,
77 locals: TaskLocals,
78 into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
79 ) -> PyMiddlewareService<S> {
80 Self {
81 inner,
82 handler,
83 locals,
84 into_response,
85 }
86 }
87}
88
89impl<S> Service<Request<Body>> for PyMiddlewareService<S>
90where
91 S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
92 + Clone
93 + Send
94 + 'static,
95 S::Future: Send,
96{
97 type Response = S::Response;
98 type Error = Infallible;
102 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
103
104 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 self.inner.poll_ready(cx)
106 }
107
108 fn call(&mut self, req: Request<Body>) -> Self::Future {
109 let inner = {
110 let clone = self.inner.clone();
112 mem::replace(&mut self.inner, clone)
113 };
114 let handler = self.handler.clone();
115 let handler_name = handler.name.clone();
116 let next = BoxService::new(inner.map_err(|err| err.into()));
117 let locals = self.locals.clone();
118 let into_response = self.into_response;
119
120 Box::pin(
121 handler
122 .call(req, next, locals)
123 .or_else(move |err| async move {
124 tracing::error!(error = ?rich_py_err(Python::with_gil(|py| err.clone_ref(py))), handler_name, "middleware failed");
125 let response = (into_response)(err.into());
126 Ok(response)
127 }),
128 )
129 }
130}