aws_smithy_http_server_python/middleware/
header_map.rs1use std::{mem, str::FromStr, sync::Arc};
7
8use http::{header::HeaderName, HeaderMap, HeaderValue};
9use parking_lot::Mutex;
10use pyo3::{
11 exceptions::{PyKeyError, PyValueError},
12 pyclass, PyErr, PyResult,
13};
14
15use crate::{mutable_mapping_pymethods, util::collection::PyMutableMapping};
16
17#[pyclass(mapping)]
19#[derive(Clone, Debug)]
20pub struct PyHeaderMap {
21 inner: Arc<Mutex<HeaderMap>>,
22}
23
24impl PyHeaderMap {
25 pub fn new(inner: HeaderMap) -> Self {
26 Self {
27 inner: Arc::new(Mutex::new(inner)),
28 }
29 }
30
31 pub fn take_inner(&mut self) -> Option<HeaderMap> {
35 let header_map = mem::take(&mut self.inner);
36 let header_map = Arc::try_unwrap(header_map).ok()?;
37 let header_map = header_map.into_inner();
38 Some(header_map)
39 }
40}
41
42impl PyMutableMapping for PyHeaderMap {
45 type Key = String;
46 type Value = String;
47
48 fn len(&self) -> PyResult<usize> {
49 Ok(self.inner.lock().len())
50 }
51
52 fn contains(&self, key: Self::Key) -> PyResult<bool> {
53 Ok(self.inner.lock().contains_key(key))
54 }
55
56 fn keys(&self) -> PyResult<Vec<Self::Key>> {
57 Ok(self.inner.lock().keys().map(|h| h.to_string()).collect())
58 }
59
60 fn values(&self) -> PyResult<Vec<Self::Value>> {
61 self.inner
62 .lock()
63 .values()
64 .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
65 .collect()
66 }
67
68 fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
69 self.inner
70 .lock()
71 .get(key)
72 .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
73 .transpose()
74 }
75
76 fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
77 self.inner.lock().insert(
78 HeaderName::from_str(&key).map_err(to_value_error)?,
79 HeaderValue::from_str(&value).map_err(to_value_error)?,
80 );
81 Ok(())
82 }
83
84 fn del(&mut self, key: Self::Key) -> PyResult<()> {
85 if self.inner.lock().remove(key).is_none() {
86 Err(PyKeyError::new_err("unknown key"))
87 } else {
88 Ok(())
89 }
90 }
91}
92
93mutable_mapping_pymethods!(PyHeaderMap, keys_iter: PyHeaderMapKeys);
94
95fn to_value_error(err: impl std::error::Error) -> PyErr {
96 PyValueError::new_err(err.to_string())
97}
98
99#[cfg(test)]
100mod tests {
101 use http::header;
102 use pyo3::{prelude::*, py_run};
103
104 use super::*;
105
106 #[test]
107 fn py_header_map() -> PyResult<()> {
108 pyo3::prepare_freethreaded_python();
109
110 let mut header_map = HeaderMap::new();
111 header_map.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
112 header_map.insert(header::HOST, "localhost".parse().unwrap());
113
114 let header_map = Python::with_gil(|py| {
115 let py_header_map = PyHeaderMap::new(header_map);
116 let headers = PyCell::new(py, py_header_map)?;
117 py_run!(
118 py,
119 headers,
120 r#"
121assert len(headers) == 2
122assert headers["content-length"] == "42"
123assert headers["host"] == "localhost"
124
125headers["content-length"] = "45"
126assert headers["content-length"] == "45"
127headers["content-encoding"] = "application/json"
128assert headers["content-encoding"] == "application/json"
129
130del headers["host"]
131assert headers.get("host") == None
132assert len(headers) == 2
133
134assert set(headers.items()) == set([
135 ("content-length", "45"),
136 ("content-encoding", "application/json")
137])
138"#
139 );
140
141 Ok::<_, PyErr>(headers.borrow_mut().take_inner().unwrap())
142 })?;
143
144 assert_eq!(
145 header_map,
146 vec![
147 (header::CONTENT_LENGTH, "45".parse().unwrap()),
148 (
149 header::CONTENT_ENCODING,
150 "application/json".parse().unwrap()
151 ),
152 ]
153 .into_iter()
154 .collect()
155 );
156
157 Ok(())
158 }
159}