aws_smithy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20/// Input type for Event Streams.
21pub struct EventStreamSender<T, E> {
22    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        let name_t = std::any::type_name::<T>();
28        let name_e = std::any::type_name::<E>();
29        write!(f, "EventStreamSender<{name_t}, {name_e}>")
30    }
31}
32
33impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34    #[doc(hidden)]
35    pub fn into_body_stream(
36        self,
37        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
38        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
39        signer: impl SignMessage + Send + Sync + 'static,
40    ) -> MessageStreamAdapter<T, E> {
41        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
42    }
43}
44
45impl<T, E, S> From<S> for EventStreamSender<T, E>
46where
47    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
48{
49    fn from(stream: S) -> Self {
50        EventStreamSender {
51            input_stream: Box::pin(stream),
52        }
53    }
54}
55
56/// An error that occurs within a message stream.
57#[derive(Debug)]
58pub struct MessageStreamError {
59    kind: MessageStreamErrorKind,
60    pub(crate) meta: ErrorMetadata,
61}
62
63#[derive(Debug)]
64enum MessageStreamErrorKind {
65    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
66}
67
68impl MessageStreamError {
69    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
70    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
71        Self {
72            meta: Default::default(),
73            kind: MessageStreamErrorKind::Unhandled(err.into()),
74        }
75    }
76
77    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
78    pub fn generic(err: ErrorMetadata) -> Self {
79        Self {
80            meta: err.clone(),
81            kind: MessageStreamErrorKind::Unhandled(err.into()),
82        }
83    }
84
85    /// Returns error metadata, which includes the error code, message,
86    /// request ID, and potentially additional information.
87    pub fn meta(&self) -> &ErrorMetadata {
88        &self.meta
89    }
90}
91
92impl StdError for MessageStreamError {
93    fn source(&self) -> Option<&(dyn StdError + 'static)> {
94        match &self.kind {
95            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
96        }
97    }
98}
99
100impl fmt::Display for MessageStreamError {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match &self.kind {
103            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
104        }
105    }
106}
107
108/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
109/// message marshaller and signer implementations.
110///
111/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
112/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
113#[allow(missing_debug_implementations)]
114pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
115    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
116    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
117    signer: Box<dyn SignMessage + Send + Sync>,
118    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
119    end_signal_sent: bool,
120    _phantom: PhantomData<E>,
121}
122
123impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
124
125impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
126    /// Create a new `MessageStreamAdapter`.
127    pub fn new(
128        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
129        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
130        signer: impl SignMessage + Send + Sync + 'static,
131        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
132    ) -> Self {
133        MessageStreamAdapter {
134            marshaller: Box::new(marshaller),
135            error_marshaller: Box::new(error_marshaller),
136            signer: Box::new(signer),
137            stream,
138            end_signal_sent: false,
139            _phantom: Default::default(),
140        }
141    }
142}
143
144impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
145    type Item =
146        Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
147
148    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        match self.stream.as_mut().poll_next(cx) {
150            Poll::Ready(message_option) => {
151                if let Some(message_result) = message_option {
152                    let message = match message_result {
153                        Ok(message) => self
154                            .marshaller
155                            .marshall(message)
156                            .map_err(SdkError::construction_failure)?,
157                        Err(message) => self
158                            .error_marshaller
159                            .marshall(message)
160                            .map_err(SdkError::construction_failure)?,
161                    };
162
163                    trace!(unsigned_message = ?message, "signing event stream message");
164                    let message = self
165                        .signer
166                        .sign(message)
167                        .map_err(SdkError::construction_failure)?;
168
169                    let mut buffer = Vec::with_capacity(message.size_hint());
170                    write_message_to(&message, &mut buffer)
171                        .map_err(SdkError::construction_failure)?;
172                    trace!(signed_message = ?buffer, "sending signed event stream message");
173                    Poll::Ready(Some(Ok(Bytes::from(buffer))))
174                } else if !self.end_signal_sent {
175                    self.end_signal_sent = true;
176                    match self.signer.sign_empty() {
177                        Some(sign) => {
178                            let message = sign.map_err(SdkError::construction_failure)?;
179                            let mut buffer = Vec::with_capacity(message.size_hint());
180                            write_message_to(&message, &mut buffer)
181                                .map_err(SdkError::construction_failure)?;
182                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
183                            Poll::Ready(Some(Ok(Bytes::from(buffer))))
184                        }
185                        None => Poll::Ready(None),
186                    }
187                } else {
188                    Poll::Ready(None)
189                }
190            }
191            Poll::Pending => Poll::Pending,
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::MarshallMessage;
199    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
200    use async_stream::stream;
201    use aws_smithy_eventstream::error::Error as EventStreamError;
202    use aws_smithy_eventstream::frame::{
203        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
204    };
205    use aws_smithy_runtime_api::client::result::SdkError;
206    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
207    use bytes::Bytes;
208    use futures_core::Stream;
209    use futures_util::stream::StreamExt;
210    use std::error::Error as StdError;
211
212    #[derive(Debug, Eq, PartialEq)]
213    struct TestMessage(String);
214
215    #[derive(Debug)]
216    struct Marshaller;
217    impl MarshallMessage for Marshaller {
218        type Input = TestMessage;
219
220        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
221            Ok(Message::new(input.0.as_bytes().to_vec()))
222        }
223    }
224    #[derive(Debug)]
225    struct ErrorMarshaller;
226    impl MarshallMessage for ErrorMarshaller {
227        type Input = TestServiceError;
228
229        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
230            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
231        }
232    }
233
234    #[derive(Debug)]
235    struct TestServiceError;
236    impl std::fmt::Display for TestServiceError {
237        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238            write!(f, "TestServiceError")
239        }
240    }
241    impl StdError for TestServiceError {}
242
243    #[derive(Debug)]
244    struct TestSigner;
245    impl SignMessage for TestSigner {
246        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
247            let mut buffer = Vec::new();
248            write_message_to(&message, &mut buffer).unwrap();
249            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
250        }
251
252        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
253            Some(Ok(
254                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
255            ))
256        }
257    }
258
259    fn check_send_sync<T: Send + Sync>(value: T) -> T {
260        value
261    }
262
263    #[test]
264    fn event_stream_sender_send_sync() {
265        check_send_sync(EventStreamSender::from(stream! {
266            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
267        }));
268    }
269
270    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
271    where
272        S: Stream<Item = Result<O, E>> + Send + 'static,
273        O: Into<Bytes> + 'static,
274        E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
275    {
276        stream
277    }
278
279    #[tokio::test]
280    async fn message_stream_adapter_success() {
281        let stream = stream! {
282            yield Ok(TestMessage("test".into()));
283        };
284        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
285            TestMessage,
286            TestServiceError,
287        >::new(
288            Marshaller,
289            ErrorMarshaller,
290            TestSigner,
291            Box::pin(stream),
292        ));
293
294        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
295        let sent = read_message_from(&mut sent_bytes).unwrap();
296        assert_eq!("signed", sent.headers()[0].name().as_str());
297        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
298        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
299        assert_eq!(&b"test"[..], &inner.payload()[..]);
300
301        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
302        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
303        assert_eq!("signed", end_signal.headers()[0].name().as_str());
304        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
305        assert_eq!(0, end_signal.payload().len());
306    }
307
308    #[tokio::test]
309    async fn message_stream_adapter_construction_failure() {
310        let stream = stream! {
311            yield Err(TestServiceError);
312        };
313        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
314            TestMessage,
315            TestServiceError,
316        >::new(
317            Marshaller,
318            ErrorMarshaller,
319            NoOpSigner {},
320            Box::pin(stream),
321        ));
322
323        let result = adapter.next().await.unwrap();
324        assert!(result.is_err());
325        assert!(matches!(
326            result.err().unwrap(),
327            SdkError::ConstructionFailure(_)
328        ));
329    }
330
331    // Verify the developer experience for this compiles
332    #[allow(unused)]
333    fn event_stream_input_ergonomics() {
334        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
335            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
336        }
337        check(stream! {
338            yield Ok(TestMessage("test".into()));
339        });
340        check(stream! {
341            yield Err(TestServiceError);
342        });
343    }
344}