1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Execute pure-Python middleware handler.

use aws_smithy_http_server::body::{Body, BoxBody};
use http::{Request, Response};
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction};
use pyo3_asyncio::TaskLocals;
use tower::{util::BoxService, BoxError, Service};

use crate::util::func_metadata;

use super::{PyMiddlewareError, PyRequest, PyResponse};

// PyNextInner represents the inner service Tower layer applied to.
type PyNextInner = BoxService<Request<Body>, Response<BoxBody>, BoxError>;

// PyNext wraps inner Tower service and makes it callable from Python.
#[pyo3::pyclass]
struct PyNext(Option<PyNextInner>);

impl PyNext {
    fn new(inner: PyNextInner) -> Self {
        Self(Some(inner))
    }

    // Consumes self by taking the inner Tower service.
    // This method would have been `into_inner(self) -> PyNextInner`
    // but we can't do that because we are crossing Python boundary.
    fn take_inner(&mut self) -> Option<PyNextInner> {
        self.0.take()
    }
}

#[pyo3::pymethods]
impl PyNext {
    // Calls the inner Tower service with the `Request` that is passed from Python.
    // It returns a coroutine to be awaited on the Python side to complete the call.
    // Note that it takes wrapped objects from both `PyRequest` and `PyNext`,
    // so after calling `next`, consumer can't access to the `Request` or
    // can't call the `next` again, this basically emulates consuming `self` and `Request`,
    // but since we are crossing the Python boundary we can't express it in natural Rust terms.
    //
    // Naming the method `__call__` allows `next` to be called like `next(...)`.
    fn __call__<'p>(&'p mut self, py: Python<'p>, py_req: Py<PyRequest>) -> PyResult<&'p PyAny> {
        let req = py_req
            .borrow_mut(py)
            .take_inner()
            .ok_or(PyMiddlewareError::RequestGone)?;
        let mut inner = self
            .take_inner()
            .ok_or(PyMiddlewareError::NextAlreadyCalled)?;
        pyo3_asyncio::tokio::future_into_py(py, async move {
            let res = inner
                .call(req)
                .await
                .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
            Ok(Python::with_gil(|py| PyResponse::new(res).into_py(py)))
        })
    }
}

/// A Python middleware handler function representation.
///
/// The Python business logic implementation needs to carry some information
/// to be executed properly like if it is a coroutine.
#[derive(Debug, Clone)]
pub struct PyMiddlewareHandler {
    pub name: String,
    pub func: PyObject,
    pub is_coroutine: bool,
}

impl PyMiddlewareHandler {
    pub fn new(py: Python, func: PyObject) -> PyResult<Self> {
        let func_metadata = func_metadata(py, &func)?;
        Ok(Self {
            name: func_metadata.name,
            func,
            is_coroutine: func_metadata.is_coroutine,
        })
    }

    // Calls pure-Python middleware handler with given `Request` and the next Tower service
    // and returns the `Response` that returned from the pure-Python handler.
    pub async fn call(
        self,
        req: Request<Body>,
        next: PyNextInner,
        locals: TaskLocals,
    ) -> PyResult<Response<BoxBody>> {
        let py_req = PyRequest::new(req);
        let py_next = PyNext::new(next);

        let handler = self.func;
        let result = if self.is_coroutine {
            pyo3_asyncio::tokio::scope(locals, async move {
                Python::with_gil(|py| {
                    let py_handler: &PyFunction = handler.extract(py)?;
                    let output = py_handler.call1((py_req, py_next))?;
                    pyo3_asyncio::tokio::into_future(output)
                })?
                .await
            })
            .await?
        } else {
            Python::with_gil(|py| {
                let py_handler: &PyFunction = handler.extract(py)?;
                let output = py_handler.call1((py_req, py_next))?;
                Ok::<_, PyErr>(output.into())
            })?
        };

        let response = Python::with_gil(|py| {
            let py_res: Py<PyResponse> = result.extract(py)?;
            let mut py_res = py_res.borrow_mut(py);
            Ok::<_, PyErr>(py_res.take_inner())
        })?;

        response.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
    }
}