aws_smithy_http/event_stream/
sender.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use aws_smithy_types::event_stream::Message;
11use bytes::Bytes;
12use futures_core::Stream;
13use std::error::Error as StdError;
14use std::fmt;
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::pin::Pin;
18use std::task::{Context, Poll};
19use tracing::trace;
20
21/// Wrapper for event stream items that may include an initial-request message.
22/// This is used internally to allow initial messages to flow through the signing pipeline.
23#[doc(hidden)]
24#[derive(Debug)]
25pub enum EventOrInitial<T> {
26    /// A regular event that needs marshalling and signing
27    Event(T),
28    /// An initial-request message that's already marshalled, just needs signing
29    InitialMessage(Message),
30}
31
32/// Input type for Event Streams.
33pub struct EventStreamSender<T, E> {
34    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
35}
36
37impl<T, E> Debug for EventStreamSender<T, E> {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        let name_t = std::any::type_name::<T>();
40        let name_e = std::any::type_name::<E>();
41        write!(f, "EventStreamSender<{name_t}, {name_e}>")
42    }
43}
44
45impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
46    /// Creates an `EventStreamSender` from a single item.
47    pub fn once(item: Result<T, E>) -> Self {
48        Self::from(futures_util::stream::once(async move { item }))
49    }
50}
51
52impl<T: Send + Sync, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
53    #[doc(hidden)]
54    pub fn into_body_stream(
55        self,
56        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
57        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
58        signer: impl SignMessage + Send + Sync + 'static,
59    ) -> MessageStreamAdapter<T, E> {
60        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
61    }
62
63    /// Extract the inner stream. This is used internally for composing streams.
64    #[doc(hidden)]
65    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
66        self.input_stream
67    }
68}
69
70impl<T, E, S> From<S> for EventStreamSender<T, E>
71where
72    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
73{
74    fn from(stream: S) -> Self {
75        EventStreamSender {
76            input_stream: Box::pin(stream),
77        }
78    }
79}
80
81/// An error that occurs within a message stream.
82#[derive(Debug)]
83pub struct MessageStreamError {
84    kind: MessageStreamErrorKind,
85    pub(crate) meta: ErrorMetadata,
86}
87
88#[derive(Debug)]
89enum MessageStreamErrorKind {
90    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
91}
92
93impl MessageStreamError {
94    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
95    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
96        Self {
97            meta: Default::default(),
98            kind: MessageStreamErrorKind::Unhandled(err.into()),
99        }
100    }
101
102    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
103    pub fn generic(err: ErrorMetadata) -> Self {
104        Self {
105            meta: err.clone(),
106            kind: MessageStreamErrorKind::Unhandled(err.into()),
107        }
108    }
109
110    /// Returns error metadata, which includes the error code, message,
111    /// request ID, and potentially additional information.
112    pub fn meta(&self) -> &ErrorMetadata {
113        &self.meta
114    }
115}
116
117impl StdError for MessageStreamError {
118    fn source(&self) -> Option<&(dyn StdError + 'static)> {
119        match &self.kind {
120            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
121        }
122    }
123}
124
125impl fmt::Display for MessageStreamError {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match &self.kind {
128            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
129        }
130    }
131}
132
133/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
134/// message marshaller and signer implementations.
135///
136/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
137/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
138#[allow(missing_debug_implementations)]
139pub struct MessageStreamAdapter<T: Send + Sync, E: StdError + Send + Sync + 'static> {
140    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
141    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
142    signer: Box<dyn SignMessage + Send + Sync>,
143    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
144    end_signal_sent: bool,
145    _phantom: PhantomData<E>,
146}
147
148impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
149
150impl<T: Send + Sync, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
151    /// Create a new `MessageStreamAdapter`.
152    pub fn new(
153        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
154        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
155        signer: impl SignMessage + Send + Sync + 'static,
156        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
157    ) -> Self {
158        MessageStreamAdapter {
159            marshaller: Box::new(marshaller),
160            error_marshaller: Box::new(error_marshaller),
161            signer: Box::new(signer),
162            stream,
163            end_signal_sent: false,
164            _phantom: Default::default(),
165        }
166    }
167}
168
169impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
170    type Item = Result<
171        http_body_1x::Frame<Bytes>,
172        SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
173    >;
174
175    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
176        match self.stream.as_mut().poll_next(cx) {
177            Poll::Ready(message_option) => {
178                if let Some(message_result) = message_option {
179                    let message = match message_result {
180                        Ok(message) => self
181                            .marshaller
182                            .marshall(message)
183                            .map_err(SdkError::construction_failure)?,
184                        Err(message) => self
185                            .error_marshaller
186                            .marshall(message)
187                            .map_err(SdkError::construction_failure)?,
188                    };
189
190                    trace!(unsigned_message = ?message, "signing event stream message");
191                    let message = self
192                        .signer
193                        .sign(message)
194                        .map_err(SdkError::construction_failure)?;
195
196                    let mut buffer = Vec::with_capacity(message.size_hint());
197                    write_message_to(&message, &mut buffer)
198                        .map_err(SdkError::construction_failure)?;
199                    trace!(signed_message = ?buffer, "sending signed event stream message");
200                    Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
201                } else if !self.end_signal_sent {
202                    self.end_signal_sent = true;
203                    match self.signer.sign_empty() {
204                        Some(sign) => {
205                            let message = sign.map_err(SdkError::construction_failure)?;
206                            let mut buffer = Vec::with_capacity(message.size_hint());
207                            write_message_to(&message, &mut buffer)
208                                .map_err(SdkError::construction_failure)?;
209                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
210                            Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
211                        }
212                        None => Poll::Ready(None),
213                    }
214                } else {
215                    Poll::Ready(None)
216                }
217            }
218            Poll::Pending => Poll::Pending,
219        }
220    }
221}
222
223/// Marshaller wrapper that handles both regular events and initial messages.
224/// This is used internally to support initial-request messages in event streams.
225#[doc(hidden)]
226#[derive(Debug)]
227pub struct EventOrInitialMarshaller<M> {
228    inner: M,
229}
230
231impl<M> EventOrInitialMarshaller<M> {
232    #[doc(hidden)]
233    pub fn new(inner: M) -> Self {
234        Self { inner }
235    }
236}
237
238impl<M, T> MarshallMessage for EventOrInitialMarshaller<M>
239where
240    M: MarshallMessage<Input = T>,
241{
242    type Input = EventOrInitial<T>;
243
244    fn marshall(
245        &self,
246        input: Self::Input,
247    ) -> Result<Message, aws_smithy_eventstream::error::Error> {
248        match input {
249            EventOrInitial::Event(event) => self.inner.marshall(event),
250            EventOrInitial::InitialMessage(message) => Ok(message),
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::MarshallMessage;
258    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
259    use async_stream::stream;
260    use aws_smithy_eventstream::error::Error as EventStreamError;
261    use aws_smithy_eventstream::frame::{
262        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
263    };
264    use aws_smithy_runtime_api::client::result::SdkError;
265    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
266    use futures_util::stream::StreamExt;
267    use std::error::Error as StdError;
268
269    #[derive(Debug, Eq, PartialEq)]
270    struct TestMessage(String);
271
272    #[derive(Debug)]
273    struct Marshaller;
274    impl MarshallMessage for Marshaller {
275        type Input = TestMessage;
276
277        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
278            Ok(Message::new(input.0.as_bytes().to_vec()))
279        }
280    }
281    #[derive(Debug)]
282    struct ErrorMarshaller;
283    impl MarshallMessage for ErrorMarshaller {
284        type Input = TestServiceError;
285
286        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
287            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
288        }
289    }
290
291    #[derive(Debug)]
292    struct TestServiceError;
293    impl std::fmt::Display for TestServiceError {
294        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295            write!(f, "TestServiceError")
296        }
297    }
298    impl StdError for TestServiceError {}
299
300    #[derive(Debug)]
301    struct TestSigner;
302    impl SignMessage for TestSigner {
303        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
304            let mut buffer = Vec::new();
305            write_message_to(&message, &mut buffer).unwrap();
306            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
307        }
308
309        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
310            Some(Ok(
311                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
312            ))
313        }
314    }
315
316    fn check_send_sync<T: Send + Sync>(value: T) -> T {
317        value
318    }
319
320    #[test]
321    fn event_stream_sender_send_sync() {
322        check_send_sync(EventStreamSender::from(stream! {
323            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
324        }));
325    }
326
327    #[tokio::test]
328    async fn message_stream_adapter_success() {
329        let stream = stream! {
330            yield Ok(TestMessage("test".into()));
331        };
332        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
333            Marshaller,
334            ErrorMarshaller,
335            TestSigner,
336            Box::pin(stream),
337        );
338
339        let sent_bytes = adapter.next().await.unwrap().unwrap();
340        let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
341        assert_eq!("signed", sent.headers()[0].name().as_str());
342        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
343        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
344        assert_eq!(&b"test"[..], &inner.payload()[..]);
345
346        let end_signal_bytes = adapter.next().await.unwrap().unwrap();
347        let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
348        assert_eq!("signed", end_signal.headers()[0].name().as_str());
349        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
350        assert_eq!(0, end_signal.payload().len());
351    }
352
353    #[tokio::test]
354    async fn message_stream_adapter_construction_failure() {
355        let stream = stream! {
356            yield Err(TestServiceError);
357        };
358        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
359            Marshaller,
360            ErrorMarshaller,
361            NoOpSigner {},
362            Box::pin(stream),
363        );
364
365        let result = adapter.next().await.unwrap();
366        assert!(result.is_err());
367        assert!(matches!(
368            result.err().unwrap(),
369            SdkError::ConstructionFailure(_)
370        ));
371    }
372
373    #[tokio::test]
374    async fn event_stream_sender_once() {
375        let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
376        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
377            Marshaller,
378            ErrorMarshaller,
379            TestSigner,
380            sender.input_stream,
381        );
382
383        let sent_bytes = adapter.next().await.unwrap().unwrap();
384        let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
385        assert_eq!("signed", sent.headers()[0].name().as_str());
386        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
387        assert_eq!(&b"test"[..], &inner.payload()[..]);
388
389        // Should get end signal next
390        let end_signal_bytes = adapter.next().await.unwrap().unwrap();
391        let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
392        assert_eq!("signed", end_signal.headers()[0].name().as_str());
393        assert_eq!(0, end_signal.payload().len());
394
395        // Stream should be exhausted
396        assert!(adapter.next().await.is_none());
397    }
398
399    // Verify the developer experience for this compiles
400    #[allow(unused)]
401    fn event_stream_input_ergonomics() {
402        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
403            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
404        }
405        check(stream! {
406            yield Ok(TestMessage("test".into()));
407        });
408        check(stream! {
409            yield Err(TestServiceError);
410        });
411    }
412}