aws_smithy_http_server_python/middleware/
request.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Python-compatible middleware [http::Request] implementation.
7
8use std::mem;
9use std::sync::Arc;
10
11use aws_smithy_http_server::body::Body;
12use http::{request::Parts, Request};
13use pyo3::{
14    exceptions::{PyRuntimeError, PyValueError},
15    prelude::*,
16};
17use tokio::sync::Mutex;
18
19use super::{PyHeaderMap, PyMiddlewareError};
20
21/// Python-compatible [Request] object.
22#[pyclass(name = "Request")]
23#[derive(Debug)]
24pub struct PyRequest {
25    parts: Option<Parts>,
26    headers: PyHeaderMap,
27    body: Arc<Mutex<Option<Body>>>,
28}
29
30impl PyRequest {
31    /// Create a new Python-compatible [Request] structure from the Rust side.
32    pub fn new(request: Request<Body>) -> Self {
33        let (mut parts, body) = request.into_parts();
34        let headers = mem::take(&mut parts.headers);
35        Self {
36            parts: Some(parts),
37            headers: PyHeaderMap::new(headers),
38            body: Arc::new(Mutex::new(Some(body))),
39        }
40    }
41
42    // Consumes self by taking the inner Request.
43    // This method would have been `into_inner(self) -> Request<Body>`
44    // but we can't do that because we are crossing Python boundary.
45    pub fn take_inner(&mut self) -> Option<Request<Body>> {
46        let headers = self.headers.take_inner()?;
47        let mut parts = self.parts.take()?;
48        parts.headers = headers;
49        let body = {
50            let body = mem::take(&mut self.body);
51            let body = Arc::try_unwrap(body).ok()?;
52            body.into_inner().take()?
53        };
54        Some(Request::from_parts(parts, body))
55    }
56}
57
58#[pymethods]
59impl PyRequest {
60    /// Return the HTTP method of this request.
61    ///
62    /// :type str:
63    #[getter]
64    fn method(&self) -> PyResult<String> {
65        self.parts
66            .as_ref()
67            .map(|parts| parts.method.to_string())
68            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
69    }
70
71    /// Return the URI of this request.
72    ///
73    /// :type str:
74    #[getter]
75    fn uri(&self) -> PyResult<String> {
76        self.parts
77            .as_ref()
78            .map(|parts| parts.uri.to_string())
79            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
80    }
81
82    /// Sets the URI of this request.
83    ///
84    /// :type str:
85    #[setter]
86    fn set_uri(&mut self, uri_str: String) -> PyResult<()> {
87        self.parts.as_mut().map_or_else(
88            || Err(PyMiddlewareError::RequestGone.into()),
89            |parts| {
90                parts.uri = uri_str.parse().map_err(|e: http::uri::InvalidUri| {
91                    PyValueError::new_err(format!(
92                        "URI `{}` cannot be parsed. Error: {}",
93                        uri_str, e
94                    ))
95                })?;
96                Ok(())
97            },
98        )
99    }
100
101    /// Return the HTTP version of this request.
102    ///
103    /// :type str:
104    #[getter]
105    fn version(&self) -> PyResult<String> {
106        self.parts
107            .as_ref()
108            .map(|parts| format!("{:?}", parts.version))
109            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
110    }
111
112    /// Return the HTTP headers of this request.
113    ///
114    /// :type typing.MutableMapping[str, str]:
115    #[getter]
116    fn headers(&self) -> PyHeaderMap {
117        self.headers.clone()
118    }
119
120    /// Return the HTTP body of this request.
121    /// Note that this is a costly operation because the whole request body is cloned.
122    ///
123    /// :type typing.Awaitable[bytes]:
124    #[getter]
125    fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
126        let body = self.body.clone();
127        pyo3_asyncio::tokio::future_into_py(py, async move {
128            let body = {
129                let mut body_guard = body.lock().await;
130                let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
131                let body = hyper::body::to_bytes(body)
132                    .await
133                    .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
134                let buf = body.clone();
135                body_guard.replace(Body::from(body));
136                buf
137            };
138            // TODO(Perf): can we use `PyBytes` here?
139            Ok(body.to_vec())
140        })
141    }
142
143    /// Set the HTTP body of this request.
144    #[setter]
145    fn set_body(&mut self, buf: &[u8]) {
146        self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
147    }
148}