aws_smithy_http/event_stream/
receiver.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{
7    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21/// Wrapper around SegmentedBuf that tracks the state of the stream.
22#[derive(Debug)]
23enum RecvBuf {
24    /// Nothing has been buffered yet.
25    Empty,
26    /// Some data has been buffered.
27    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
28    Partial(SegmentedBuf<Bytes>),
29    /// The end of the stream has been reached, but there may still be some buffered data.
30    EosPartial(SegmentedBuf<Bytes>),
31    /// An exception terminated this stream.
32    Terminated,
33}
34
35impl RecvBuf {
36    /// Returns true if there's more buffered data.
37    fn has_data(&self) -> bool {
38        match self {
39            RecvBuf::Empty | RecvBuf::Terminated => false,
40            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41        }
42    }
43
44    /// Returns true if the stream has ended.
45    fn is_eos(&self) -> bool {
46        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47    }
48
49    /// Returns a mutable reference to the underlying buffered data.
50    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51        match self {
52            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53            RecvBuf::Partial(segmented) => segmented,
54            RecvBuf::EosPartial(segmented) => segmented,
55            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56        }
57    }
58
59    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
60    /// if the `RecvBuf` was previously empty.
61    fn with_partial(self, partial: Bytes) -> Self {
62        match self {
63            RecvBuf::Empty => {
64                let mut segmented = SegmentedBuf::new();
65                segmented.push(partial);
66                RecvBuf::Partial(segmented)
67            }
68            RecvBuf::Partial(mut segmented) => {
69                segmented.push(partial);
70                RecvBuf::Partial(segmented)
71            }
72            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74            }
75        }
76    }
77
78    /// Returns a `RecvBuf` that has reached end of stream.
79    fn ended(self) -> Self {
80        match self {
81            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85        }
86    }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91    /// The stream ended before a complete message frame was received.
92    UnexpectedEndOfStream,
93}
94
95/// An error that occurs within an event stream receiver.
96#[derive(Debug)]
97pub struct ReceiverError {
98    kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self.kind {
104            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105        }
106    }
107}
108
109impl StdError for ReceiverError {}
110
111/// Receives Smithy-modeled messages out of an Event Stream.
112#[derive(Debug)]
113pub struct Receiver<T, E> {
114    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115    decoder: MessageFrameDecoder,
116    buffer: RecvBuf,
117    body: SdkBody,
118    /// Event Stream has optional initial response frames an with `:message-type` of
119    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
120    /// initial response, then the message will be stored in `buffered_message` so that it can
121    /// be returned with the next call of `recv()`.
122    buffered_message: Option<Message>,
123    _phantom: PhantomData<E>,
124}
125
126// Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
127#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130    Request,
131    Response,
132}
133
134impl InitialMessageType {
135    fn as_str(&self) -> &'static str {
136        match self {
137            InitialMessageType::Request => "initial-request",
138            InitialMessageType::Response => "initial-response",
139        }
140    }
141}
142
143impl<T, E> Receiver<T, E> {
144    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
145    pub fn new(
146        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147        body: SdkBody,
148    ) -> Self {
149        Receiver {
150            unmarshaller: Box::new(unmarshaller),
151            decoder: MessageFrameDecoder::new(),
152            buffer: RecvBuf::Empty,
153            body,
154            buffered_message: None,
155            _phantom: Default::default(),
156        }
157    }
158
159    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160        match self.unmarshaller.unmarshall(&message) {
161            Ok(unmarshalled) => match unmarshalled {
162                UnmarshalledMessage::Event(event) => Ok(Some(event)),
163                UnmarshalledMessage::Error(err) => {
164                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165                }
166            },
167            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168        }
169    }
170
171    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172        use http_body_util::BodyExt;
173
174        if !self.buffer.is_eos() {
175            let next_chunk = self
176                .body
177                .frame()
178                .await
179                .transpose()
180                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182            if let Some(chunk) = next_chunk {
183                // Ignoring the possibility of trailers here since event_streams don't have them
184                if let Ok(data) = chunk.into_data() {
185                    self.buffer = buffer.with_partial(data);
186                }
187            } else {
188                self.buffer = buffer.ended();
189            }
190        }
191        Ok(())
192    }
193
194    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
195        while !self.buffer.is_eos() {
196            if self.buffer.has_data() {
197                if let DecodedFrame::Complete(message) = self
198                    .decoder
199                    .decode_frame(self.buffer.buffered())
200                    .map_err(|err| {
201                        SdkError::response_error(
202                            err,
203                            // the buffer has been consumed
204                            RawMessage::Invalid(None),
205                        )
206                    })?
207                {
208                    trace!(message = ?message, "received complete event stream message");
209                    return Ok(Some(message));
210                }
211            }
212
213            self.buffer_next_chunk().await?;
214        }
215        if self.buffer.has_data() {
216            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
217            let buf = self.buffer.buffered();
218            return Err(SdkError::response_error(
219                ReceiverError {
220                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
221                },
222                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
223            ));
224        }
225        Ok(None)
226    }
227
228    /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
229    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
230    #[doc(hidden)]
231    pub async fn try_recv_initial(
232        &mut self,
233        message_type: InitialMessageType,
234    ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
235        if let Some(message) = self.next_message().await? {
236            if let Some(event_type) = message
237                .headers()
238                .iter()
239                .find(|h| h.name().as_str() == ":event-type")
240            {
241                if event_type
242                    .value()
243                    .as_string()
244                    .map(|s| s.as_str() == message_type.as_str())
245                    .unwrap_or(false)
246                {
247                    return Ok(Some(message));
248                }
249            }
250            // Buffer the message so that it can be returned by the next call to `recv()`
251            self.buffered_message = Some(message);
252        }
253        Ok(None)
254    }
255
256    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
257    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
258    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
259    /// messages.
260    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
261        if let Some(buffered) = self.buffered_message.take() {
262            return match self.unmarshall(buffered) {
263                Ok(message) => Ok(message),
264                Err(error) => {
265                    self.buffer = RecvBuf::Terminated;
266                    Err(error)
267                }
268            };
269        }
270        if let Some(message) = self.next_message().await? {
271            match self.unmarshall(message) {
272                Ok(message) => Ok(message),
273                Err(error) => {
274                    self.buffer = RecvBuf::Terminated;
275                    Err(error)
276                }
277            }
278        } else {
279            Ok(None)
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::{InitialMessageType, Receiver, UnmarshallMessage};
287    use aws_smithy_eventstream::error::Error as EventStreamError;
288    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
289    use aws_smithy_runtime_api::client::result::SdkError;
290    use aws_smithy_types::body::SdkBody;
291    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
292    use bytes::Bytes;
293    use http_body_1x::Frame;
294    use std::error::Error as StdError;
295    use std::io::{Error as IOError, ErrorKind};
296
297    fn encode_initial_response() -> Bytes {
298        let mut buffer = Vec::new();
299        let message = Message::new(Bytes::new())
300            .add_header(Header::new(
301                ":message-type",
302                HeaderValue::String("event".into()),
303            ))
304            .add_header(Header::new(
305                ":event-type",
306                HeaderValue::String("initial-response".into()),
307            ));
308        write_message_to(&message, &mut buffer).unwrap();
309        buffer.into()
310    }
311
312    fn encode_message(message: &str) -> Bytes {
313        let mut buffer = Vec::new();
314        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
315        write_message_to(&message, &mut buffer).unwrap();
316        buffer.into()
317    }
318
319    fn map_to_frame(stream: Vec<Result<Bytes, IOError>>) -> Vec<Result<Frame<Bytes>, IOError>> {
320        stream
321            .into_iter()
322            .map(|chunk| chunk.map(Frame::data))
323            .collect()
324    }
325
326    #[derive(Debug)]
327    struct FakeError;
328    impl std::fmt::Display for FakeError {
329        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330            write!(f, "FakeError")
331        }
332    }
333    impl StdError for FakeError {}
334
335    #[derive(Debug, Eq, PartialEq)]
336    struct TestMessage(String);
337
338    #[derive(Debug)]
339    struct Unmarshaller;
340    impl UnmarshallMessage for Unmarshaller {
341        type Output = TestMessage;
342        type Error = EventStreamError;
343
344        fn unmarshall(
345            &self,
346            message: &Message,
347        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
348            Ok(UnmarshalledMessage::Event(TestMessage(
349                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
350            )))
351        }
352    }
353
354    #[tokio::test]
355    async fn receive_success() {
356        let chunks: Vec<Result<_, IOError>> =
357            map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
358        let chunk_stream = futures_util::stream::iter(chunks);
359        let stream_body = http_body_util::StreamBody::new(chunk_stream);
360        let body = SdkBody::from_body_1_x(stream_body);
361
362        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
363
364        assert_eq!(
365            TestMessage("one".into()),
366            receiver.recv().await.unwrap().unwrap()
367        );
368        assert_eq!(
369            TestMessage("two".into()),
370            receiver.recv().await.unwrap().unwrap()
371        );
372        assert_eq!(None, receiver.recv().await.unwrap());
373    }
374
375    #[tokio::test]
376    async fn receive_last_chunk_empty() {
377        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
378            Ok(encode_message("one")),
379            Ok(encode_message("two")),
380            Ok(Bytes::from_static(&[])),
381        ]);
382        let chunk_stream = futures_util::stream::iter(chunks);
383        let stream_body = http_body_util::StreamBody::new(chunk_stream);
384        let body = SdkBody::from_body_1_x(stream_body);
385        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
386        assert_eq!(
387            TestMessage("one".into()),
388            receiver.recv().await.unwrap().unwrap()
389        );
390        assert_eq!(
391            TestMessage("two".into()),
392            receiver.recv().await.unwrap().unwrap()
393        );
394        assert_eq!(None, receiver.recv().await.unwrap());
395    }
396
397    #[tokio::test]
398    async fn receive_last_chunk_not_full_message() {
399        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
400            Ok(encode_message("one")),
401            Ok(encode_message("two")),
402            Ok(encode_message("three").split_to(10)),
403        ]);
404        let chunk_stream = futures_util::stream::iter(chunks);
405        let stream_body = http_body_util::StreamBody::new(chunk_stream);
406        let body = SdkBody::from_body_1_x(stream_body);
407        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
408        assert_eq!(
409            TestMessage("one".into()),
410            receiver.recv().await.unwrap().unwrap()
411        );
412        assert_eq!(
413            TestMessage("two".into()),
414            receiver.recv().await.unwrap().unwrap()
415        );
416        assert!(matches!(
417            receiver.recv().await,
418            Err(SdkError::ResponseError { .. }),
419        ));
420    }
421
422    #[tokio::test]
423    async fn receive_last_chunk_has_multiple_messages() {
424        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
425            Ok(encode_message("one")),
426            Ok(encode_message("two")),
427            Ok(Bytes::from(
428                [encode_message("three"), encode_message("four")].concat(),
429            )),
430        ]);
431        let chunk_stream = futures_util::stream::iter(chunks);
432        let stream_body = http_body_util::StreamBody::new(chunk_stream);
433        let body = SdkBody::from_body_1_x(stream_body);
434        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
435        assert_eq!(
436            TestMessage("one".into()),
437            receiver.recv().await.unwrap().unwrap()
438        );
439        assert_eq!(
440            TestMessage("two".into()),
441            receiver.recv().await.unwrap().unwrap()
442        );
443        assert_eq!(
444            TestMessage("three".into()),
445            receiver.recv().await.unwrap().unwrap()
446        );
447        assert_eq!(
448            TestMessage("four".into()),
449            receiver.recv().await.unwrap().unwrap()
450        );
451        assert_eq!(None, receiver.recv().await.unwrap());
452    }
453
454    proptest::proptest! {
455        #[test]
456        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
457            let combined = Bytes::from([
458                encode_message("one"),
459                encode_message("two"),
460                encode_message("three"),
461                encode_message("four"),
462                encode_message("five"),
463                encode_message("six"),
464                encode_message("seven"),
465                encode_message("eight"),
466            ].concat());
467
468            let midpoint = combined.len() / 2;
469            let (start, boundary1, boundary2, end) = (
470                0,
471                b1 % midpoint,
472                midpoint + b2 % midpoint,
473                combined.len()
474            );
475            println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
476
477            let rt = tokio::runtime::Runtime::new().unwrap();
478            rt.block_on(async move {
479                let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
480                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
481                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
482                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
483                ]);
484
485                let chunk_stream = futures_util::stream::iter(chunks);
486                let stream_body = http_body_util::StreamBody::new(chunk_stream);
487                let body = SdkBody::from_body_1_x(stream_body);
488                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
489                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
490                    assert_eq!(
491                        TestMessage((*payload).into()),
492                        receiver.recv().await.unwrap().unwrap()
493                    );
494                }
495                assert_eq!(None, receiver.recv().await.unwrap());
496            });
497        }
498    }
499
500    #[tokio::test]
501    async fn receive_network_failure() {
502        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
503            Ok(encode_message("one")),
504            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
505        ]);
506        let chunk_stream = futures_util::stream::iter(chunks);
507        let stream_body = http_body_util::StreamBody::new(chunk_stream);
508        let body = SdkBody::from_body_1_x(stream_body);
509        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
510        assert_eq!(
511            TestMessage("one".into()),
512            receiver.recv().await.unwrap().unwrap()
513        );
514        assert!(matches!(
515            receiver.recv().await,
516            Err(SdkError::DispatchFailure(_))
517        ));
518    }
519
520    #[tokio::test]
521    async fn receive_message_parse_failure() {
522        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
523            Ok(encode_message("one")),
524            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
525            // for the MessageFrameDecoder to actually start parsing it.
526            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
527        ]);
528        let chunk_stream = futures_util::stream::iter(chunks);
529        let stream_body = http_body_util::StreamBody::new(chunk_stream);
530        let body = SdkBody::from_body_1_x(stream_body);
531        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
532        assert_eq!(
533            TestMessage("one".into()),
534            receiver.recv().await.unwrap().unwrap()
535        );
536        assert!(matches!(
537            receiver.recv().await,
538            Err(SdkError::ResponseError { .. })
539        ));
540    }
541
542    #[tokio::test]
543    async fn receive_initial_response() {
544        let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
545            Ok(encode_initial_response()),
546            Ok(encode_message("one")),
547        ]);
548        let chunk_stream = futures_util::stream::iter(chunks);
549        let stream_body = http_body_util::StreamBody::new(chunk_stream);
550        let body = SdkBody::from_body_1_x(stream_body);
551        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
552        assert!(receiver
553            .try_recv_initial(InitialMessageType::Response)
554            .await
555            .unwrap()
556            .is_some());
557        assert_eq!(
558            TestMessage("one".into()),
559            receiver.recv().await.unwrap().unwrap()
560        );
561    }
562
563    #[tokio::test]
564    async fn receive_no_initial_response() {
565        let chunks: Vec<Result<_, IOError>> =
566            map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
567        let chunk_stream = futures_util::stream::iter(chunks);
568        let stream_body = http_body_util::StreamBody::new(chunk_stream);
569
570        let body = SdkBody::from_body_1_x(stream_body);
571        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
572        assert!(receiver
573            .try_recv_initial(InitialMessageType::Response)
574            .await
575            .unwrap()
576            .is_none());
577        assert_eq!(
578            TestMessage("one".into()),
579            receiver.recv().await.unwrap().unwrap()
580        );
581        assert_eq!(
582            TestMessage("two".into()),
583            receiver.recv().await.unwrap().unwrap()
584        );
585    }
586
587    fn assert_send_and_sync<T: Send + Sync>() {}
588
589    #[tokio::test]
590    async fn receiver_is_send_and_sync() {
591        assert_send_and_sync::<Receiver<(), ()>>();
592    }
593}