aws_smithy_http_server_python/middleware/
header_map.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6#![allow(non_local_definitions)]
7
8use std::{mem, str::FromStr, sync::Arc};
9
10use http::{header::HeaderName, HeaderMap, HeaderValue};
11use parking_lot::Mutex;
12use pyo3::{
13    exceptions::{PyKeyError, PyValueError},
14    pyclass, PyErr, PyResult,
15};
16
17use crate::{mutable_mapping_pymethods, util::collection::PyMutableMapping};
18
19/// Python-compatible [HeaderMap] object.
20#[pyclass(mapping)]
21#[derive(Clone, Debug)]
22pub struct PyHeaderMap {
23    inner: Arc<Mutex<HeaderMap>>,
24}
25
26impl PyHeaderMap {
27    pub fn new(inner: HeaderMap) -> Self {
28        Self {
29            inner: Arc::new(Mutex::new(inner)),
30        }
31    }
32
33    // Consumes self by taking the inner `HeaderMap`.
34    // This method would have been `into_inner(self) -> HeaderMap`
35    // but we can't do that because we are crossing Python boundary.
36    pub fn take_inner(&mut self) -> Option<HeaderMap> {
37        let header_map = mem::take(&mut self.inner);
38        let header_map = Arc::try_unwrap(header_map).ok()?;
39        let header_map = header_map.into_inner();
40        Some(header_map)
41    }
42}
43
44/// By implementing [PyMutableMapping] for [PyHeaderMap] we are making it to
45/// behave like a dictionary on the Python.
46impl PyMutableMapping for PyHeaderMap {
47    type Key = String;
48    type Value = String;
49
50    fn len(&self) -> PyResult<usize> {
51        Ok(self.inner.lock().len())
52    }
53
54    fn contains(&self, key: Self::Key) -> PyResult<bool> {
55        Ok(self.inner.lock().contains_key(key))
56    }
57
58    fn keys(&self) -> PyResult<Vec<Self::Key>> {
59        Ok(self.inner.lock().keys().map(|h| h.to_string()).collect())
60    }
61
62    fn values(&self) -> PyResult<Vec<Self::Value>> {
63        self.inner
64            .lock()
65            .values()
66            .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
67            .collect()
68    }
69
70    fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
71        self.inner
72            .lock()
73            .get(key)
74            .map(|h| h.to_str().map(|s| s.to_string()).map_err(to_value_error))
75            .transpose()
76    }
77
78    fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
79        self.inner.lock().insert(
80            HeaderName::from_str(&key).map_err(to_value_error)?,
81            HeaderValue::from_str(&value).map_err(to_value_error)?,
82        );
83        Ok(())
84    }
85
86    fn del(&mut self, key: Self::Key) -> PyResult<()> {
87        if self.inner.lock().remove(key).is_none() {
88            Err(PyKeyError::new_err("unknown key"))
89        } else {
90            Ok(())
91        }
92    }
93}
94
95mutable_mapping_pymethods!(PyHeaderMap, keys_iter: PyHeaderMapKeys);
96
97fn to_value_error(err: impl std::error::Error) -> PyErr {
98    PyValueError::new_err(err.to_string())
99}
100
101#[cfg(test)]
102mod tests {
103    use http::header;
104    use pyo3::{prelude::*, py_run};
105
106    use super::*;
107
108    #[test]
109    fn py_header_map() -> PyResult<()> {
110        pyo3::prepare_freethreaded_python();
111
112        let mut header_map = HeaderMap::new();
113        header_map.insert(header::CONTENT_LENGTH, "42".parse().unwrap());
114        header_map.insert(header::HOST, "localhost".parse().unwrap());
115
116        let header_map = Python::with_gil(|py| {
117            let py_header_map = PyHeaderMap::new(header_map);
118            let headers = PyCell::new(py, py_header_map)?;
119            py_run!(
120                py,
121                headers,
122                r#"
123assert len(headers) == 2
124assert headers["content-length"] == "42"
125assert headers["host"] == "localhost"
126
127headers["content-length"] = "45"
128assert headers["content-length"] == "45"
129headers["content-encoding"] = "application/json"
130assert headers["content-encoding"] == "application/json"
131
132del headers["host"]
133assert headers.get("host") == None
134assert len(headers) == 2
135
136assert set(headers.items()) == set([
137    ("content-length", "45"),
138    ("content-encoding", "application/json")
139])
140"#
141            );
142
143            Ok::<_, PyErr>(headers.borrow_mut().take_inner().unwrap())
144        })?;
145
146        assert_eq!(
147            header_map,
148            vec![
149                (header::CONTENT_LENGTH, "45".parse().unwrap()),
150                (
151                    header::CONTENT_ENCODING,
152                    "application/json".parse().unwrap()
153                ),
154            ]
155            .into_iter()
156            .collect()
157        );
158
159        Ok(())
160    }
161}