use std::mem;
use std::sync::Arc;
use aws_smithy_http_server::body::Body;
use http::{request::Parts, Request};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
prelude::*,
};
use tokio::sync::Mutex;
use super::{PyHeaderMap, PyMiddlewareError};
#[pyclass(name = "Request")]
#[derive(Debug)]
pub struct PyRequest {
parts: Option<Parts>,
headers: PyHeaderMap,
body: Arc<Mutex<Option<Body>>>,
}
impl PyRequest {
pub fn new(request: Request<Body>) -> Self {
let (mut parts, body) = request.into_parts();
let headers = mem::take(&mut parts.headers);
Self {
parts: Some(parts),
headers: PyHeaderMap::new(headers),
body: Arc::new(Mutex::new(Some(body))),
}
}
pub fn take_inner(&mut self) -> Option<Request<Body>> {
let headers = self.headers.take_inner()?;
let mut parts = self.parts.take()?;
parts.headers = headers;
let body = {
let body = mem::take(&mut self.body);
let body = Arc::try_unwrap(body).ok()?;
body.into_inner().take()?
};
Some(Request::from_parts(parts, body))
}
}
#[pymethods]
impl PyRequest {
#[getter]
fn method(&self) -> PyResult<String> {
self.parts
.as_ref()
.map(|parts| parts.method.to_string())
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}
#[getter]
fn uri(&self) -> PyResult<String> {
self.parts
.as_ref()
.map(|parts| parts.uri.to_string())
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}
#[setter]
fn set_uri(&mut self, uri_str: String) -> PyResult<()> {
self.parts.as_mut().map_or_else(
|| Err(PyMiddlewareError::RequestGone.into()),
|parts| {
parts.uri = uri_str.parse().map_err(|e: http::uri::InvalidUri| {
PyValueError::new_err(format!(
"URI `{}` cannot be parsed. Error: {}",
uri_str, e
))
})?;
Ok(())
},
)
}
#[getter]
fn version(&self) -> PyResult<String> {
self.parts
.as_ref()
.map(|parts| format!("{:?}", parts.version))
.ok_or_else(|| PyMiddlewareError::RequestGone.into())
}
#[getter]
fn headers(&self) -> PyHeaderMap {
self.headers.clone()
}
#[getter]
fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
let body = self.body.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let body = {
let mut body_guard = body.lock().await;
let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
let body = hyper::body::to_bytes(body)
.await
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
let buf = body.clone();
body_guard.replace(Body::from(body));
buf
};
Ok(body.to_vec())
})
}
#[setter]
fn set_body(&mut self, buf: &[u8]) {
self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
}
}