aws_smithy_http_server_python/util/
collection.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Provides Rust equivalents of [collections.abc] Python classes.
7//!
8//! Creating a custom container is achived in Python via extending a `collections.abc.*` class:
9//! ```python
10//! class MySeq(collections.abc.Sequence):
11//!     def __getitem__(self, index):  ...  # Required abstract method
12//!     def __len__(self):  ...             # Required abstract method
13//! ```
14//! You just need to implement required abstract methods and you get
15//! extra mixin methods for free.
16//!
17//! Ideally we also want to just extend abstract base classes from Python but
18//! it is not supported yet: <https://github.com/PyO3/pyo3/issues/991>.
19//!
20//! Until then, we are providing traits with the required methods and, macros that
21//! takes those types that implement those traits and provides mixin methods for them.
22//!
23//! [collections.abc]: https://docs.python.org/3/library/collections.abc.html
24
25#![allow(non_local_definitions)]
26
27use pyo3::PyResult;
28
29/// Rust version of [collections.abc.MutableMapping].
30///
31/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping
32pub trait PyMutableMapping {
33    type Key;
34    type Value;
35
36    fn len(&self) -> PyResult<usize>;
37    fn contains(&self, key: Self::Key) -> PyResult<bool>;
38    fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>>;
39    fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()>;
40    fn del(&mut self, key: Self::Key) -> PyResult<()>;
41
42    // TODO(Perf): This methods should return iterators instead of `Vec`s.
43    fn keys(&self) -> PyResult<Vec<Self::Key>>;
44    fn values(&self) -> PyResult<Vec<Self::Value>>;
45}
46
47/// Macro that provides mixin methods of [collections.abc.MutableMapping] to the implementing type.
48///
49/// [collections.abc.MutableMapping]: https://docs.python.org/3/library/collections.abc.html#collections.abc.MutableMapping
50#[macro_export]
51macro_rules! mutable_mapping_pymethods {
52    ($ty:ident, keys_iter: $keys_iter: ident) => {
53        const _: fn() = || {
54            fn assert_impl<T: PyMutableMapping>() {}
55            assert_impl::<$ty>();
56        };
57
58        #[pyo3::pyclass]
59        struct $keys_iter(std::vec::IntoIter<<$ty as PyMutableMapping>::Key>);
60
61        #[pyo3::pymethods]
62        impl $keys_iter {
63            fn __next__(&mut self) -> Option<<$ty as PyMutableMapping>::Key> {
64                self.0.next()
65            }
66        }
67
68        #[pyo3::pymethods]
69        impl $ty {
70            // -- collections.abc.Sized
71
72            fn __len__(&self) -> pyo3::PyResult<usize> {
73                self.len()
74            }
75
76            // -- collections.abc.Container
77
78            fn __contains__(&self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<bool> {
79                self.contains(key)
80            }
81
82            // -- collections.abc.Iterable
83
84            /// Returns an iterator over the keys of the dictionary.
85            /// NOTE: This method currently causes all keys to be cloned.
86            fn __iter__(&self) -> pyo3::PyResult<$keys_iter> {
87                Ok($keys_iter(self.keys()?.into_iter()))
88            }
89
90            // -- collections.abc.Mapping
91
92            fn __getitem__(
93                &self,
94                key: <$ty as PyMutableMapping>::Key,
95            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
96                <$ty as PyMutableMapping>::get(&self, key)
97            }
98
99            fn get(
100                &self,
101                key: <$ty as PyMutableMapping>::Key,
102                default: Option<<$ty as PyMutableMapping>::Value>,
103            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
104                Ok(<$ty as PyMutableMapping>::get(&self, key)?.or(default))
105            }
106
107            /// Returns keys of the dictionary.
108            /// NOTE: This method currently causes all keys to be cloned.
109            fn keys(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Key>> {
110                <$ty as PyMutableMapping>::keys(&self)
111            }
112
113            /// Returns values of the dictionary.
114            /// NOTE: This method currently causes all values to be cloned.
115            fn values(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Value>> {
116                <$ty as PyMutableMapping>::values(&self)
117            }
118
119            /// Returns items (key, value) of the dictionary.
120            /// NOTE: This method currently causes all keys and values to be cloned.
121            fn items(
122                &self,
123            ) -> pyo3::PyResult<
124                Vec<(
125                    <$ty as PyMutableMapping>::Key,
126                    <$ty as PyMutableMapping>::Value,
127                )>,
128            > {
129                Ok(self
130                    .keys()?
131                    .into_iter()
132                    .zip(self.values()?.into_iter())
133                    .collect())
134            }
135
136            // -- collections.abc.MutableMapping
137
138            fn __setitem__(
139                &mut self,
140                key: <$ty as PyMutableMapping>::Key,
141                value: <$ty as PyMutableMapping>::Value,
142            ) -> pyo3::PyResult<()> {
143                self.set(key, value)
144            }
145
146            fn __delitem__(&mut self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<()> {
147                self.del(key)
148            }
149
150            fn pop(
151                &mut self,
152                key: <$ty as PyMutableMapping>::Key,
153                default: Option<<$ty as PyMutableMapping>::Value>,
154            ) -> pyo3::PyResult<<$ty as PyMutableMapping>::Value> {
155                let val = self.__getitem__(key.clone())?;
156                match val {
157                    Some(val) => {
158                        self.del(key)?;
159                        Ok(val)
160                    }
161                    None => {
162                        default.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("unknown key"))
163                    }
164                }
165            }
166
167            fn popitem(
168                &mut self,
169            ) -> pyo3::PyResult<(
170                <$ty as PyMutableMapping>::Key,
171                <$ty as PyMutableMapping>::Value,
172            )> {
173                let key = self
174                    .keys()?
175                    .iter()
176                    .cloned()
177                    .next()
178                    .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("no key"))?;
179                let value = self.pop(key.clone(), None)?;
180                Ok((key, value))
181            }
182
183            fn clear(&mut self, py: pyo3::Python) -> pyo3::PyResult<()> {
184                loop {
185                    match self.popitem() {
186                        Ok(_) => {}
187                        Err(err) if err.is_instance_of::<pyo3::exceptions::PyKeyError>(py) => {
188                            return Ok(())
189                        }
190                        Err(err) => return Err(err),
191                    }
192                }
193            }
194
195            fn setdefault(
196                &mut self,
197                key: <$ty as PyMutableMapping>::Key,
198                default: Option<<$ty as PyMutableMapping>::Value>,
199            ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
200                match self.__getitem__(key.clone())? {
201                    Some(value) => Ok(Some(value)),
202                    None => {
203                        if let Some(value) = default.clone() {
204                            self.set(key, value)?;
205                        }
206                        Ok(default)
207                    }
208                }
209            }
210        }
211    };
212}
213
214#[cfg(test)]
215mod tests {
216    use std::collections::HashMap;
217
218    use pyo3::{prelude::*, py_run};
219
220    use super::*;
221
222    #[pyclass(mapping)]
223    struct Map(HashMap<String, String>);
224
225    impl PyMutableMapping for Map {
226        type Key = String;
227        type Value = String;
228
229        fn len(&self) -> PyResult<usize> {
230            Ok(self.0.len())
231        }
232
233        fn contains(&self, key: Self::Key) -> PyResult<bool> {
234            Ok(self.0.contains_key(&key))
235        }
236
237        fn keys(&self) -> PyResult<Vec<Self::Key>> {
238            Ok(self.0.keys().cloned().collect())
239        }
240
241        fn values(&self) -> PyResult<Vec<Self::Value>> {
242            Ok(self.0.values().cloned().collect())
243        }
244
245        fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
246            Ok(self.0.get(&key).cloned())
247        }
248
249        fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
250            self.0.insert(key, value);
251            Ok(())
252        }
253
254        fn del(&mut self, key: Self::Key) -> PyResult<()> {
255            self.0.remove(&key);
256            Ok(())
257        }
258    }
259
260    mutable_mapping_pymethods!(Map, keys_iter: MapKeys);
261
262    #[test]
263    fn mutable_mapping() -> PyResult<()> {
264        pyo3::prepare_freethreaded_python();
265
266        let map = Map({
267            let mut hash_map = HashMap::new();
268            hash_map.insert("foo".to_string(), "bar".to_string());
269            hash_map.insert("baz".to_string(), "qux".to_string());
270            hash_map
271        });
272
273        Python::with_gil(|py| {
274            let map = PyCell::new(py, map)?;
275            py_run!(
276                py,
277                map,
278                r#"
279# collections.abc.Sized
280assert len(map) == 2
281
282# collections.abc.Container
283assert "foo" in map
284assert "foobar" not in map
285
286# collections.abc.Iterable
287elems = ["foo", "baz"]
288
289for elem in map:
290    assert elem in elems
291
292it = iter(map)
293assert next(it) in elems
294assert next(it) in elems
295try:
296    next(it)
297    assert False, "should stop iteration"
298except StopIteration:
299    pass
300
301assert set(list(map)) == set(["foo", "baz"])
302
303# collections.abc.Mapping
304assert map["foo"] == "bar"
305assert map.get("baz") == "qux"
306assert map.get("foobar") == None
307assert map.get("foobar", "default") == "default"
308
309assert set(list(map.keys())) == set(["foo", "baz"])
310assert set(list(map.values())) == set(["bar", "qux"])
311assert set(list(map.items())) == set([("foo", "bar"), ("baz", "qux")])
312
313# collections.abc.MutableMapping
314map["foobar"] = "bazqux"
315del map["foo"]
316
317try:
318    map.pop("not_exist")
319    assert False, "should throw KeyError"
320except KeyError:
321    pass
322assert map.pop("not_exist", "default") == "default"
323assert map.pop("foobar") == "bazqux"
324assert "foobar" not in map
325
326# at this point there is only `baz => qux` in `map`
327assert map.popitem() == ("baz", "qux")
328assert len(map) == 0
329try:
330    map.popitem()
331    assert False, "should throw KeyError"
332except KeyError:
333    pass
334
335map["foo"] = "bar"
336assert len(map) == 1
337map.clear()
338assert len(map) == 0
339assert "foo" not in "bar"
340
341assert map.setdefault("foo", "bar") == "bar"
342assert map["foo"] == "bar"
343assert map.setdefault("foo", "baz") == "bar"
344
345# TODO(MissingImpl): Add tests for map.update(...)
346"#
347            );
348            Ok(())
349        })
350    }
351}