aws_smithy_http_server_python/util/
collection.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Provides Rust equivalents of [collections.abc] Python classes.
7//!
8//! Creating a custom container is achived in Python via extending a `collections.abc.*` class:
9//! ```python
10//! class MySeq(collections.abc.Sequence):
11//!     def __getitem__(self, index):  ...  # Required abstract method
12//!     def __len__(self):  ...             # Required abstract method
13//! ```
14//! You just need to implement required abstract methods and you get
15//! extra mixin methods for free.
16//!
17//! Ideally we also want to just extend abstract base classes from Python but
18//! it is not supported yet: <https://github.com/PyO3/pyo3/issues/991>.
19//!
20//! Until then, we are providing traits with the required methods and, macros that
21//! takes those types that implement those traits and provides mixin methods for them.
22//!
23//! [collections.abc]: https://docs.python.org/3/library/collections.abc.html
24
25use pyo3::PyResult;
26
27/// Rust version of [collections.abc.MutableMapping].
28///
29/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping
30pub trait PyMutableMapping {
31    type Key;
32    type Value;
33
34    fn len(&self) -> PyResult<usize>;
35    fn contains(&self, key: Self::Key) -> PyResult<bool>;
36    fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>>;
37    fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()>;
38    fn del(&mut self, key: Self::Key) -> PyResult<()>;
39
40    // TODO(Perf): This methods should return iterators instead of `Vec`s.
41    fn keys(&self) -> PyResult<Vec<Self::Key>>;
42    fn values(&self) -> PyResult<Vec<Self::Value>>;
43}
44
45/// Macro that provides mixin methods of [collections.abc.MutableMapping] to the implementing type.
46///
47/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping
48#[macro_export]
49macro_rules! mutable_mapping_pymethods {
50    ($ty:ident, keys_iter: $keys_iter: ident) => {
51        const _: fn() = || {
52            fn assert_impl<T: PyMutableMapping>() {}
53            assert_impl::<$ty>();
54        };
55
56        #[pyo3::pyclass]
57        struct $keys_iter(std::vec::IntoIter<<$ty as PyMutableMapping>::Key>);
58
59        #[pyo3::pymethods]
60        impl $keys_iter {
61            fn __next__(&mut self) -> Option<<$ty as PyMutableMapping>::Key> {
62                self.0.next()
63            }
64        }
65
66        #[pyo3::pymethods]
67        impl $ty {
68            // -- collections.abc.Sized
69
70            fn __len__(&self) -> pyo3::PyResult<usize> {
71                self.len()
72            }
73
74            // -- collections.abc.Container
75
76            fn __contains__(&self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<bool> {
77                self.contains(key)
78            }
79
80            // -- collections.abc.Iterable
81
82            /// Returns an iterator over the keys of the dictionary.
83            /// NOTE: This method currently causes all keys to be cloned.
84            fn __iter__(&self) -> pyo3::PyResult<$keys_iter> {
85                Ok($keys_iter(self.keys()?.into_iter()))
86            }
87
88            // -- collections.abc.Mapping
89
90            fn __getitem__(
91                &self,
92                key: <$ty as PyMutableMapping>::Key,
93            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
94                <$ty as PyMutableMapping>::get(&self, key)
95            }
96
97            fn get(
98                &self,
99                key: <$ty as PyMutableMapping>::Key,
100                default: Option<<$ty as PyMutableMapping>::Value>,
101            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
102                Ok(<$ty as PyMutableMapping>::get(&self, key)?.or(default))
103            }
104
105            /// Returns keys of the dictionary.
106            /// NOTE: This method currently causes all keys to be cloned.
107            fn keys(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Key>> {
108                <$ty as PyMutableMapping>::keys(&self)
109            }
110
111            /// Returns values of the dictionary.
112            /// NOTE: This method currently causes all values to be cloned.
113            fn values(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Value>> {
114                <$ty as PyMutableMapping>::values(&self)
115            }
116
117            /// Returns items (key, value) of the dictionary.
118            /// NOTE: This method currently causes all keys and values to be cloned.
119            fn items(
120                &self,
121            ) -> pyo3::PyResult<
122                Vec<(
123                    <$ty as PyMutableMapping>::Key,
124                    <$ty as PyMutableMapping>::Value,
125                )>,
126            > {
127                Ok(self
128                    .keys()?
129                    .into_iter()
130                    .zip(self.values()?.into_iter())
131                    .collect())
132            }
133
134            // -- collections.abc.MutableMapping
135
136            fn __setitem__(
137                &mut self,
138                key: <$ty as PyMutableMapping>::Key,
139                value: <$ty as PyMutableMapping>::Value,
140            ) -> pyo3::PyResult<()> {
141                self.set(key, value)
142            }
143
144            fn __delitem__(&mut self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<()> {
145                self.del(key)
146            }
147
148            fn pop(
149                &mut self,
150                key: <$ty as PyMutableMapping>::Key,
151                default: Option<<$ty as PyMutableMapping>::Value>,
152            ) -> pyo3::PyResult<<$ty as PyMutableMapping>::Value> {
153                let val = self.__getitem__(key.clone())?;
154                match val {
155                    Some(val) => {
156                        self.del(key)?;
157                        Ok(val)
158                    }
159                    None => {
160                        default.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("unknown key"))
161                    }
162                }
163            }
164
165            fn popitem(
166                &mut self,
167            ) -> pyo3::PyResult<(
168                <$ty as PyMutableMapping>::Key,
169                <$ty as PyMutableMapping>::Value,
170            )> {
171                let key = self
172                    .keys()?
173                    .iter()
174                    .cloned()
175                    .next()
176                    .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("no key"))?;
177                let value = self.pop(key.clone(), None)?;
178                Ok((key, value))
179            }
180
181            fn clear(&mut self, py: pyo3::Python) -> pyo3::PyResult<()> {
182                loop {
183                    match self.popitem() {
184                        Ok(_) => {}
185                        Err(err) if err.is_instance_of::<pyo3::exceptions::PyKeyError>(py) => {
186                            return Ok(())
187                        }
188                        Err(err) => return Err(err),
189                    }
190                }
191            }
192
193            fn setdefault(
194                &mut self,
195                key: <$ty as PyMutableMapping>::Key,
196                default: Option<<$ty as PyMutableMapping>::Value>,
197            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
198                match self.__getitem__(key.clone())? {
199                    Some(value) => Ok(Some(value)),
200                    None => {
201                        if let Some(value) = default.clone() {
202                            self.set(key, value)?;
203                        }
204                        Ok(default)
205                    }
206                }
207            }
208        }
209    };
210}
211
212#[cfg(test)]
213mod tests {
214    use std::collections::HashMap;
215
216    use pyo3::{prelude::*, py_run};
217
218    use super::*;
219
220    #[pyclass(mapping)]
221    struct Map(HashMap<String, String>);
222
223    impl PyMutableMapping for Map {
224        type Key = String;
225        type Value = String;
226
227        fn len(&self) -> PyResult<usize> {
228            Ok(self.0.len())
229        }
230
231        fn contains(&self, key: Self::Key) -> PyResult<bool> {
232            Ok(self.0.contains_key(&key))
233        }
234
235        fn keys(&self) -> PyResult<Vec<Self::Key>> {
236            Ok(self.0.keys().cloned().collect())
237        }
238
239        fn values(&self) -> PyResult<Vec<Self::Value>> {
240            Ok(self.0.values().cloned().collect())
241        }
242
243        fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
244            Ok(self.0.get(&key).cloned())
245        }
246
247        fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
248            self.0.insert(key, value);
249            Ok(())
250        }
251
252        fn del(&mut self, key: Self::Key) -> PyResult<()> {
253            self.0.remove(&key);
254            Ok(())
255        }
256    }
257
258    mutable_mapping_pymethods!(Map, keys_iter: MapKeys);
259
260    #[test]
261    fn mutable_mapping() -> PyResult<()> {
262        pyo3::prepare_freethreaded_python();
263
264        let map = Map({
265            let mut hash_map = HashMap::new();
266            hash_map.insert("foo".to_string(), "bar".to_string());
267            hash_map.insert("baz".to_string(), "qux".to_string());
268            hash_map
269        });
270
271        Python::with_gil(|py| {
272            let map = PyCell::new(py, map)?;
273            py_run!(
274                py,
275                map,
276                r#"
277# collections.abc.Sized
278assert len(map) == 2
279
280# collections.abc.Container
281assert "foo" in map
282assert "foobar" not in map
283
284# collections.abc.Iterable
285elems = ["foo", "baz"]
286
287for elem in map:
288    assert elem in elems
289
290it = iter(map)
291assert next(it) in elems
292assert next(it) in elems
293try:
294    next(it)
295    assert False, "should stop iteration"
296except StopIteration:
297    pass
298
299assert set(list(map)) == set(["foo", "baz"])
300
301# collections.abc.Mapping
302assert map["foo"] == "bar"
303assert map.get("baz") == "qux"
304assert map.get("foobar") == None
305assert map.get("foobar", "default") == "default"
306
307assert set(list(map.keys())) == set(["foo", "baz"])
308assert set(list(map.values())) == set(["bar", "qux"])
309assert set(list(map.items())) == set([("foo", "bar"), ("baz", "qux")])
310
311# collections.abc.MutableMapping
312map["foobar"] = "bazqux"
313del map["foo"]
314
315try:
316    map.pop("not_exist")
317    assert False, "should throw KeyError"
318except KeyError:
319    pass
320assert map.pop("not_exist", "default") == "default"
321assert map.pop("foobar") == "bazqux"
322assert "foobar" not in map
323
324# at this point there is only `baz => qux` in `map`
325assert map.popitem() == ("baz", "qux")
326assert len(map) == 0
327try:
328    map.popitem()
329    assert False, "should throw KeyError"
330except KeyError:
331    pass
332
333map["foo"] = "bar"
334assert len(map) == 1
335map.clear()
336assert len(map) == 0
337assert "foo" not in "bar"
338
339assert map.setdefault("foo", "bar") == "bar"
340assert map["foo"] == "bar"
341assert map.setdefault("foo", "baz") == "bar"
342
343# TODO(MissingImpl): Add tests for map.update(...)
344"#
345            );
346            Ok(())
347        })
348    }
349}