aws_smithy_http_server_python/middleware/
response.rs1#![allow(non_local_definitions)]
9
10use std::collections::HashMap;
11use std::mem;
12use std::sync::Arc;
13
14use aws_smithy_http_server::body::{to_boxed, BoxBody};
15use http::{response::Parts, Response};
16use pyo3::{exceptions::PyRuntimeError, prelude::*};
17use tokio::sync::Mutex;
18
19use super::{PyHeaderMap, PyMiddlewareError};
20
21#[pyclass(name = "Response")]
28pub struct PyResponse {
29 parts: Option<Parts>,
30 headers: PyHeaderMap,
31 body: Arc<Mutex<Option<BoxBody>>>,
32}
33
34impl PyResponse {
35 pub fn new(response: Response<BoxBody>) -> Self {
37 let (mut parts, body) = response.into_parts();
38 let headers = mem::take(&mut parts.headers);
39 Self {
40 parts: Some(parts),
41 headers: PyHeaderMap::new(headers),
42 body: Arc::new(Mutex::new(Some(body))),
43 }
44 }
45
46 pub fn take_inner(&mut self) -> Option<Response<BoxBody>> {
50 let headers = self.headers.take_inner()?;
51 let mut parts = self.parts.take()?;
52 parts.headers = headers;
53 let body = {
54 let body = mem::take(&mut self.body);
55 let body = Arc::try_unwrap(body).ok()?;
56 body.into_inner()?
57 };
58 Some(Response::from_parts(parts, body))
59 }
60}
61
62#[pymethods]
63impl PyResponse {
64 #[pyo3(text_signature = "($self, status, headers=None, body=None)")]
66 #[new]
67 fn newpy(
68 status: u16,
69 headers: Option<HashMap<String, String>>,
70 body: Option<Vec<u8>>,
71 ) -> PyResult<Self> {
72 let mut builder = Response::builder().status(status);
73
74 if let Some(headers) = headers {
75 for (k, v) in headers {
76 builder = builder.header(k, v);
77 }
78 }
79
80 let response = builder
81 .body(body.map(to_boxed).unwrap_or_default())
82 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
83
84 Ok(Self::new(response))
85 }
86
87 #[getter]
91 fn status(&self) -> PyResult<u16> {
92 self.parts
93 .as_ref()
94 .map(|parts| parts.status.as_u16())
95 .ok_or_else(|| PyMiddlewareError::ResponseGone.into())
96 }
97
98 #[getter]
102 fn version(&self) -> PyResult<String> {
103 self.parts
104 .as_ref()
105 .map(|parts| format!("{:?}", parts.version))
106 .ok_or_else(|| PyMiddlewareError::ResponseGone.into())
107 }
108
109 #[getter]
113 fn headers(&self) -> PyHeaderMap {
114 self.headers.clone()
115 }
116
117 #[getter]
122 fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
123 let body = self.body.clone();
124 pyo3_asyncio::tokio::future_into_py(py, async move {
125 let body = {
126 let mut body_guard = body.lock().await;
127 let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
128 let body = hyper::body::to_bytes(body)
129 .await
130 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
131 let buf = body.clone();
132 body_guard.replace(to_boxed(body));
133 buf
134 };
135 Ok(body.to_vec())
137 })
138 }
139
140 #[setter]
142 fn set_body(&mut self, buf: &[u8]) {
143 self.body = Arc::new(Mutex::new(Some(to_boxed(buf.to_owned()))));
144 }
145}