aws_smithy_http_server_python/middleware/
layer.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Tower layer implementation of Python middleware handling.
7
8use std::{
9    convert::Infallible,
10    marker::PhantomData,
11    mem,
12    task::{Context, Poll},
13};
14
15use aws_smithy_http_server::{
16    body::{Body, BoxBody},
17    response::IntoResponse,
18};
19use futures::{future::BoxFuture, TryFutureExt};
20use http::{Request, Response};
21use pyo3::Python;
22use pyo3_asyncio::TaskLocals;
23use tower::{util::BoxService, Layer, Service, ServiceExt};
24
25use super::PyMiddlewareHandler;
26use crate::{util::error::rich_py_err, PyMiddlewareException};
27
28/// Tower [Layer] implementation of Python middleware handling.
29///
30/// Middleware stored in the `handler` attribute will be executed inside an async Tower middleware.
31#[derive(Debug, Clone)]
32pub struct PyMiddlewareLayer<P> {
33    handler: PyMiddlewareHandler,
34    locals: TaskLocals,
35    _protocol: PhantomData<P>,
36}
37
38impl<P> PyMiddlewareLayer<P> {
39    pub fn new(handler: PyMiddlewareHandler, locals: TaskLocals) -> Self {
40        Self {
41            handler,
42            locals,
43            _protocol: PhantomData,
44        }
45    }
46}
47
48impl<S, P> Layer<S> for PyMiddlewareLayer<P>
49where
50    PyMiddlewareException: IntoResponse<P>,
51{
52    type Service = PyMiddlewareService<S>;
53
54    fn layer(&self, inner: S) -> Self::Service {
55        PyMiddlewareService::new(
56            inner,
57            self.handler.clone(),
58            self.locals.clone(),
59            PyMiddlewareException::into_response,
60        )
61    }
62}
63
64/// Tower [Service] wrapping the Python middleware [Layer].
65#[derive(Clone, Debug)]
66pub struct PyMiddlewareService<S> {
67    inner: S,
68    handler: PyMiddlewareHandler,
69    locals: TaskLocals,
70    into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
71}
72
73impl<S> PyMiddlewareService<S> {
74    pub fn new(
75        inner: S,
76        handler: PyMiddlewareHandler,
77        locals: TaskLocals,
78        into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
79    ) -> PyMiddlewareService<S> {
80        Self {
81            inner,
82            handler,
83            locals,
84            into_response,
85        }
86    }
87}
88
89impl<S> Service<Request<Body>> for PyMiddlewareService<S>
90where
91    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
92        + Clone
93        + Send
94        + 'static,
95    S::Future: Send,
96{
97    type Response = S::Response;
98    // We are making `Service` `Infallible` because we convert errors to responses via
99    // `PyMiddlewareException::into_response` which has `IntoResponse<Protocol>` bound,
100    // so we always return a protocol specific error response instead of erroring out.
101    type Error = Infallible;
102    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
103
104    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        self.inner.poll_ready(cx)
106    }
107
108    fn call(&mut self, req: Request<Body>) -> Self::Future {
109        let inner = {
110            // https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
111            let clone = self.inner.clone();
112            mem::replace(&mut self.inner, clone)
113        };
114        let handler = self.handler.clone();
115        let handler_name = handler.name.clone();
116        let next = BoxService::new(inner.map_err(|err| err.into()));
117        let locals = self.locals.clone();
118        let into_response = self.into_response;
119
120        Box::pin(
121            handler
122                .call(req, next, locals)
123                .or_else(move |err| async move {
124                    tracing::error!(error = ?rich_py_err(Python::with_gil(|py| err.clone_ref(py))), handler_name, "middleware failed");
125                    let response = (into_response)(err.into());
126                    Ok(response)
127                }),
128        )
129    }
130}