aws_smithy_http_server_python/middleware/
handler.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Execute pure-Python middleware handler.
7
8use aws_smithy_http_server::body::{Body, BoxBody};
9use http::{Request, Response};
10use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction};
11use pyo3_asyncio::TaskLocals;
12use tower::{util::BoxService, BoxError, Service};
13
14use crate::util::func_metadata;
15
16use super::{PyMiddlewareError, PyRequest, PyResponse};
17
18// PyNextInner represents the inner service Tower layer applied to.
19type PyNextInner = BoxService<Request<Body>, Response<BoxBody>, BoxError>;
20
21// PyNext wraps inner Tower service and makes it callable from Python.
22#[pyo3::pyclass]
23struct PyNext(Option<PyNextInner>);
24
25impl PyNext {
26    fn new(inner: PyNextInner) -> Self {
27        Self(Some(inner))
28    }
29
30    // Consumes self by taking the inner Tower service.
31    // This method would have been `into_inner(self) -> PyNextInner`
32    // but we can't do that because we are crossing Python boundary.
33    fn take_inner(&mut self) -> Option<PyNextInner> {
34        self.0.take()
35    }
36}
37
38#[pyo3::pymethods]
39impl PyNext {
40    // Calls the inner Tower service with the `Request` that is passed from Python.
41    // It returns a coroutine to be awaited on the Python side to complete the call.
42    // Note that it takes wrapped objects from both `PyRequest` and `PyNext`,
43    // so after calling `next`, consumer can't access to the `Request` or
44    // can't call the `next` again, this basically emulates consuming `self` and `Request`,
45    // but since we are crossing the Python boundary we can't express it in natural Rust terms.
46    //
47    // Naming the method `__call__` allows `next` to be called like `next(...)`.
48    fn __call__<'p>(&'p mut self, py: Python<'p>, py_req: Py<PyRequest>) -> PyResult<&'p PyAny> {
49        let req = py_req
50            .borrow_mut(py)
51            .take_inner()
52            .ok_or(PyMiddlewareError::RequestGone)?;
53        let mut inner = self
54            .take_inner()
55            .ok_or(PyMiddlewareError::NextAlreadyCalled)?;
56        pyo3_asyncio::tokio::future_into_py(py, async move {
57            let res = inner
58                .call(req)
59                .await
60                .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
61            Ok(Python::with_gil(|py| PyResponse::new(res).into_py(py)))
62        })
63    }
64}
65
66/// A Python middleware handler function representation.
67///
68/// The Python business logic implementation needs to carry some information
69/// to be executed properly like if it is a coroutine.
70#[derive(Debug, Clone)]
71pub struct PyMiddlewareHandler {
72    pub name: String,
73    pub func: PyObject,
74    pub is_coroutine: bool,
75}
76
77impl PyMiddlewareHandler {
78    pub fn new(py: Python, func: PyObject) -> PyResult<Self> {
79        let func_metadata = func_metadata(py, &func)?;
80        Ok(Self {
81            name: func_metadata.name,
82            func,
83            is_coroutine: func_metadata.is_coroutine,
84        })
85    }
86
87    // Calls pure-Python middleware handler with given `Request` and the next Tower service
88    // and returns the `Response` that returned from the pure-Python handler.
89    pub async fn call(
90        self,
91        req: Request<Body>,
92        next: PyNextInner,
93        locals: TaskLocals,
94    ) -> PyResult<Response<BoxBody>> {
95        let py_req = PyRequest::new(req);
96        let py_next = PyNext::new(next);
97
98        let handler = self.func;
99        let result = if self.is_coroutine {
100            pyo3_asyncio::tokio::scope(locals, async move {
101                Python::with_gil(|py| {
102                    let py_handler: &PyFunction = handler.extract(py)?;
103                    let output = py_handler.call1((py_req, py_next))?;
104                    pyo3_asyncio::tokio::into_future(output)
105                })?
106                .await
107            })
108            .await?
109        } else {
110            Python::with_gil(|py| {
111                let py_handler: &PyFunction = handler.extract(py)?;
112                let output = py_handler.call1((py_req, py_next))?;
113                Ok::<_, PyErr>(output.into())
114            })?
115        };
116
117        let response = Python::with_gil(|py| {
118            let py_res: Py<PyResponse> = result.extract(py)?;
119            let mut py_res = py_res.borrow_mut(py);
120            Ok::<_, PyErr>(py_res.take_inner())
121        })?;
122
123        response.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
124    }
125}