aws_smithy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20/// Input type for Event Streams.
21pub struct EventStreamSender<T, E> {
22    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        let name_t = std::any::type_name::<T>();
28        let name_e = std::any::type_name::<E>();
29        write!(f, "EventStreamSender<{name_t}, {name_e}>")
30    }
31}
32
33impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34    /// Creates an `EventStreamSender` from a single item.
35    pub fn once(item: Result<T, E>) -> Self {
36        Self::from(futures_util::stream::once(async move { item }))
37    }
38}
39
40impl<T: Send + Sync, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
41    #[doc(hidden)]
42    pub fn into_body_stream(
43        self,
44        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
45        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
46        signer: impl SignMessage + Send + Sync + 'static,
47    ) -> MessageStreamAdapter<T, E> {
48        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
49    }
50}
51
52impl<T, E, S> From<S> for EventStreamSender<T, E>
53where
54    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
55{
56    fn from(stream: S) -> Self {
57        EventStreamSender {
58            input_stream: Box::pin(stream),
59        }
60    }
61}
62
63/// An error that occurs within a message stream.
64#[derive(Debug)]
65pub struct MessageStreamError {
66    kind: MessageStreamErrorKind,
67    pub(crate) meta: ErrorMetadata,
68}
69
70#[derive(Debug)]
71enum MessageStreamErrorKind {
72    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
73}
74
75impl MessageStreamError {
76    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
77    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
78        Self {
79            meta: Default::default(),
80            kind: MessageStreamErrorKind::Unhandled(err.into()),
81        }
82    }
83
84    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
85    pub fn generic(err: ErrorMetadata) -> Self {
86        Self {
87            meta: err.clone(),
88            kind: MessageStreamErrorKind::Unhandled(err.into()),
89        }
90    }
91
92    /// Returns error metadata, which includes the error code, message,
93    /// request ID, and potentially additional information.
94    pub fn meta(&self) -> &ErrorMetadata {
95        &self.meta
96    }
97}
98
99impl StdError for MessageStreamError {
100    fn source(&self) -> Option<&(dyn StdError + 'static)> {
101        match &self.kind {
102            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
103        }
104    }
105}
106
107impl fmt::Display for MessageStreamError {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match &self.kind {
110            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
111        }
112    }
113}
114
115/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
116/// message marshaller and signer implementations.
117///
118/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
119/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
120#[allow(missing_debug_implementations)]
121pub struct MessageStreamAdapter<T: Send + Sync, E: StdError + Send + Sync + 'static> {
122    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
123    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
124    signer: Box<dyn SignMessage + Send + Sync>,
125    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
126    end_signal_sent: bool,
127    _phantom: PhantomData<E>,
128}
129
130impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
131
132impl<T: Send + Sync, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
133    /// Create a new `MessageStreamAdapter`.
134    pub fn new(
135        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
136        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
137        signer: impl SignMessage + Send + Sync + 'static,
138        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
139    ) -> Self {
140        MessageStreamAdapter {
141            marshaller: Box::new(marshaller),
142            error_marshaller: Box::new(error_marshaller),
143            signer: Box::new(signer),
144            stream,
145            end_signal_sent: false,
146            _phantom: Default::default(),
147        }
148    }
149}
150
151impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
152    type Item = Result<
153        http_body_1x::Frame<Bytes>,
154        SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
155    >;
156
157    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158        match self.stream.as_mut().poll_next(cx) {
159            Poll::Ready(message_option) => {
160                if let Some(message_result) = message_option {
161                    let message = match message_result {
162                        Ok(message) => self
163                            .marshaller
164                            .marshall(message)
165                            .map_err(SdkError::construction_failure)?,
166                        Err(message) => self
167                            .error_marshaller
168                            .marshall(message)
169                            .map_err(SdkError::construction_failure)?,
170                    };
171
172                    trace!(unsigned_message = ?message, "signing event stream message");
173                    let message = self
174                        .signer
175                        .sign(message)
176                        .map_err(SdkError::construction_failure)?;
177
178                    let mut buffer = Vec::with_capacity(message.size_hint());
179                    write_message_to(&message, &mut buffer)
180                        .map_err(SdkError::construction_failure)?;
181                    trace!(signed_message = ?buffer, "sending signed event stream message");
182                    Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
183                } else if !self.end_signal_sent {
184                    self.end_signal_sent = true;
185                    match self.signer.sign_empty() {
186                        Some(sign) => {
187                            let message = sign.map_err(SdkError::construction_failure)?;
188                            let mut buffer = Vec::with_capacity(message.size_hint());
189                            write_message_to(&message, &mut buffer)
190                                .map_err(SdkError::construction_failure)?;
191                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
192                            Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
193                        }
194                        None => Poll::Ready(None),
195                    }
196                } else {
197                    Poll::Ready(None)
198                }
199            }
200            Poll::Pending => Poll::Pending,
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::MarshallMessage;
208    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
209    use async_stream::stream;
210    use aws_smithy_eventstream::error::Error as EventStreamError;
211    use aws_smithy_eventstream::frame::{
212        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
213    };
214    use aws_smithy_runtime_api::client::result::SdkError;
215    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
216    use futures_util::stream::StreamExt;
217    use std::error::Error as StdError;
218
219    #[derive(Debug, Eq, PartialEq)]
220    struct TestMessage(String);
221
222    #[derive(Debug)]
223    struct Marshaller;
224    impl MarshallMessage for Marshaller {
225        type Input = TestMessage;
226
227        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
228            Ok(Message::new(input.0.as_bytes().to_vec()))
229        }
230    }
231    #[derive(Debug)]
232    struct ErrorMarshaller;
233    impl MarshallMessage for ErrorMarshaller {
234        type Input = TestServiceError;
235
236        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
237            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
238        }
239    }
240
241    #[derive(Debug)]
242    struct TestServiceError;
243    impl std::fmt::Display for TestServiceError {
244        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245            write!(f, "TestServiceError")
246        }
247    }
248    impl StdError for TestServiceError {}
249
250    #[derive(Debug)]
251    struct TestSigner;
252    impl SignMessage for TestSigner {
253        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
254            let mut buffer = Vec::new();
255            write_message_to(&message, &mut buffer).unwrap();
256            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
257        }
258
259        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
260            Some(Ok(
261                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
262            ))
263        }
264    }
265
266    fn check_send_sync<T: Send + Sync>(value: T) -> T {
267        value
268    }
269
270    #[test]
271    fn event_stream_sender_send_sync() {
272        check_send_sync(EventStreamSender::from(stream! {
273            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
274        }));
275    }
276
277    #[tokio::test]
278    async fn message_stream_adapter_success() {
279        let stream = stream! {
280            yield Ok(TestMessage("test".into()));
281        };
282        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
283            Marshaller,
284            ErrorMarshaller,
285            TestSigner,
286            Box::pin(stream),
287        );
288
289        let sent_bytes = adapter.next().await.unwrap().unwrap();
290        let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
291        assert_eq!("signed", sent.headers()[0].name().as_str());
292        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
293        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
294        assert_eq!(&b"test"[..], &inner.payload()[..]);
295
296        let end_signal_bytes = adapter.next().await.unwrap().unwrap();
297        let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
298        assert_eq!("signed", end_signal.headers()[0].name().as_str());
299        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
300        assert_eq!(0, end_signal.payload().len());
301    }
302
303    #[tokio::test]
304    async fn message_stream_adapter_construction_failure() {
305        let stream = stream! {
306            yield Err(TestServiceError);
307        };
308        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
309            Marshaller,
310            ErrorMarshaller,
311            NoOpSigner {},
312            Box::pin(stream),
313        );
314
315        let result = adapter.next().await.unwrap();
316        assert!(result.is_err());
317        assert!(matches!(
318            result.err().unwrap(),
319            SdkError::ConstructionFailure(_)
320        ));
321    }
322
323    #[tokio::test]
324    async fn event_stream_sender_once() {
325        let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
326        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
327            Marshaller,
328            ErrorMarshaller,
329            TestSigner,
330            sender.input_stream,
331        );
332
333        let sent_bytes = adapter.next().await.unwrap().unwrap();
334        let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
335        assert_eq!("signed", sent.headers()[0].name().as_str());
336        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
337        assert_eq!(&b"test"[..], &inner.payload()[..]);
338
339        // Should get end signal next
340        let end_signal_bytes = adapter.next().await.unwrap().unwrap();
341        let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
342        assert_eq!("signed", end_signal.headers()[0].name().as_str());
343        assert_eq!(0, end_signal.payload().len());
344
345        // Stream should be exhausted
346        assert!(adapter.next().await.is_none());
347    }
348
349    // Verify the developer experience for this compiles
350    #[allow(unused)]
351    fn event_stream_input_ergonomics() {
352        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
353            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
354        }
355        check(stream! {
356            yield Ok(TestMessage("test".into()));
357        });
358        check(stream! {
359            yield Err(TestServiceError);
360        });
361    }
362}