aws_smithy_http_server_python/middleware/
header_map.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::{mem, str::FromStr, sync::Arc};
7
8use http::{header::HeaderName, HeaderMap, HeaderValue};
9use parking_lot::Mutex;
10use pyo3::{
11    exceptions::{PyKeyError, PyValueError},
12    pyclass, PyErr, PyResult,
13};
14
15use crate::{mutable_mapping_pymethods, util::collection::PyMutableMapping};
16
17/// Python-compatible [HeaderMap] object.
18#[pyclass(mapping)]
19#[derive(Clone, Debug)]
20pub struct PyHeaderMap {
21    inner: Arc<Mutex<HeaderMap>>,
22}
23
24impl PyHeaderMap {
25    pub fn new(inner: HeaderMap) -> Self {
26        Self {
27            inner: Arc::new(Mutex::new(inner)),
28        }
29    }
30
31    // Consumes self by taking the inner `HeaderMap`.
32    // This method would have been `into_inner(self) -> HeaderMap`
33    // but we can't do that because we are crossing Python boundary.
34    pub fn take_inner(&mut self) -> Option<HeaderMap> {
35        let header_map = mem::take(&mut self.inner);
36        let header_map = Arc::try_unwrap(header_map).ok()?;
37        let header_map = header_map.into_inner();
38        Some(header_map)
39    }
40}
41
42/// By implementing [PyMutableMapping] for [PyHeaderMap] we are making it to
43/// behave like a dictionary on the Python.
44impl PyMutableMapping for PyHeaderMap {
45    type Key = String;
46    type Value = String;
47
48    fn len(&self) -> PyResult<usize> {
49        Ok(self.inner.lock().len())
50    }
51
52    fn contains(&self, key: Self::Key) -> PyResult<bool> {
53        Ok(self.inner.lock().contains_key(key))
54    }
55
56    fn keys(&self) -> PyResult<Vec<Self::Key>> {
57        Ok(self.inner.lock().keys().map(|h| h.to_string()).collect())
58    }
59
60    fn values(&self) -> PyResult<Vec<Self::Value>> {
61        self.inner
62            .lock()
63            .values()
64            .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
65            .collect()
66    }
67
68    fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
69        self.inner
70            .lock()
71            .get(key)
72            .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
73            .transpose()
74    }
75
76    fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
77        self.inner.lock().insert(
78            HeaderName::from_str(&key).map_err(to_value_error)?,
79            HeaderValue::from_str(&value).map_err(to_value_error)?,
80        );
81        Ok(())
82    }
83
84    fn del(&mut self, key: Self::Key) -> PyResult<()> {
85        if self.inner.lock().remove(key).is_none() {
86            Err(PyKeyError::new_err("unknown key"))
87        } else {
88            Ok(())
89        }
90    }
91}
92
93mutable_mapping_pymethods!(PyHeaderMap, keys_iter: PyHeaderMapKeys);
94
95fn to_value_error(err: impl std::error::Error) -> PyErr {
96    PyValueError::new_err(err.to_string())
97}
98
99#[cfg(test)]
100mod tests {
101    use http::header;
102    use pyo3::{prelude::*, py_run};
103
104    use super::*;
105
106    #[test]
107    fn py_header_map() -> PyResult<()> {
108        pyo3::prepare_freethreaded_python();
109
110        let mut header_map = HeaderMap::new();
111        header_map.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
112        header_map.insert(header::HOST, "localhost".parse().unwrap());
113
114        let header_map = Python::with_gil(|py| {
115            let py_header_map = PyHeaderMap::new(header_map);
116            let headers = PyCell::new(py, py_header_map)?;
117            py_run!(
118                py,
119                headers,
120                r#"
121assert len(headers) == 2
122assert headers["content-length"] == "42"
123assert headers["host"] == "localhost"
124
125headers["content-length"] = "45"
126assert headers["content-length"] == "45"
127headers["content-encoding"] = "application/json"
128assert headers["content-encoding"] == "application/json"
129
130del headers["host"]
131assert headers.get("host") == None
132assert len(headers) == 2
133
134assert set(headers.items()) == set([
135    ("content-length", "45"),
136    ("content-encoding", "application/json")
137])
138"#
139            );
140
141            Ok::<_, PyErr>(headers.borrow_mut().take_inner().unwrap())
142        })?;
143
144        assert_eq!(
145            header_map,
146            vec![
147                (header::CONTENT_LENGTH, "45".parse().unwrap()),
148                (
149                    header::CONTENT_ENCODING,
150                    "application/json".parse().unwrap()
151                ),
152            ]
153            .into_iter()
154            .collect()
155        );
156
157        Ok(())
158    }
159}