aws_smithy_legacy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20/// Input type for Event Streams.
21pub struct EventStreamSender<T, E> {
22    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        let name_t = std::any::type_name::<T>();
28        let name_e = std::any::type_name::<E>();
29        write!(f, "EventStreamSender<{name_t}, {name_e}>")
30    }
31}
32
33impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34    /// Creates an `EventStreamSender` from a single item.
35    pub fn once(item: Result<T, E>) -> Self {
36        Self::from(futures_util::stream::once(async move { item }))
37    }
38}
39
40impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
41    #[doc(hidden)]
42    pub fn into_body_stream(
43        self,
44        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
45        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
46        signer: impl SignMessage + Send + Sync + 'static,
47    ) -> MessageStreamAdapter<T, E> {
48        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
49    }
50
51    /// Extract the inner stream. This is used internally for composing streams.
52    #[doc(hidden)]
53    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
54        self.input_stream
55    }
56}
57
58impl<T, E, S> From<S> for EventStreamSender<T, E>
59where
60    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
61{
62    fn from(stream: S) -> Self {
63        EventStreamSender {
64            input_stream: Box::pin(stream),
65        }
66    }
67}
68
69/// An error that occurs within a message stream.
70#[derive(Debug)]
71pub struct MessageStreamError {
72    kind: MessageStreamErrorKind,
73    pub(crate) meta: ErrorMetadata,
74}
75
76#[derive(Debug)]
77enum MessageStreamErrorKind {
78    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
79}
80
81impl MessageStreamError {
82    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
83    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
84        Self {
85            meta: Default::default(),
86            kind: MessageStreamErrorKind::Unhandled(err.into()),
87        }
88    }
89
90    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
91    pub fn generic(err: ErrorMetadata) -> Self {
92        Self {
93            meta: err.clone(),
94            kind: MessageStreamErrorKind::Unhandled(err.into()),
95        }
96    }
97
98    /// Returns error metadata, which includes the error code, message,
99    /// request ID, and potentially additional information.
100    pub fn meta(&self) -> &ErrorMetadata {
101        &self.meta
102    }
103}
104
105impl StdError for MessageStreamError {
106    fn source(&self) -> Option<&(dyn StdError + 'static)> {
107        match &self.kind {
108            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
109        }
110    }
111}
112
113impl fmt::Display for MessageStreamError {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        match &self.kind {
116            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
117        }
118    }
119}
120
121/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
122/// message marshaller and signer implementations.
123///
124/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
125/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
126#[allow(missing_debug_implementations)]
127pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
128    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
129    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
130    signer: Box<dyn SignMessage + Send + Sync>,
131    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
132    end_signal_sent: bool,
133    _phantom: PhantomData<E>,
134}
135
136impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
137
138impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
139    /// Create a new `MessageStreamAdapter`.
140    pub fn new(
141        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
142        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
143        signer: impl SignMessage + Send + Sync + 'static,
144        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
145    ) -> Self {
146        MessageStreamAdapter {
147            marshaller: Box::new(marshaller),
148            error_marshaller: Box::new(error_marshaller),
149            signer: Box::new(signer),
150            stream,
151            end_signal_sent: false,
152            _phantom: Default::default(),
153        }
154    }
155}
156
157impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
158    type Item =
159        Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
160
161    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162        match self.stream.as_mut().poll_next(cx) {
163            Poll::Ready(message_option) => {
164                if let Some(message_result) = message_option {
165                    let message = match message_result {
166                        Ok(message) => self
167                            .marshaller
168                            .marshall(message)
169                            .map_err(SdkError::construction_failure)?,
170                        Err(message) => self
171                            .error_marshaller
172                            .marshall(message)
173                            .map_err(SdkError::construction_failure)?,
174                    };
175
176                    trace!(unsigned_message = ?message, "signing event stream message");
177                    let message = self
178                        .signer
179                        .sign(message)
180                        .map_err(SdkError::construction_failure)?;
181
182                    let mut buffer = Vec::with_capacity(message.size_hint());
183                    write_message_to(&message, &mut buffer)
184                        .map_err(SdkError::construction_failure)?;
185                    trace!(signed_message = ?buffer, "sending signed event stream message");
186                    Poll::Ready(Some(Ok(Bytes::from(buffer))))
187                } else if !self.end_signal_sent {
188                    self.end_signal_sent = true;
189                    match self.signer.sign_empty() {
190                        Some(sign) => {
191                            let message = sign.map_err(SdkError::construction_failure)?;
192                            let mut buffer = Vec::with_capacity(message.size_hint());
193                            write_message_to(&message, &mut buffer)
194                                .map_err(SdkError::construction_failure)?;
195                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
196                            Poll::Ready(Some(Ok(Bytes::from(buffer))))
197                        }
198                        None => Poll::Ready(None),
199                    }
200                } else {
201                    Poll::Ready(None)
202                }
203            }
204            Poll::Pending => Poll::Pending,
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::MarshallMessage;
212    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
213    use async_stream::stream;
214    use aws_smithy_eventstream::error::Error as EventStreamError;
215    use aws_smithy_eventstream::frame::{
216        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
217    };
218    use aws_smithy_runtime_api::client::result::SdkError;
219    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
220    use bytes::Bytes;
221    use futures_core::Stream;
222    use futures_util::stream::StreamExt;
223    use std::error::Error as StdError;
224
225    #[derive(Debug, Eq, PartialEq)]
226    struct TestMessage(String);
227
228    #[derive(Debug)]
229    struct Marshaller;
230    impl MarshallMessage for Marshaller {
231        type Input = TestMessage;
232
233        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
234            Ok(Message::new(input.0.as_bytes().to_vec()))
235        }
236    }
237    #[derive(Debug)]
238    struct ErrorMarshaller;
239    impl MarshallMessage for ErrorMarshaller {
240        type Input = TestServiceError;
241
242        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
243            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
244        }
245    }
246
247    #[derive(Debug)]
248    struct TestServiceError;
249    impl std::fmt::Display for TestServiceError {
250        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251            write!(f, "TestServiceError")
252        }
253    }
254    impl StdError for TestServiceError {}
255
256    #[derive(Debug)]
257    struct TestSigner;
258    impl SignMessage for TestSigner {
259        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
260            let mut buffer = Vec::new();
261            write_message_to(&message, &mut buffer).unwrap();
262            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
263        }
264
265        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
266            Some(Ok(
267                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
268            ))
269        }
270    }
271
272    fn check_send_sync<T: Send + Sync>(value: T) -> T {
273        value
274    }
275
276    #[test]
277    fn event_stream_sender_send_sync() {
278        check_send_sync(EventStreamSender::from(stream! {
279            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
280        }));
281    }
282
283    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
284    where
285        S: Stream<Item = Result<O, E>> + Send + 'static,
286        O: Into<Bytes> + 'static,
287        E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
288    {
289        stream
290    }
291
292    #[tokio::test]
293    async fn message_stream_adapter_success() {
294        let stream = stream! {
295            yield Ok(TestMessage("test".into()));
296        };
297        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
298            TestMessage,
299            TestServiceError,
300        >::new(
301            Marshaller,
302            ErrorMarshaller,
303            TestSigner,
304            Box::pin(stream),
305        ));
306
307        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
308        let sent = read_message_from(&mut sent_bytes).unwrap();
309        assert_eq!("signed", sent.headers()[0].name().as_str());
310        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
311        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
312        assert_eq!(&b"test"[..], &inner.payload()[..]);
313
314        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
315        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
316        assert_eq!("signed", end_signal.headers()[0].name().as_str());
317        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
318        assert_eq!(0, end_signal.payload().len());
319    }
320
321    #[tokio::test]
322    async fn message_stream_adapter_construction_failure() {
323        let stream = stream! {
324            yield Err(TestServiceError);
325        };
326        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
327            TestMessage,
328            TestServiceError,
329        >::new(
330            Marshaller,
331            ErrorMarshaller,
332            NoOpSigner {},
333            Box::pin(stream),
334        ));
335
336        let result = adapter.next().await.unwrap();
337        assert!(result.is_err());
338        assert!(matches!(
339            result.err().unwrap(),
340            SdkError::ConstructionFailure(_)
341        ));
342    }
343
344    #[tokio::test]
345    async fn event_stream_sender_once() {
346        let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
347        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
348            Marshaller,
349            ErrorMarshaller,
350            TestSigner,
351            sender.input_stream,
352        );
353
354        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
355        let sent = read_message_from(&mut sent_bytes).unwrap();
356        assert_eq!("signed", sent.headers()[0].name().as_str());
357        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
358        assert_eq!(&b"test"[..], &inner.payload()[..]);
359
360        // Should get end signal next
361        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
362        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
363        assert_eq!("signed", end_signal.headers()[0].name().as_str());
364        assert_eq!(0, end_signal.payload().len());
365
366        // Stream should be exhausted
367        assert!(adapter.next().await.is_none());
368    }
369
370    // Verify the developer experience for this compiles
371    #[allow(unused)]
372    fn event_stream_input_ergonomics() {
373        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
374            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
375        }
376        check(stream! {
377            yield Ok(TestMessage("test".into()));
378        });
379        check(stream! {
380            yield Err(TestServiceError);
381        });
382    }
383}