aws_smithy_http_server_python/middleware/
response.rs1use std::collections::HashMap;
9use std::mem;
10use std::sync::Arc;
11
12use aws_smithy_http_server::body::{to_boxed, BoxBody};
13use http::{response::Parts, Response};
14use pyo3::{exceptions::PyRuntimeError, prelude::*};
15use tokio::sync::Mutex;
16
17use super::{PyHeaderMap, PyMiddlewareError};
18
19#[pyclass(name = "Response")]
26pub struct PyResponse {
27 parts: Option<Parts>,
28 headers: PyHeaderMap,
29 body: Arc<Mutex<Option<BoxBody>>>,
30}
31
32impl PyResponse {
33 pub fn new(response: Response<BoxBody>) -> Self {
35 let (mut parts, body) = response.into_parts();
36 let headers = mem::take(&mut parts.headers);
37 Self {
38 parts: Some(parts),
39 headers: PyHeaderMap::new(headers),
40 body: Arc::new(Mutex::new(Some(body))),
41 }
42 }
43
44 pub fn take_inner(&mut self) -> Option<Response<BoxBody>> {
48 let headers = self.headers.take_inner()?;
49 let mut parts = self.parts.take()?;
50 parts.headers = headers;
51 let body = {
52 let body = mem::take(&mut self.body);
53 let body = Arc::try_unwrap(body).ok()?;
54 body.into_inner().take()?
55 };
56 Some(Response::from_parts(parts, body))
57 }
58}
59
60#[pymethods]
61impl PyResponse {
62 #[pyo3(text_signature = "($self, status, headers=None, body=None)")]
64 #[new]
65 fn newpy(
66 status: u16,
67 headers: Option<HashMap<String, String>>,
68 body: Option<Vec<u8>>,
69 ) -> PyResult<Self> {
70 let mut builder = Response::builder().status(status);
71
72 if let Some(headers) = headers {
73 for (k, v) in headers {
74 builder = builder.header(k, v);
75 }
76 }
77
78 let response = builder
79 .body(body.map(to_boxed).unwrap_or_default())
80 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
81
82 Ok(Self::new(response))
83 }
84
85 #[getter]
89 fn status(&self) -> PyResult<u16> {
90 self.parts
91 .as_ref()
92 .map(|parts| parts.status.as_u16())
93 .ok_or_else(|| PyMiddlewareError::ResponseGone.into())
94 }
95
96 #[getter]
100 fn version(&self) -> PyResult<String> {
101 self.parts
102 .as_ref()
103 .map(|parts| format!("{:?}", parts.version))
104 .ok_or_else(|| PyMiddlewareError::ResponseGone.into())
105 }
106
107 #[getter]
111 fn headers(&self) -> PyHeaderMap {
112 self.headers.clone()
113 }
114
115 #[getter]
120 fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
121 let body = self.body.clone();
122 pyo3_asyncio::tokio::future_into_py(py, async move {
123 let body = {
124 let mut body_guard = body.lock().await;
125 let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
126 let body = hyper::body::to_bytes(body)
127 .await
128 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
129 let buf = body.clone();
130 body_guard.replace(to_boxed(body));
131 buf
132 };
133 Ok(body.to_vec())
135 })
136 }
137
138 #[setter]
140 fn set_body(&mut self, buf: &[u8]) {
141 self.body = Arc::new(Mutex::new(Some(to_boxed(buf.to_owned()))));
142 }
143}