aws_smithy_http_server_python/middleware/
handler.rs1use aws_smithy_http_server::body::{Body, BoxBody};
9use http::{Request, Response};
10use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction};
11use pyo3_asyncio::TaskLocals;
12use tower::{util::BoxService, BoxError, Service};
13
14use crate::util::func_metadata;
15
16use super::{PyMiddlewareError, PyRequest, PyResponse};
17
18type PyNextInner = BoxService<Request<Body>, Response<BoxBody>, BoxError>;
20
21#[pyo3::pyclass]
23struct PyNext(Option<PyNextInner>);
24
25impl PyNext {
26 fn new(inner: PyNextInner) -> Self {
27 Self(Some(inner))
28 }
29
30 fn take_inner(&mut self) -> Option<PyNextInner> {
34 self.0.take()
35 }
36}
37
38#[pyo3::pymethods]
39impl PyNext {
40 fn __call__<'p>(&'p mut self, py: Python<'p>, py_req: Py<PyRequest>) -> PyResult<&'p PyAny> {
49 let req = py_req
50 .borrow_mut(py)
51 .take_inner()
52 .ok_or(PyMiddlewareError::RequestGone)?;
53 let mut inner = self
54 .take_inner()
55 .ok_or(PyMiddlewareError::NextAlreadyCalled)?;
56 pyo3_asyncio::tokio::future_into_py(py, async move {
57 let res = inner
58 .call(req)
59 .await
60 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
61 Ok(Python::with_gil(|py| PyResponse::new(res).into_py(py)))
62 })
63 }
64}
65
66#[derive(Debug, Clone)]
71pub struct PyMiddlewareHandler {
72 pub name: String,
73 pub func: PyObject,
74 pub is_coroutine: bool,
75}
76
77impl PyMiddlewareHandler {
78 pub fn new(py: Python, func: PyObject) -> PyResult<Self> {
79 let func_metadata = func_metadata(py, &func)?;
80 Ok(Self {
81 name: func_metadata.name,
82 func,
83 is_coroutine: func_metadata.is_coroutine,
84 })
85 }
86
87 pub async fn call(
90 self,
91 req: Request<Body>,
92 next: PyNextInner,
93 locals: TaskLocals,
94 ) -> PyResult<Response<BoxBody>> {
95 let py_req = PyRequest::new(req);
96 let py_next = PyNext::new(next);
97
98 let handler = self.func;
99 let result = if self.is_coroutine {
100 pyo3_asyncio::tokio::scope(locals, async move {
101 Python::with_gil(|py| {
102 let py_handler: &PyFunction = handler.extract(py)?;
103 let output = py_handler.call1((py_req, py_next))?;
104 pyo3_asyncio::tokio::into_future(output)
105 })?
106 .await
107 })
108 .await?
109 } else {
110 Python::with_gil(|py| {
111 let py_handler: &PyFunction = handler.extract(py)?;
112 let output = py_handler.call1((py_req, py_next))?;
113 Ok::<_, PyErr>(output.into())
114 })?
115 };
116
117 let response = Python::with_gil(|py| {
118 let py_res: Py<PyResponse> = result.extract(py)?;
119 let mut py_res = py_res.borrow_mut(py);
120 Ok::<_, PyErr>(py_res.take_inner())
121 })?;
122
123 response.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
124 }
125}