aws_smithy_http_server_python/util/
collection.rs1#![allow(non_local_definitions)]
26
27use pyo3::PyResult;
28
29pub trait PyMutableMapping {
33 type Key;
34 type Value;
35
36 fn len(&self) -> PyResult<usize>;
37 fn contains(&self, key: Self::Key) -> PyResult<bool>;
38 fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>>;
39 fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()>;
40 fn del(&mut self, key: Self::Key) -> PyResult<()>;
41
42 fn keys(&self) -> PyResult<Vec<Self::Key>>;
44 fn values(&self) -> PyResult<Vec<Self::Value>>;
45}
46
47#[macro_export]
51macro_rules! mutable_mapping_pymethods {
52 ($ty:ident, keys_iter: $keys_iter: ident) => {
53 const _: fn() = || {
54 fn assert_impl<T: PyMutableMapping>() {}
55 assert_impl::<$ty>();
56 };
57
58 #[pyo3::pyclass]
59 struct $keys_iter(std::vec::IntoIter<<$ty as PyMutableMapping>::Key>);
60
61 #[pyo3::pymethods]
62 impl $keys_iter {
63 fn __next__(&mut self) -> Option<<$ty as PyMutableMapping>::Key> {
64 self.0.next()
65 }
66 }
67
68 #[pyo3::pymethods]
69 impl $ty {
70 fn __len__(&self) -> pyo3::PyResult<usize> {
73 self.len()
74 }
75
76 fn __contains__(&self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<bool> {
79 self.contains(key)
80 }
81
82 fn __iter__(&self) -> pyo3::PyResult<$keys_iter> {
87 Ok($keys_iter(self.keys()?.into_iter()))
88 }
89
90 fn __getitem__(
93 &self,
94 key: <$ty as PyMutableMapping>::Key,
95 ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
96 <$ty as PyMutableMapping>::get(&self, key)
97 }
98
99 fn get(
100 &self,
101 key: <$ty as PyMutableMapping>::Key,
102 default: Option<<$ty as PyMutableMapping>::Value>,
103 ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
104 Ok(<$ty as PyMutableMapping>::get(&self, key)?.or(default))
105 }
106
107 fn keys(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Key>> {
110 <$ty as PyMutableMapping>::keys(&self)
111 }
112
113 fn values(&self) -> pyo3::PyResult<Vec<<$ty as PyMutableMapping>::Value>> {
116 <$ty as PyMutableMapping>::values(&self)
117 }
118
119 fn items(
122 &self,
123 ) -> pyo3::PyResult<
124 Vec<(
125 <$ty as PyMutableMapping>::Key,
126 <$ty as PyMutableMapping>::Value,
127 )>,
128 > {
129 Ok(self
130 .keys()?
131 .into_iter()
132 .zip(self.values()?.into_iter())
133 .collect())
134 }
135
136 fn __setitem__(
139 &mut self,
140 key: <$ty as PyMutableMapping>::Key,
141 value: <$ty as PyMutableMapping>::Value,
142 ) -> pyo3::PyResult<()> {
143 self.set(key, value)
144 }
145
146 fn __delitem__(&mut self, key: <$ty as PyMutableMapping>::Key) -> pyo3::PyResult<()> {
147 self.del(key)
148 }
149
150 fn pop(
151 &mut self,
152 key: <$ty as PyMutableMapping>::Key,
153 default: Option<<$ty as PyMutableMapping>::Value>,
154 ) -> pyo3::PyResult<<$ty as PyMutableMapping>::Value> {
155 let val = self.__getitem__(key.clone())?;
156 match val {
157 Some(val) => {
158 self.del(key)?;
159 Ok(val)
160 }
161 None => {
162 default.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("unknown key"))
163 }
164 }
165 }
166
167 fn popitem(
168 &mut self,
169 ) -> pyo3::PyResult<(
170 <$ty as PyMutableMapping>::Key,
171 <$ty as PyMutableMapping>::Value,
172 )> {
173 let key = self
174 .keys()?
175 .iter()
176 .cloned()
177 .next()
178 .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("no key"))?;
179 let value = self.pop(key.clone(), None)?;
180 Ok((key, value))
181 }
182
183 fn clear(&mut self, py: pyo3::Python) -> pyo3::PyResult<()> {
184 loop {
185 match self.popitem() {
186 Ok(_) => {}
187 Err(err) if err.is_instance_of::<pyo3::exceptions::PyKeyError>(py) => {
188 return Ok(())
189 }
190 Err(err) => return Err(err),
191 }
192 }
193 }
194
195 fn setdefault(
196 &mut self,
197 key: <$ty as PyMutableMapping>::Key,
198 default: Option<<$ty as PyMutableMapping>::Value>,
199 ) -> pyo3::PyResult<Option<<$ty as PyMutableMapping>::Value>> {
200 match self.__getitem__(key.clone())? {
201 Some(value) => Ok(Some(value)),
202 None => {
203 if let Some(value) = default.clone() {
204 self.set(key, value)?;
205 }
206 Ok(default)
207 }
208 }
209 }
210 }
211 };
212}
213
214#[cfg(test)]
215mod tests {
216 use std::collections::HashMap;
217
218 use pyo3::{prelude::*, py_run};
219
220 use super::*;
221
222 #[pyclass(mapping)]
223 struct Map(HashMap<String, String>);
224
225 impl PyMutableMapping for Map {
226 type Key = String;
227 type Value = String;
228
229 fn len(&self) -> PyResult<usize> {
230 Ok(self.0.len())
231 }
232
233 fn contains(&self, key: Self::Key) -> PyResult<bool> {
234 Ok(self.0.contains_key(&key))
235 }
236
237 fn keys(&self) -> PyResult<Vec<Self::Key>> {
238 Ok(self.0.keys().cloned().collect())
239 }
240
241 fn values(&self) -> PyResult<Vec<Self::Value>> {
242 Ok(self.0.values().cloned().collect())
243 }
244
245 fn get(&self, key: Self::Key) -> PyResult<Option<Self::Value>> {
246 Ok(self.0.get(&key).cloned())
247 }
248
249 fn set(&mut self, key: Self::Key, value: Self::Value) -> PyResult<()> {
250 self.0.insert(key, value);
251 Ok(())
252 }
253
254 fn del(&mut self, key: Self::Key) -> PyResult<()> {
255 self.0.remove(&key);
256 Ok(())
257 }
258 }
259
260 mutable_mapping_pymethods!(Map, keys_iter: MapKeys);
261
262 #[test]
263 fn mutable_mapping() -> PyResult<()> {
264 pyo3::prepare_freethreaded_python();
265
266 let map = Map({
267 let mut hash_map = HashMap::new();
268 hash_map.insert("foo".to_string(), "bar".to_string());
269 hash_map.insert("baz".to_string(), "qux".to_string());
270 hash_map
271 });
272
273 Python::with_gil(|py| {
274 let map = PyCell::new(py, map)?;
275 py_run!(
276 py,
277 map,
278 r#"
279# collections.abc.Sized
280assert len(map) == 2
281
282# collections.abc.Container
283assert "foo" in map
284assert "foobar" not in map
285
286# collections.abc.Iterable
287elems = ["foo", "baz"]
288
289for elem in map:
290 assert elem in elems
291
292it = iter(map)
293assert next(it) in elems
294assert next(it) in elems
295try:
296 next(it)
297 assert False, "should stop iteration"
298except StopIteration:
299 pass
300
301assert set(list(map)) == set(["foo", "baz"])
302
303# collections.abc.Mapping
304assert map["foo"] == "bar"
305assert map.get("baz") == "qux"
306assert map.get("foobar") == None
307assert map.get("foobar", "default") == "default"
308
309assert set(list(map.keys())) == set(["foo", "baz"])
310assert set(list(map.values())) == set(["bar", "qux"])
311assert set(list(map.items())) == set([("foo", "bar"), ("baz", "qux")])
312
313# collections.abc.MutableMapping
314map["foobar"] = "bazqux"
315del map["foo"]
316
317try:
318 map.pop("not_exist")
319 assert False, "should throw KeyError"
320except KeyError:
321 pass
322assert map.pop("not_exist", "default") == "default"
323assert map.pop("foobar") == "bazqux"
324assert "foobar" not in map
325
326# at this point there is only `baz => qux` in `map`
327assert map.popitem() == ("baz", "qux")
328assert len(map) == 0
329try:
330 map.popitem()
331 assert False, "should throw KeyError"
332except KeyError:
333 pass
334
335map["foo"] = "bar"
336assert len(map) == 1
337map.clear()
338assert len(map) == 0
339assert "foo" not in "bar"
340
341assert map.setdefault("foo", "bar") == "bar"
342assert map["foo"] == "bar"
343assert map.setdefault("foo", "baz") == "bar"
344
345# TODO(MissingImpl): Add tests for map.update(...)
346"#
347 );
348 Ok(())
349 })
350 }
351}