aws_smithy_async/future/
rendezvous.rs1use std::future::poll_fn;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use tokio::sync::Semaphore;
19
20pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
26    let (tx, rx) = tokio::sync::mpsc::channel(1);
27    let semaphore = Arc::new(Semaphore::new(0));
28    (
29        Sender {
30            semaphore: semaphore.clone(),
31            chan: tx,
32        },
33        Receiver {
34            semaphore,
35            chan: rx,
36            needs_permit: false,
37        },
38    )
39}
40
41pub mod error {
43    use std::fmt;
44    use tokio::sync::mpsc::error::SendError as TokioSendError;
45
46    #[derive(Debug)]
48    pub struct SendError<T> {
49        source: TokioSendError<T>,
50    }
51
52    impl<T> SendError<T> {
53        pub(crate) fn tokio_send_error(source: TokioSendError<T>) -> Self {
54            Self { source }
55        }
56    }
57
58    impl<T> fmt::Display for SendError<T> {
59        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60            write!(f, "failed to send value to the receiver")
61        }
62    }
63
64    impl<T: fmt::Debug + 'static> std::error::Error for SendError<T> {
65        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66            Some(&self.source)
67        }
68    }
69}
70
71#[derive(Debug)]
72pub struct Sender<T> {
74    semaphore: Arc<Semaphore>,
75    chan: tokio::sync::mpsc::Sender<T>,
76}
77
78impl<T> Sender<T> {
79    pub async fn send(&self, item: T) -> Result<(), error::SendError<T>> {
84        let result = self.chan.send(item).await;
85        if result.is_ok() {
87            self.semaphore
89                .acquire()
90                .await
91                .expect("semaphore is never closed")
92                .forget();
93        }
94        result.map_err(error::SendError::tokio_send_error)
95    }
96}
97
98#[derive(Debug)]
99pub struct Receiver<T> {
101    semaphore: Arc<Semaphore>,
102    chan: tokio::sync::mpsc::Receiver<T>,
103    needs_permit: bool,
104}
105
106impl<T> Receiver<T> {
107    pub async fn recv(&mut self) -> Option<T> {
109        poll_fn(|cx| self.poll_recv(cx)).await
110    }
111
112    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
113        let resp = self.chan.poll_recv(cx);
116        if self.needs_permit && matches!(resp, Poll::Pending) {
118            self.needs_permit = false;
119            self.semaphore.add_permits(1);
120        }
121
122        if matches!(resp, Poll::Ready(_)) {
123            self.needs_permit = true;
125        }
126        resp
127    }
128}
129
130#[cfg(test)]
131mod test {
132    use crate::future::rendezvous::channel;
133    use std::sync::{Arc, Mutex};
134
135    #[tokio::test]
136    async fn send_blocks_caller() {
137        let (tx, mut rx) = channel::<u8>();
138        let done = Arc::new(Mutex::new(0));
139        let idone = done.clone();
140        let send = tokio::spawn(async move {
141            *idone.lock().unwrap() = 1;
142            tx.send(0).await.unwrap();
143            *idone.lock().unwrap() = 2;
144            tx.send(1).await.unwrap();
145            *idone.lock().unwrap() = 3;
146        });
147        assert_eq!(*done.lock().unwrap(), 0);
148        assert_eq!(rx.recv().await, Some(0));
149        assert_eq!(*done.lock().unwrap(), 1);
150        assert_eq!(rx.recv().await, Some(1));
151        assert_eq!(*done.lock().unwrap(), 2);
152        assert_eq!(rx.recv().await, None);
153        assert_eq!(*done.lock().unwrap(), 3);
154        let _ = send.await;
155    }
156
157    #[tokio::test]
158    async fn send_errors_when_rx_dropped() {
159        let (tx, rx) = channel::<u8>();
160        drop(rx);
161        tx.send(0).await.expect_err("rx half dropped");
162    }
163}