aws_smithy_http/event_stream/
receiver.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{
7    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21/// Wrapper around SegmentedBuf that tracks the state of the stream.
22#[derive(Debug)]
23enum RecvBuf {
24    /// Nothing has been buffered yet.
25    Empty,
26    /// Some data has been buffered.
27    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
28    Partial(SegmentedBuf<Bytes>),
29    /// The end of the stream has been reached, but there may still be some buffered data.
30    EosPartial(SegmentedBuf<Bytes>),
31    /// An exception terminated this stream.
32    Terminated,
33}
34
35impl RecvBuf {
36    /// Returns true if there's more buffered data.
37    fn has_data(&self) -> bool {
38        match self {
39            RecvBuf::Empty | RecvBuf::Terminated => false,
40            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41        }
42    }
43
44    /// Returns true if the stream has ended.
45    fn is_eos(&self) -> bool {
46        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47    }
48
49    /// Returns a mutable reference to the underlying buffered data.
50    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51        match self {
52            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53            RecvBuf::Partial(segmented) => segmented,
54            RecvBuf::EosPartial(segmented) => segmented,
55            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56        }
57    }
58
59    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
60    /// if the `RecvBuf` was previously empty.
61    fn with_partial(self, partial: Bytes) -> Self {
62        match self {
63            RecvBuf::Empty => {
64                let mut segmented = SegmentedBuf::new();
65                segmented.push(partial);
66                RecvBuf::Partial(segmented)
67            }
68            RecvBuf::Partial(mut segmented) => {
69                segmented.push(partial);
70                RecvBuf::Partial(segmented)
71            }
72            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74            }
75        }
76    }
77
78    /// Returns a `RecvBuf` that has reached end of stream.
79    fn ended(self) -> Self {
80        match self {
81            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85        }
86    }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91    /// The stream ended before a complete message frame was received.
92    UnexpectedEndOfStream,
93}
94
95/// An error that occurs within an event stream receiver.
96#[derive(Debug)]
97pub struct ReceiverError {
98    kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self.kind {
104            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105        }
106    }
107}
108
109impl StdError for ReceiverError {}
110
111/// Receives Smithy-modeled messages out of an Event Stream.
112#[derive(Debug)]
113pub struct Receiver<T, E> {
114    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115    decoder: MessageFrameDecoder,
116    buffer: RecvBuf,
117    body: SdkBody,
118    /// Event Stream has optional initial response frames an with `:message-type` of
119    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
120    /// initial response, then the message will be stored in `buffered_message` so that it can
121    /// be returned with the next call of `recv()`.
122    buffered_message: Option<Message>,
123    _phantom: PhantomData<E>,
124}
125
126// Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
127#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130    Request,
131    Response,
132}
133
134impl InitialMessageType {
135    fn as_str(&self) -> &'static str {
136        match self {
137            InitialMessageType::Request => "initial-request",
138            InitialMessageType::Response => "initial-response",
139        }
140    }
141}
142
143impl<T, E> Receiver<T, E> {
144    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
145    pub fn new(
146        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147        body: SdkBody,
148    ) -> Self {
149        Receiver {
150            unmarshaller: Box::new(unmarshaller),
151            decoder: MessageFrameDecoder::new(),
152            buffer: RecvBuf::Empty,
153            body,
154            buffered_message: None,
155            _phantom: Default::default(),
156        }
157    }
158
159    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160        match self.unmarshaller.unmarshall(&message) {
161            Ok(unmarshalled) => match unmarshalled {
162                UnmarshalledMessage::Event(event) => Ok(Some(event)),
163                UnmarshalledMessage::Error(err) => {
164                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165                }
166            },
167            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168        }
169    }
170
171    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172        use http_body_util::BodyExt;
173
174        if !self.buffer.is_eos() {
175            let next_chunk = self
176                .body
177                .frame()
178                .await
179                .transpose()
180                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182            if let Some(chunk) = next_chunk {
183                // Ignoring the possibility of trailers here since event_streams don't have them
184                if let Ok(data) = chunk.into_data() {
185                    self.buffer = buffer.with_partial(data);
186                }
187            } else {
188                self.buffer = buffer.ended();
189            }
190        }
191        Ok(())
192    }
193
194    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
195        while !self.buffer.is_eos() {
196            if self.buffer.has_data() {
197                if let DecodedFrame::Complete(message) = self
198                    .decoder
199                    .decode_frame(self.buffer.buffered())
200                    .map_err(|err| {
201                        SdkError::response_error(
202                            err,
203                            // the buffer has been consumed
204                            RawMessage::Invalid(None),
205                        )
206                    })?
207                {
208                    trace!(message = ?message, "received complete event stream message");
209                    return Ok(Some(message));
210                }
211            }
212
213            self.buffer_next_chunk().await?;
214        }
215        if self.buffer.has_data() {
216            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
217            let buf = self.buffer.buffered();
218            return Err(SdkError::response_error(
219                ReceiverError {
220                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
221                },
222                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
223            ));
224        }
225        Ok(None)
226    }
227
228    /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
229    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
230    #[doc(hidden)]
231    pub async fn try_recv_initial(
232        &mut self,
233        message_type: InitialMessageType,
234    ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
235        self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
236            .await
237            .map(|opt| opt.map(|(msg, _)| msg))
238    }
239
240    /// Tries to receive the initial response message with preprocessing support.
241    ///
242    /// The preprocessor function can transform the raw message (e.g., unwrap envelopes)
243    /// and return metadata alongside the transformed message. If the transformed message
244    /// matches the expected `message_type`, both the message and metadata are returned.
245    /// Otherwise, the transformed message is buffered and `Ok(None)` is returned.
246    #[doc(hidden)]
247    pub async fn try_recv_initial_with_preprocessor<F, M>(
248        &mut self,
249        message_type: InitialMessageType,
250        preprocessor: F,
251    ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
252    where
253        F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
254    {
255        if let Some(message) = self.next_message().await? {
256            let (processed_message, metadata) =
257                preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
258
259            if let Some(event_type) = processed_message
260                .headers()
261                .iter()
262                .find(|h| h.name().as_str() == ":event-type")
263            {
264                if event_type
265                    .value()
266                    .as_string()
267                    .map(|s| s.as_str() == message_type.as_str())
268                    .unwrap_or(false)
269                {
270                    return Ok(Some((processed_message, metadata)));
271                }
272            }
273            // Buffer the processed message so that it can be returned by the next call to `recv()`
274            self.buffered_message = Some(message);
275        }
276        Ok(None)
277    }
278
279    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
280    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
281    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
282    /// messages.
283    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
284        if let Some(buffered) = self.buffered_message.take() {
285            return match self.unmarshall(buffered) {
286                Ok(message) => Ok(message),
287                Err(error) => {
288                    self.buffer = RecvBuf::Terminated;
289                    Err(error)
290                }
291            };
292        }
293        if let Some(message) = self.next_message().await? {
294            match self.unmarshall(message) {
295                Ok(message) => Ok(message),
296                Err(error) => {
297                    self.buffer = RecvBuf::Terminated;
298                    Err(error)
299                }
300            }
301        } else {
302            Ok(None)
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::{InitialMessageType, Receiver, UnmarshallMessage};
310    use aws_smithy_eventstream::error::Error as EventStreamError;
311    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
312    use aws_smithy_runtime_api::client::result::SdkError;
313    use aws_smithy_types::body::SdkBody;
314    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
315    use bytes::Bytes;
316    use http_body_1x::Frame;
317    use std::error::Error as StdError;
318    use std::io::{Error as IOError, ErrorKind};
319
320    fn encode_initial_response() -> Bytes {
321        let mut buffer = Vec::new();
322        let message = Message::new(Bytes::new())
323            .add_header(Header::new(
324                ":message-type",
325                HeaderValue::String("event".into()),
326            ))
327            .add_header(Header::new(
328                ":event-type",
329                HeaderValue::String("initial-response".into()),
330            ));
331        write_message_to(&message, &mut buffer).unwrap();
332        buffer.into()
333    }
334
335    fn encode_message(message: &str) -> Bytes {
336        let mut buffer = Vec::new();
337        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
338        write_message_to(&message, &mut buffer).unwrap();
339        buffer.into()
340    }
341
342    fn map_to_frame(stream: Vec<Result<Bytes, IOError>>) -> Vec<Result<Frame<Bytes>, IOError>> {
343        stream
344            .into_iter()
345            .map(|chunk| chunk.map(Frame::data))
346            .collect()
347    }
348
349    #[derive(Debug)]
350    struct FakeError;
351    impl std::fmt::Display for FakeError {
352        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353            write!(f, "FakeError")
354        }
355    }
356    impl StdError for FakeError {}
357
358    #[derive(Debug, Eq, PartialEq)]
359    struct TestMessage(String);
360
361    #[derive(Debug)]
362    struct Unmarshaller;
363    impl UnmarshallMessage for Unmarshaller {
364        type Output = TestMessage;
365        type Error = EventStreamError;
366
367        fn unmarshall(
368            &self,
369            message: &Message,
370        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
371            Ok(UnmarshalledMessage::Event(TestMessage(
372                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
373            )))
374        }
375    }
376
377    #[tokio::test]
378    async fn receive_success() {
379        let chunks: Vec<Result<_, IOError>> =
380            map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
381        let chunk_stream = futures_util::stream::iter(chunks);
382        let stream_body = http_body_util::StreamBody::new(chunk_stream);
383        let body = SdkBody::from_body_1_x(stream_body);
384
385        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
386
387        assert_eq!(
388            TestMessage("one".into()),
389            receiver.recv().await.unwrap().unwrap()
390        );
391        assert_eq!(
392            TestMessage("two".into()),
393            receiver.recv().await.unwrap().unwrap()
394        );
395        assert_eq!(None, receiver.recv().await.unwrap());
396    }
397
398    #[tokio::test]
399    async fn receive_last_chunk_empty() {
400        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
401            Ok(encode_message("one")),
402            Ok(encode_message("two")),
403            Ok(Bytes::from_static(&[])),
404        ]);
405        let chunk_stream = futures_util::stream::iter(chunks);
406        let stream_body = http_body_util::StreamBody::new(chunk_stream);
407        let body = SdkBody::from_body_1_x(stream_body);
408        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
409        assert_eq!(
410            TestMessage("one".into()),
411            receiver.recv().await.unwrap().unwrap()
412        );
413        assert_eq!(
414            TestMessage("two".into()),
415            receiver.recv().await.unwrap().unwrap()
416        );
417        assert_eq!(None, receiver.recv().await.unwrap());
418    }
419
420    #[tokio::test]
421    async fn receive_last_chunk_not_full_message() {
422        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
423            Ok(encode_message("one")),
424            Ok(encode_message("two")),
425            Ok(encode_message("three").split_to(10)),
426        ]);
427        let chunk_stream = futures_util::stream::iter(chunks);
428        let stream_body = http_body_util::StreamBody::new(chunk_stream);
429        let body = SdkBody::from_body_1_x(stream_body);
430        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
431        assert_eq!(
432            TestMessage("one".into()),
433            receiver.recv().await.unwrap().unwrap()
434        );
435        assert_eq!(
436            TestMessage("two".into()),
437            receiver.recv().await.unwrap().unwrap()
438        );
439        assert!(matches!(
440            receiver.recv().await,
441            Err(SdkError::ResponseError { .. }),
442        ));
443    }
444
445    #[tokio::test]
446    async fn receive_last_chunk_has_multiple_messages() {
447        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
448            Ok(encode_message("one")),
449            Ok(encode_message("two")),
450            Ok(Bytes::from(
451                [encode_message("three"), encode_message("four")].concat(),
452            )),
453        ]);
454        let chunk_stream = futures_util::stream::iter(chunks);
455        let stream_body = http_body_util::StreamBody::new(chunk_stream);
456        let body = SdkBody::from_body_1_x(stream_body);
457        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
458        assert_eq!(
459            TestMessage("one".into()),
460            receiver.recv().await.unwrap().unwrap()
461        );
462        assert_eq!(
463            TestMessage("two".into()),
464            receiver.recv().await.unwrap().unwrap()
465        );
466        assert_eq!(
467            TestMessage("three".into()),
468            receiver.recv().await.unwrap().unwrap()
469        );
470        assert_eq!(
471            TestMessage("four".into()),
472            receiver.recv().await.unwrap().unwrap()
473        );
474        assert_eq!(None, receiver.recv().await.unwrap());
475    }
476
477    proptest::proptest! {
478        #[test]
479        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
480            let combined = Bytes::from([
481                encode_message("one"),
482                encode_message("two"),
483                encode_message("three"),
484                encode_message("four"),
485                encode_message("five"),
486                encode_message("six"),
487                encode_message("seven"),
488                encode_message("eight"),
489            ].concat());
490
491            let midpoint = combined.len() / 2;
492            let (start, boundary1, boundary2, end) = (
493                0,
494                b1 % midpoint,
495                midpoint + b2 % midpoint,
496                combined.len()
497            );
498            println!("[{start}, {boundary1}], [{boundary1}, {boundary2}], [{boundary2}, {end}]");
499
500            let rt = tokio::runtime::Runtime::new().unwrap();
501            rt.block_on(async move {
502                let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
503                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
504                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
505                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
506                ]);
507
508                let chunk_stream = futures_util::stream::iter(chunks);
509                let stream_body = http_body_util::StreamBody::new(chunk_stream);
510                let body = SdkBody::from_body_1_x(stream_body);
511                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
512                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
513                    assert_eq!(
514                        TestMessage((*payload).into()),
515                        receiver.recv().await.unwrap().unwrap()
516                    );
517                }
518                assert_eq!(None, receiver.recv().await.unwrap());
519            });
520        }
521    }
522
523    #[tokio::test]
524    async fn receive_network_failure() {
525        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
526            Ok(encode_message("one")),
527            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
528        ]);
529        let chunk_stream = futures_util::stream::iter(chunks);
530        let stream_body = http_body_util::StreamBody::new(chunk_stream);
531        let body = SdkBody::from_body_1_x(stream_body);
532        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
533        assert_eq!(
534            TestMessage("one".into()),
535            receiver.recv().await.unwrap().unwrap()
536        );
537        assert!(matches!(
538            receiver.recv().await,
539            Err(SdkError::DispatchFailure(_))
540        ));
541    }
542
543    #[tokio::test]
544    async fn receive_message_parse_failure() {
545        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
546            Ok(encode_message("one")),
547            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
548            // for the MessageFrameDecoder to actually start parsing it.
549            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
550        ]);
551        let chunk_stream = futures_util::stream::iter(chunks);
552        let stream_body = http_body_util::StreamBody::new(chunk_stream);
553        let body = SdkBody::from_body_1_x(stream_body);
554        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
555        assert_eq!(
556            TestMessage("one".into()),
557            receiver.recv().await.unwrap().unwrap()
558        );
559        assert!(matches!(
560            receiver.recv().await,
561            Err(SdkError::ResponseError { .. })
562        ));
563    }
564
565    #[tokio::test]
566    async fn receive_initial_response() {
567        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
568            Ok(encode_initial_response()),
569            Ok(encode_message("one")),
570        ]);
571        let chunk_stream = futures_util::stream::iter(chunks);
572        let stream_body = http_body_util::StreamBody::new(chunk_stream);
573        let body = SdkBody::from_body_1_x(stream_body);
574        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
575        assert!(receiver
576            .try_recv_initial(InitialMessageType::Response)
577            .await
578            .unwrap()
579            .is_some());
580        assert_eq!(
581            TestMessage("one".into()),
582            receiver.recv().await.unwrap().unwrap()
583        );
584    }
585
586    #[tokio::test]
587    async fn receive_no_initial_response() {
588        let chunks: Vec<Result<_, IOError>> =
589            map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
590        let chunk_stream = futures_util::stream::iter(chunks);
591        let stream_body = http_body_util::StreamBody::new(chunk_stream);
592
593        let body = SdkBody::from_body_1_x(stream_body);
594        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
595        assert!(receiver
596            .try_recv_initial(InitialMessageType::Response)
597            .await
598            .unwrap()
599            .is_none());
600        assert_eq!(
601            TestMessage("one".into()),
602            receiver.recv().await.unwrap().unwrap()
603        );
604        assert_eq!(
605            TestMessage("two".into()),
606            receiver.recv().await.unwrap().unwrap()
607        );
608    }
609
610    fn assert_send_and_sync<T: Send + Sync>() {}
611
612    #[tokio::test]
613    async fn receiver_is_send_and_sync() {
614        assert_send_and_sync::<Receiver<(), ()>>();
615    }
616}