aws_smithy_http_server_python/middleware/
header_map.rs1#![allow(non_local_definitions)]
7
8use std::{mem, str::FromStr, sync::Arc};
9
10use http::{header::HeaderName, HeaderMap, HeaderValue};
11use parking_lot::Mutex;
12use pyo3::{
13 exceptions::{PyKeyError, PyValueError},
14 pyclass, PyErr, PyResult,
15};
16
17use crate::{mutable_mapping_pymethods, util::collection::PyMutableMapping};
18
19#[pyclass(mapping)]
21#[derive(Clone, Debug)]
22pub struct PyHeaderMap {
23 inner: Arc<Mutex<HeaderMap>>,
24}
25
26impl PyHeaderMap {
27 pub fn new(inner: HeaderMap) -> Self {
28 Self {
29 inner: Arc::new(Mutex::new(inner)),
30 }
31 }
32
33 pub fn take_inner(&mut self) -> Option<HeaderMap> {
37 let header_map = mem::take(&mut self.inner);
38 let header_map = Arc::try_unwrap(header_map).ok()?;
39 let header_map = header_map.into_inner();
40 Some(header_map)
41 }
42}
43
44impl PyMutableMapping for PyHeaderMap {
47 type Key = String;
48 type Value = String;
49
50 fn len(&self) -> PyResult<usize> {
51 Ok(self.inner.lock().len())
52 }
53
54 fn contains(&self, key: Self::Key) -> PyResult<bool> {
55 Ok(self.inner.lock().contains_key(key))
56 }
57
58 fn keys(&self) -> PyResult<Vec<Self::Key>> {
59 Ok(self.inner.lock().keys().map(|h| h.to_string()).collect())
60 }
61
62 fn values(&self) -> PyResult<Vec<Self::Value>> {
63 self.inner
64 .lock()
65 .values()
66 .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
67 .collect()
68 }
69
70 fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
71 self.inner
72 .lock()
73 .get(key)
74 .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
75 .transpose()
76 }
77
78 fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
79 self.inner.lock().insert(
80 HeaderName::from_str(&key).map_err(to_value_error)?,
81 HeaderValue::from_str(&value).map_err(to_value_error)?,
82 );
83 Ok(())
84 }
85
86 fn del(&mut self, key: Self::Key) -> PyResult<()> {
87 if self.inner.lock().remove(key).is_none() {
88 Err(PyKeyError::new_err("unknown key"))
89 } else {
90 Ok(())
91 }
92 }
93}
94
95mutable_mapping_pymethods!(PyHeaderMap, keys_iter: PyHeaderMapKeys);
96
97fn to_value_error(err: impl std::error::Error) -> PyErr {
98 PyValueError::new_err(err.to_string())
99}
100
101#[cfg(test)]
102mod tests {
103 use http::header;
104 use pyo3::{prelude::*, py_run};
105
106 use super::*;
107
108 #[test]
109 fn py_header_map() -> PyResult<()> {
110 pyo3::prepare_freethreaded_python();
111
112 let mut header_map = HeaderMap::new();
113 header_map.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
114 header_map.insert(header::HOST, "localhost".parse().unwrap());
115
116 let header_map = Python::with_gil(|py| {
117 let py_header_map = PyHeaderMap::new(header_map);
118 let headers = PyCell::new(py, py_header_map)?;
119 py_run!(
120 py,
121 headers,
122 r#"
123assert len(headers) == 2
124assert headers["content-length"] == "42"
125assert headers["host"] == "localhost"
126
127headers["content-length"] = "45"
128assert headers["content-length"] == "45"
129headers["content-encoding"] = "application/json"
130assert headers["content-encoding"] == "application/json"
131
132del headers["host"]
133assert headers.get("host") == None
134assert len(headers) == 2
135
136assert set(headers.items()) == set([
137 ("content-length", "45"),
138 ("content-encoding", "application/json")
139])
140"#
141 );
142
143 Ok::<_, PyErr>(headers.borrow_mut().take_inner().unwrap())
144 })?;
145
146 assert_eq!(
147 header_map,
148 vec![
149 (header::CONTENT_LENGTH, "45".parse().unwrap()),
150 (
151 header::CONTENT_ENCODING,
152 "application/json".parse().unwrap()
153 ),
154 ]
155 .into_iter()
156 .collect()
157 );
158
159 Ok(())
160 }
161}