aws_smithy_runtime/
static_partition_map.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::{Mutex, MutexGuard, OnceLock};
9
10/// A data structure for persisting and sharing state between multiple clients.
11///
12/// Some state should be shared between multiple clients. For example, when creating multiple clients
13/// for the same service, it's desirable to share a client rate limiter. This way, when one client
14/// receives a throttling response, the other clients will be aware of it as well.
15///
16/// Whether clients share state is dependent on their partition key `K`. Going back to the client
17/// rate limiter example, `K` would be a struct containing the name of the service as well as the
18/// client's configured region, since receiving throttling responses in `us-east-1` shouldn't
19/// throttle requests to the same service made in other regions.
20///
21/// Values stored in a `StaticPartitionMap` will be cloned whenever they are requested. Values must
22/// be initialized before they can be retrieved, and the `StaticPartitionMap::get_or_init` method is
23/// how you can ensure this.
24///
25/// # Example
26///
27/// ```
28///use std::sync::{Arc, Mutex};
29/// use aws_smithy_runtime::static_partition_map::StaticPartitionMap;
30///
31/// // The shared state must be `Clone` and will be internally mutable. Deriving `Default` isn't
32/// // necessary, but allows us to use the `StaticPartitionMap::get_or_init_default` method.
33/// #[derive(Clone, Default)]
34/// pub struct SomeSharedState {
35///     inner: Arc<Mutex<Inner>>
36/// }
37///
38/// #[derive(Default)]
39/// struct Inner {
40///     // Some shared state...
41/// }
42///
43/// // `Clone`, `Hash`, and `Eq` are all required trait impls for partition keys
44/// #[derive(Clone, Hash, PartialEq, Eq)]
45/// pub struct SharedStatePartition {
46///     region: String,
47///     service_name: String,
48/// }
49///
50/// impl SharedStatePartition {
51///     pub fn new(region: impl Into<String>, service_name: impl Into<String>) -> Self {
52///         Self { region: region.into(), service_name: service_name.into() }
53///     }
54/// }
55///
56/// static SOME_SHARED_STATE: StaticPartitionMap<SharedStatePartition, SomeSharedState> = StaticPartitionMap::new();
57///
58/// struct Client {
59///     shared_state: SomeSharedState,
60/// }
61///
62/// impl Client {
63///     pub fn new() -> Self {
64///         let key = SharedStatePartition::new("us-east-1", "example_service_20230628");
65///         Self {
66///             // If the stored value implements `Default`, you can call the
67///             // `StaticPartitionMap::get_or_init_default` convenience method.
68///             shared_state: SOME_SHARED_STATE.get_or_init_default(key),
69///         }
70///     }
71/// }
72/// ```
73#[derive(Debug, Default)]
74pub struct StaticPartitionMap<K, V> {
75    inner: OnceLock<Mutex<HashMap<K, V>>>,
76}
77
78impl<K, V> StaticPartitionMap<K, V> {
79    /// Creates a new `StaticPartitionMap`.
80    pub const fn new() -> Self {
81        Self {
82            inner: OnceLock::new(),
83        }
84    }
85}
86
87impl<K, V> StaticPartitionMap<K, V>
88where
89    K: Eq + Hash,
90{
91    fn get_or_init_inner(&self) -> MutexGuard<'_, HashMap<K, V>> {
92        self.inner
93            // At the very least, we'll always be storing the default state.
94            .get_or_init(|| Mutex::new(HashMap::with_capacity(1)))
95            .lock()
96            .unwrap()
97    }
98}
99
100impl<K, V> StaticPartitionMap<K, V>
101where
102    K: Eq + Hash,
103    V: Clone,
104{
105    /// Gets the value for the given partition key.
106    #[must_use]
107    pub fn get(&self, partition_key: K) -> Option<V> {
108        self.get_or_init_inner().get(&partition_key).cloned()
109    }
110
111    /// Gets the value for the given partition key, initializing it with `init` if it doesn't exist.
112    #[must_use]
113    pub fn get_or_init<F>(&self, partition_key: K, init: F) -> V
114    where
115        F: FnOnce() -> V,
116    {
117        let mut inner = self.get_or_init_inner();
118        let v = inner.entry(partition_key).or_insert_with(init);
119        v.clone()
120    }
121}
122
123impl<K, V> StaticPartitionMap<K, V>
124where
125    K: Eq + Hash,
126    V: Clone + Default,
127{
128    /// Gets the value for the given partition key, initializing it if it doesn't exist.
129    #[must_use]
130    pub fn get_or_init_default(&self, partition_key: K) -> V {
131        self.get_or_init(partition_key, V::default)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::StaticPartitionMap;
138
139    #[test]
140    fn test_keyed_partition_returns_same_value_for_same_key() {
141        let kp = StaticPartitionMap::new();
142        let _ = kp.get_or_init("A", || "A".to_owned());
143        let actual = kp.get_or_init("A", || "B".to_owned());
144        let expected = "A".to_owned();
145        assert_eq!(expected, actual);
146    }
147
148    #[test]
149    fn test_keyed_partition_returns_different_value_for_different_key() {
150        let kp = StaticPartitionMap::new();
151        let _ = kp.get_or_init("A", || "A".to_owned());
152        let actual = kp.get_or_init("B", || "B".to_owned());
153
154        let expected = "B".to_owned();
155        assert_eq!(expected, actual);
156
157        let actual = kp.get("A").unwrap();
158        let expected = "A".to_owned();
159        assert_eq!(expected, actual);
160    }
161}