aws_smithy_async/future/
rendezvous.rs1use std::future::poll_fn;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use tokio::sync::Semaphore;
19
20pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
26 let (tx, rx) = tokio::sync::mpsc::channel(1);
27 let semaphore = Arc::new(Semaphore::new(0));
28 (
29 Sender {
30 semaphore: semaphore.clone(),
31 chan: tx,
32 },
33 Receiver {
34 semaphore,
35 chan: rx,
36 needs_permit: false,
37 },
38 )
39}
40
41pub mod error {
43 use std::fmt;
44 use tokio::sync::mpsc::error::SendError as TokioSendError;
45
46 #[derive(Debug)]
48 pub struct SendError<T> {
49 source: TokioSendError<T>,
50 }
51
52 impl<T> SendError<T> {
53 pub(crate) fn tokio_send_error(source: TokioSendError<T>) -> Self {
54 Self { source }
55 }
56 }
57
58 impl<T> fmt::Display for SendError<T> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 write!(f, "failed to send value to the receiver")
61 }
62 }
63
64 impl<T: fmt::Debug + 'static> std::error::Error for SendError<T> {
65 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
66 Some(&self.source)
67 }
68 }
69}
70
71#[derive(Debug)]
72pub struct Sender<T> {
74 semaphore: Arc<Semaphore>,
75 chan: tokio::sync::mpsc::Sender<T>,
76}
77
78impl<T> Sender<T> {
79 pub async fn send(&self, item: T) -> Result<(), error::SendError<T>> {
84 let result = self.chan.send(item).await;
85 if result.is_ok() {
87 self.semaphore
89 .acquire()
90 .await
91 .expect("semaphore is never closed")
92 .forget();
93 }
94 result.map_err(error::SendError::tokio_send_error)
95 }
96}
97
98#[derive(Debug)]
99pub struct Receiver<T> {
101 semaphore: Arc<Semaphore>,
102 chan: tokio::sync::mpsc::Receiver<T>,
103 needs_permit: bool,
104}
105
106impl<T> Receiver<T> {
107 pub async fn recv(&mut self) -> Option<T> {
109 poll_fn(|cx| self.poll_recv(cx)).await
110 }
111
112 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
113 let resp = self.chan.poll_recv(cx);
116 if self.needs_permit && matches!(resp, Poll::Pending) {
118 self.needs_permit = false;
119 self.semaphore.add_permits(1);
120 }
121
122 if matches!(resp, Poll::Ready(_)) {
123 self.needs_permit = true;
125 }
126 resp
127 }
128}
129
130#[cfg(test)]
131mod test {
132 use crate::future::rendezvous::channel;
133 use std::sync::{Arc, Mutex};
134
135 #[tokio::test]
136 async fn send_blocks_caller() {
137 let (tx, mut rx) = channel::<u8>();
138 let done = Arc::new(Mutex::new(0));
139 let idone = done.clone();
140 let send = tokio::spawn(async move {
141 *idone.lock().unwrap() = 1;
142 tx.send(0).await.unwrap();
143 *idone.lock().unwrap() = 2;
144 tx.send(1).await.unwrap();
145 *idone.lock().unwrap() = 3;
146 });
147 assert_eq!(*done.lock().unwrap(), 0);
148 assert_eq!(rx.recv().await, Some(0));
149 assert_eq!(*done.lock().unwrap(), 1);
150 assert_eq!(rx.recv().await, Some(1));
151 assert_eq!(*done.lock().unwrap(), 2);
152 assert_eq!(rx.recv().await, None);
153 assert_eq!(*done.lock().unwrap(), 3);
154 let _ = send.await;
155 }
156
157 #[tokio::test]
158 async fn send_errors_when_rx_dropped() {
159 let (tx, rx) = channel::<u8>();
160 drop(rx);
161 tx.send(0).await.expect_err("rx half dropped");
162 }
163}