aws_smithy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20/// Input type for Event Streams.
21pub struct EventStreamSender<T, E> {
22    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        let name_t = std::any::type_name::<T>();
28        let name_e = std::any::type_name::<E>();
29        write!(f, "EventStreamSender<{name_t}, {name_e}>")
30    }
31}
32
33impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34    /// Creates an `EventStreamSender` from a single item.
35    pub fn once(item: Result<T, E>) -> Self {
36        Self::from(futures_util::stream::once(async move { item }))
37    }
38}
39
40impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
41    #[doc(hidden)]
42    pub fn into_body_stream(
43        self,
44        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
45        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
46        signer: impl SignMessage + Send + Sync + 'static,
47    ) -> MessageStreamAdapter<T, E> {
48        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
49    }
50}
51
52impl<T, E, S> From<S> for EventStreamSender<T, E>
53where
54    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
55{
56    fn from(stream: S) -> Self {
57        EventStreamSender {
58            input_stream: Box::pin(stream),
59        }
60    }
61}
62
63/// An error that occurs within a message stream.
64#[derive(Debug)]
65pub struct MessageStreamError {
66    kind: MessageStreamErrorKind,
67    pub(crate) meta: ErrorMetadata,
68}
69
70#[derive(Debug)]
71enum MessageStreamErrorKind {
72    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
73}
74
75impl MessageStreamError {
76    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
77    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
78        Self {
79            meta: Default::default(),
80            kind: MessageStreamErrorKind::Unhandled(err.into()),
81        }
82    }
83
84    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
85    pub fn generic(err: ErrorMetadata) -> Self {
86        Self {
87            meta: err.clone(),
88            kind: MessageStreamErrorKind::Unhandled(err.into()),
89        }
90    }
91
92    /// Returns error metadata, which includes the error code, message,
93    /// request ID, and potentially additional information.
94    pub fn meta(&self) -> &ErrorMetadata {
95        &self.meta
96    }
97}
98
99impl StdError for MessageStreamError {
100    fn source(&self) -> Option<&(dyn StdError + 'static)> {
101        match &self.kind {
102            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
103        }
104    }
105}
106
107impl fmt::Display for MessageStreamError {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match &self.kind {
110            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
111        }
112    }
113}
114
115/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
116/// message marshaller and signer implementations.
117///
118/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
119/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
120#[allow(missing_debug_implementations)]
121pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
122    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
123    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
124    signer: Box<dyn SignMessage + Send + Sync>,
125    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
126    end_signal_sent: bool,
127    _phantom: PhantomData<E>,
128}
129
130impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
131
132impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
133    /// Create a new `MessageStreamAdapter`.
134    pub fn new(
135        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
136        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
137        signer: impl SignMessage + Send + Sync + 'static,
138        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
139    ) -> Self {
140        MessageStreamAdapter {
141            marshaller: Box::new(marshaller),
142            error_marshaller: Box::new(error_marshaller),
143            signer: Box::new(signer),
144            stream,
145            end_signal_sent: false,
146            _phantom: Default::default(),
147        }
148    }
149}
150
151impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
152    type Item =
153        Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
154
155    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156        match self.stream.as_mut().poll_next(cx) {
157            Poll::Ready(message_option) => {
158                if let Some(message_result) = message_option {
159                    let message = match message_result {
160                        Ok(message) => self
161                            .marshaller
162                            .marshall(message)
163                            .map_err(SdkError::construction_failure)?,
164                        Err(message) => self
165                            .error_marshaller
166                            .marshall(message)
167                            .map_err(SdkError::construction_failure)?,
168                    };
169
170                    trace!(unsigned_message = ?message, "signing event stream message");
171                    let message = self
172                        .signer
173                        .sign(message)
174                        .map_err(SdkError::construction_failure)?;
175
176                    let mut buffer = Vec::with_capacity(message.size_hint());
177                    write_message_to(&message, &mut buffer)
178                        .map_err(SdkError::construction_failure)?;
179                    trace!(signed_message = ?buffer, "sending signed event stream message");
180                    Poll::Ready(Some(Ok(Bytes::from(buffer))))
181                } else if !self.end_signal_sent {
182                    self.end_signal_sent = true;
183                    match self.signer.sign_empty() {
184                        Some(sign) => {
185                            let message = sign.map_err(SdkError::construction_failure)?;
186                            let mut buffer = Vec::with_capacity(message.size_hint());
187                            write_message_to(&message, &mut buffer)
188                                .map_err(SdkError::construction_failure)?;
189                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
190                            Poll::Ready(Some(Ok(Bytes::from(buffer))))
191                        }
192                        None => Poll::Ready(None),
193                    }
194                } else {
195                    Poll::Ready(None)
196                }
197            }
198            Poll::Pending => Poll::Pending,
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::MarshallMessage;
206    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
207    use async_stream::stream;
208    use aws_smithy_eventstream::error::Error as EventStreamError;
209    use aws_smithy_eventstream::frame::{
210        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
211    };
212    use aws_smithy_runtime_api::client::result::SdkError;
213    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
214    use bytes::Bytes;
215    use futures_core::Stream;
216    use futures_util::stream::StreamExt;
217    use std::error::Error as StdError;
218
219    #[derive(Debug, Eq, PartialEq)]
220    struct TestMessage(String);
221
222    #[derive(Debug)]
223    struct Marshaller;
224    impl MarshallMessage for Marshaller {
225        type Input = TestMessage;
226
227        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
228            Ok(Message::new(input.0.as_bytes().to_vec()))
229        }
230    }
231    #[derive(Debug)]
232    struct ErrorMarshaller;
233    impl MarshallMessage for ErrorMarshaller {
234        type Input = TestServiceError;
235
236        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
237            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
238        }
239    }
240
241    #[derive(Debug)]
242    struct TestServiceError;
243    impl std::fmt::Display for TestServiceError {
244        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245            write!(f, "TestServiceError")
246        }
247    }
248    impl StdError for TestServiceError {}
249
250    #[derive(Debug)]
251    struct TestSigner;
252    impl SignMessage for TestSigner {
253        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
254            let mut buffer = Vec::new();
255            write_message_to(&message, &mut buffer).unwrap();
256            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
257        }
258
259        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
260            Some(Ok(
261                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
262            ))
263        }
264    }
265
266    fn check_send_sync<T: Send + Sync>(value: T) -> T {
267        value
268    }
269
270    #[test]
271    fn event_stream_sender_send_sync() {
272        check_send_sync(EventStreamSender::from(stream! {
273            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
274        }));
275    }
276
277    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
278    where
279        S: Stream<Item = Result<O, E>> + Send + 'static,
280        O: Into<Bytes> + 'static,
281        E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
282    {
283        stream
284    }
285
286    #[tokio::test]
287    async fn message_stream_adapter_success() {
288        let stream = stream! {
289            yield Ok(TestMessage("test".into()));
290        };
291        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
292            TestMessage,
293            TestServiceError,
294        >::new(
295            Marshaller,
296            ErrorMarshaller,
297            TestSigner,
298            Box::pin(stream),
299        ));
300
301        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
302        let sent = read_message_from(&mut sent_bytes).unwrap();
303        assert_eq!("signed", sent.headers()[0].name().as_str());
304        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
305        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
306        assert_eq!(&b"test"[..], &inner.payload()[..]);
307
308        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
309        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
310        assert_eq!("signed", end_signal.headers()[0].name().as_str());
311        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
312        assert_eq!(0, end_signal.payload().len());
313    }
314
315    #[tokio::test]
316    async fn message_stream_adapter_construction_failure() {
317        let stream = stream! {
318            yield Err(TestServiceError);
319        };
320        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
321            TestMessage,
322            TestServiceError,
323        >::new(
324            Marshaller,
325            ErrorMarshaller,
326            NoOpSigner {},
327            Box::pin(stream),
328        ));
329
330        let result = adapter.next().await.unwrap();
331        assert!(result.is_err());
332        assert!(matches!(
333            result.err().unwrap(),
334            SdkError::ConstructionFailure(_)
335        ));
336    }
337
338    #[tokio::test]
339    async fn event_stream_sender_once() {
340        let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
341        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
342            Marshaller,
343            ErrorMarshaller,
344            TestSigner,
345            sender.input_stream,
346        );
347
348        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
349        let sent = read_message_from(&mut sent_bytes).unwrap();
350        assert_eq!("signed", sent.headers()[0].name().as_str());
351        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
352        assert_eq!(&b"test"[..], &inner.payload()[..]);
353
354        // Should get end signal next
355        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
356        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
357        assert_eq!("signed", end_signal.headers()[0].name().as_str());
358        assert_eq!(0, end_signal.payload().len());
359
360        // Stream should be exhausted
361        assert!(adapter.next().await.is_none());
362    }
363
364    // Verify the developer experience for this compiles
365    #[allow(unused)]
366    fn event_stream_input_ergonomics() {
367        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
368            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
369        }
370        check(stream! {
371            yield Ok(TestMessage("test".into()));
372        });
373        check(stream! {
374            yield Err(TestServiceError);
375        });
376    }
377}