aws_smithy_http/event_stream/
receiver.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_eventstream::frame::{
7    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21/// Wrapper around SegmentedBuf that tracks the state of the stream.
22#[derive(Debug)]
23enum RecvBuf {
24    /// Nothing has been buffered yet.
25    Empty,
26    /// Some data has been buffered.
27    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
28    Partial(SegmentedBuf<Bytes>),
29    /// The end of the stream has been reached, but there may still be some buffered data.
30    EosPartial(SegmentedBuf<Bytes>),
31    /// An exception terminated this stream.
32    Terminated,
33}
34
35impl RecvBuf {
36    /// Returns true if there's more buffered data.
37    fn has_data(&self) -> bool {
38        match self {
39            RecvBuf::Empty | RecvBuf::Terminated => false,
40            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41        }
42    }
43
44    /// Returns true if the stream has ended.
45    fn is_eos(&self) -> bool {
46        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47    }
48
49    /// Returns a mutable reference to the underlying buffered data.
50    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51        match self {
52            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53            RecvBuf::Partial(segmented) => segmented,
54            RecvBuf::EosPartial(segmented) => segmented,
55            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56        }
57    }
58
59    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
60    /// if the `RecvBuf` was previously empty.
61    fn with_partial(self, partial: Bytes) -> Self {
62        match self {
63            RecvBuf::Empty => {
64                let mut segmented = SegmentedBuf::new();
65                segmented.push(partial);
66                RecvBuf::Partial(segmented)
67            }
68            RecvBuf::Partial(mut segmented) => {
69                segmented.push(partial);
70                RecvBuf::Partial(segmented)
71            }
72            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74            }
75        }
76    }
77
78    /// Returns a `RecvBuf` that has reached end of stream.
79    fn ended(self) -> Self {
80        match self {
81            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85        }
86    }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91    /// The stream ended before a complete message frame was received.
92    UnexpectedEndOfStream,
93}
94
95/// An error that occurs within an event stream receiver.
96#[derive(Debug)]
97pub struct ReceiverError {
98    kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        match self.kind {
104            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105        }
106    }
107}
108
109impl StdError for ReceiverError {}
110
111/// Receives Smithy-modeled messages out of an Event Stream.
112#[derive(Debug)]
113pub struct Receiver<T, E> {
114    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115    decoder: MessageFrameDecoder,
116    buffer: RecvBuf,
117    body: SdkBody,
118    /// Event Stream has optional initial response frames an with `:message-type` of
119    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
120    /// initial response, then the message will be stored in `buffered_message` so that it can
121    /// be returned with the next call of `recv()`.
122    buffered_message: Option<Message>,
123    _phantom: PhantomData<E>,
124}
125
126// Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
127#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130    Request,
131    Response,
132}
133
134impl InitialMessageType {
135    fn as_str(&self) -> &'static str {
136        match self {
137            InitialMessageType::Request => "initial-request",
138            InitialMessageType::Response => "initial-response",
139        }
140    }
141}
142
143impl<T, E> Receiver<T, E> {
144    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
145    pub fn new(
146        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147        body: SdkBody,
148    ) -> Self {
149        Receiver {
150            unmarshaller: Box::new(unmarshaller),
151            decoder: MessageFrameDecoder::new(),
152            buffer: RecvBuf::Empty,
153            body,
154            buffered_message: None,
155            _phantom: Default::default(),
156        }
157    }
158
159    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160        match self.unmarshaller.unmarshall(&message) {
161            Ok(unmarshalled) => match unmarshalled {
162                UnmarshalledMessage::Event(event) => Ok(Some(event)),
163                UnmarshalledMessage::Error(err) => {
164                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165                }
166            },
167            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168        }
169    }
170
171    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172        use http_body_04x::Body;
173
174        if !self.buffer.is_eos() {
175            let next_chunk = self
176                .body
177                .data()
178                .await
179                .transpose()
180                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182            if let Some(chunk) = next_chunk {
183                self.buffer = buffer.with_partial(chunk);
184            } else {
185                self.buffer = buffer.ended();
186            }
187        }
188        Ok(())
189    }
190
191    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
192        while !self.buffer.is_eos() {
193            if self.buffer.has_data() {
194                if let DecodedFrame::Complete(message) = self
195                    .decoder
196                    .decode_frame(self.buffer.buffered())
197                    .map_err(|err| {
198                        SdkError::response_error(
199                            err,
200                            // the buffer has been consumed
201                            RawMessage::Invalid(None),
202                        )
203                    })?
204                {
205                    trace!(message = ?message, "received complete event stream message");
206                    return Ok(Some(message));
207                }
208            }
209
210            self.buffer_next_chunk().await?;
211        }
212        if self.buffer.has_data() {
213            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
214            let buf = self.buffer.buffered();
215            return Err(SdkError::response_error(
216                ReceiverError {
217                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
218                },
219                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
220            ));
221        }
222        Ok(None)
223    }
224
225    /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
226    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
227    #[doc(hidden)]
228    pub async fn try_recv_initial(
229        &mut self,
230        message_type: InitialMessageType,
231    ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
232        self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
233            .await
234            .map(|opt| opt.map(|(msg, _)| msg))
235    }
236
237    /// Tries to receive the initial response message with preprocessing support.
238    ///
239    /// The preprocessor function can transform the raw message (e.g., unwrap envelopes)
240    /// and return metadata alongside the transformed message. If the transformed message
241    /// matches the expected `message_type`, both the message and metadata are returned.
242    /// Otherwise, the transformed message is buffered and `Ok(None)` is returned.
243    #[doc(hidden)]
244    pub async fn try_recv_initial_with_preprocessor<F, M>(
245        &mut self,
246        message_type: InitialMessageType,
247        preprocessor: F,
248    ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
249    where
250        F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
251    {
252        if let Some(message) = self.next_message().await? {
253            let (processed_message, metadata) =
254                preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
255
256            if let Some(event_type) = processed_message
257                .headers()
258                .iter()
259                .find(|h| h.name().as_str() == ":event-type")
260            {
261                if event_type
262                    .value()
263                    .as_string()
264                    .map(|s| s.as_str() == message_type.as_str())
265                    .unwrap_or(false)
266                {
267                    return Ok(Some((processed_message, metadata)));
268                }
269            }
270            // Buffer the processed message so that it can be returned by the next call to `recv()`
271            self.buffered_message = Some(message);
272        }
273        Ok(None)
274    }
275
276    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
277    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
278    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
279    /// messages.
280    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
281        if let Some(buffered) = self.buffered_message.take() {
282            return match self.unmarshall(buffered) {
283                Ok(message) => Ok(message),
284                Err(error) => {
285                    self.buffer = RecvBuf::Terminated;
286                    Err(error)
287                }
288            };
289        }
290        if let Some(message) = self.next_message().await? {
291            match self.unmarshall(message) {
292                Ok(message) => Ok(message),
293                Err(error) => {
294                    self.buffer = RecvBuf::Terminated;
295                    Err(error)
296                }
297            }
298        } else {
299            Ok(None)
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::{InitialMessageType, Receiver, UnmarshallMessage};
307    use aws_smithy_eventstream::error::Error as EventStreamError;
308    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
309    use aws_smithy_runtime_api::client::result::SdkError;
310    use aws_smithy_types::body::SdkBody;
311    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
312    use bytes::Bytes;
313    use hyper::body::Body;
314    use std::error::Error as StdError;
315    use std::io::{Error as IOError, ErrorKind};
316
317    fn encode_initial_response() -> Bytes {
318        let mut buffer = Vec::new();
319        let message = Message::new(Bytes::new())
320            .add_header(Header::new(
321                ":message-type",
322                HeaderValue::String("event".into()),
323            ))
324            .add_header(Header::new(
325                ":event-type",
326                HeaderValue::String("initial-response".into()),
327            ));
328        write_message_to(&message, &mut buffer).unwrap();
329        buffer.into()
330    }
331
332    fn encode_message(message: &str) -> Bytes {
333        let mut buffer = Vec::new();
334        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
335        write_message_to(&message, &mut buffer).unwrap();
336        buffer.into()
337    }
338
339    #[derive(Debug)]
340    struct FakeError;
341    impl std::fmt::Display for FakeError {
342        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343            write!(f, "FakeError")
344        }
345    }
346    impl StdError for FakeError {}
347
348    #[derive(Debug, Eq, PartialEq)]
349    struct TestMessage(String);
350
351    #[derive(Debug)]
352    struct Unmarshaller;
353    impl UnmarshallMessage for Unmarshaller {
354        type Output = TestMessage;
355        type Error = EventStreamError;
356
357        fn unmarshall(
358            &self,
359            message: &Message,
360        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
361            Ok(UnmarshalledMessage::Event(TestMessage(
362                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
363            )))
364        }
365    }
366
367    #[tokio::test]
368    async fn receive_success() {
369        let chunks: Vec<Result<_, IOError>> =
370            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
371        let chunk_stream = futures_util::stream::iter(chunks);
372        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
373        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
374        assert_eq!(
375            TestMessage("one".into()),
376            receiver.recv().await.unwrap().unwrap()
377        );
378        assert_eq!(
379            TestMessage("two".into()),
380            receiver.recv().await.unwrap().unwrap()
381        );
382        assert_eq!(None, receiver.recv().await.unwrap());
383    }
384
385    #[tokio::test]
386    async fn receive_last_chunk_empty() {
387        let chunks: Vec<Result<_, IOError>> = vec![
388            Ok(encode_message("one")),
389            Ok(encode_message("two")),
390            Ok(Bytes::from_static(&[])),
391        ];
392        let chunk_stream = futures_util::stream::iter(chunks);
393        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
394        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
395        assert_eq!(
396            TestMessage("one".into()),
397            receiver.recv().await.unwrap().unwrap()
398        );
399        assert_eq!(
400            TestMessage("two".into()),
401            receiver.recv().await.unwrap().unwrap()
402        );
403        assert_eq!(None, receiver.recv().await.unwrap());
404    }
405
406    #[tokio::test]
407    async fn receive_last_chunk_not_full_message() {
408        let chunks: Vec<Result<_, IOError>> = vec![
409            Ok(encode_message("one")),
410            Ok(encode_message("two")),
411            Ok(encode_message("three").split_to(10)),
412        ];
413        let chunk_stream = futures_util::stream::iter(chunks);
414        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
415        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
416        assert_eq!(
417            TestMessage("one".into()),
418            receiver.recv().await.unwrap().unwrap()
419        );
420        assert_eq!(
421            TestMessage("two".into()),
422            receiver.recv().await.unwrap().unwrap()
423        );
424        assert!(matches!(
425            receiver.recv().await,
426            Err(SdkError::ResponseError { .. }),
427        ));
428    }
429
430    #[tokio::test]
431    async fn receive_last_chunk_has_multiple_messages() {
432        let chunks: Vec<Result<_, IOError>> = vec![
433            Ok(encode_message("one")),
434            Ok(encode_message("two")),
435            Ok(Bytes::from(
436                [encode_message("three"), encode_message("four")].concat(),
437            )),
438        ];
439        let chunk_stream = futures_util::stream::iter(chunks);
440        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
441        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
442        assert_eq!(
443            TestMessage("one".into()),
444            receiver.recv().await.unwrap().unwrap()
445        );
446        assert_eq!(
447            TestMessage("two".into()),
448            receiver.recv().await.unwrap().unwrap()
449        );
450        assert_eq!(
451            TestMessage("three".into()),
452            receiver.recv().await.unwrap().unwrap()
453        );
454        assert_eq!(
455            TestMessage("four".into()),
456            receiver.recv().await.unwrap().unwrap()
457        );
458        assert_eq!(None, receiver.recv().await.unwrap());
459    }
460
461    proptest::proptest! {
462        #[test]
463        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
464            let combined = Bytes::from([
465                encode_message("one"),
466                encode_message("two"),
467                encode_message("three"),
468                encode_message("four"),
469                encode_message("five"),
470                encode_message("six"),
471                encode_message("seven"),
472                encode_message("eight"),
473            ].concat());
474
475            let midpoint = combined.len() / 2;
476            let (start, boundary1, boundary2, end) = (
477                0,
478                b1 % midpoint,
479                midpoint + b2 % midpoint,
480                combined.len()
481            );
482            println!("[{start}, {boundary1}], [{boundary1}, {boundary2}], [{boundary2}, {end}]");
483
484            let rt = tokio::runtime::Runtime::new().unwrap();
485            rt.block_on(async move {
486                let chunks: Vec<Result<_, IOError>> = vec![
487                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
488                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
489                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
490                ];
491
492                let chunk_stream = futures_util::stream::iter(chunks);
493                let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
494                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
495                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
496                    assert_eq!(
497                        TestMessage((*payload).into()),
498                        receiver.recv().await.unwrap().unwrap()
499                    );
500                }
501                assert_eq!(None, receiver.recv().await.unwrap());
502            });
503        }
504    }
505
506    #[tokio::test]
507    async fn receive_network_failure() {
508        let chunks: Vec<Result<_, IOError>> = vec![
509            Ok(encode_message("one")),
510            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
511        ];
512        let chunk_stream = futures_util::stream::iter(chunks);
513        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
514        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
515        assert_eq!(
516            TestMessage("one".into()),
517            receiver.recv().await.unwrap().unwrap()
518        );
519        assert!(matches!(
520            receiver.recv().await,
521            Err(SdkError::DispatchFailure(_))
522        ));
523    }
524
525    #[tokio::test]
526    async fn receive_message_parse_failure() {
527        let chunks: Vec<Result<_, IOError>> = vec![
528            Ok(encode_message("one")),
529            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
530            // for the MessageFrameDecoder to actually start parsing it.
531            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
532        ];
533        let chunk_stream = futures_util::stream::iter(chunks);
534        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
535        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
536        assert_eq!(
537            TestMessage("one".into()),
538            receiver.recv().await.unwrap().unwrap()
539        );
540        assert!(matches!(
541            receiver.recv().await,
542            Err(SdkError::ResponseError { .. })
543        ));
544    }
545
546    #[tokio::test]
547    async fn receive_initial_response() {
548        let chunks: Vec<Result<_, IOError>> =
549            vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
550        let chunk_stream = futures_util::stream::iter(chunks);
551        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
552        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
553        assert!(receiver
554            .try_recv_initial(InitialMessageType::Response)
555            .await
556            .unwrap()
557            .is_some());
558        assert_eq!(
559            TestMessage("one".into()),
560            receiver.recv().await.unwrap().unwrap()
561        );
562    }
563
564    #[tokio::test]
565    async fn receive_no_initial_response() {
566        let chunks: Vec<Result<_, IOError>> =
567            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
568        let chunk_stream = futures_util::stream::iter(chunks);
569        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
570        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
571        assert!(receiver
572            .try_recv_initial(InitialMessageType::Response)
573            .await
574            .unwrap()
575            .is_none());
576        assert_eq!(
577            TestMessage("one".into()),
578            receiver.recv().await.unwrap().unwrap()
579        );
580        assert_eq!(
581            TestMessage("two".into()),
582            receiver.recv().await.unwrap().unwrap()
583        );
584    }
585
586    fn assert_send_and_sync<T: Send + Sync>() {}
587
588    #[tokio::test]
589    async fn receiver_is_send_and_sync() {
590        assert_send_and_sync::<Receiver<(), ()>>();
591    }
592}