AWS SDK

AWS SDK

rev. 742c376d0200593c4954dcd8fb54e3d41691067a (ignoring whitespace)

Files changed:

tmp-codegen-diff/aws-sdk/sdk/aws-smithy-legacy-http/src/event_stream/receiver.rs

@@ -1,0 +569,0 @@
    1         -
/*
    2         -
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    3         -
 * SPDX-License-Identifier: Apache-2.0
    4         -
 */
    5         -
    6         -
use aws_smithy_eventstream::frame::{
    7         -
    DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
    8         -
};
    9         -
use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
   10         -
use aws_smithy_types::body::SdkBody;
   11         -
use aws_smithy_types::event_stream::{Message, RawMessage};
   12         -
use bytes::Buf;
   13         -
use bytes::Bytes;
   14         -
use bytes_utils::SegmentedBuf;
   15         -
use std::error::Error as StdError;
   16         -
use std::fmt;
   17         -
use std::marker::PhantomData;
   18         -
use std::mem;
   19         -
use tracing::trace;
   20         -
   21         -
/// Wrapper around SegmentedBuf that tracks the state of the stream.
   22         -
#[derive(Debug)]
   23         -
enum RecvBuf {
   24         -
    /// Nothing has been buffered yet.
   25         -
    Empty,
   26         -
    /// Some data has been buffered.
   27         -
    /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
   28         -
    Partial(SegmentedBuf<Bytes>),
   29         -
    /// The end of the stream has been reached, but there may still be some buffered data.
   30         -
    EosPartial(SegmentedBuf<Bytes>),
   31         -
    /// An exception terminated this stream.
   32         -
    Terminated,
   33         -
}
   34         -
   35         -
impl RecvBuf {
   36         -
    /// Returns true if there's more buffered data.
   37         -
    fn has_data(&self) -> bool {
   38         -
        match self {
   39         -
            RecvBuf::Empty | RecvBuf::Terminated => false,
   40         -
            RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
   41         -
        }
   42         -
    }
   43         -
   44         -
    /// Returns true if the stream has ended.
   45         -
    fn is_eos(&self) -> bool {
   46         -
        matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
   47         -
    }
   48         -
   49         -
    /// Returns a mutable reference to the underlying buffered data.
   50         -
    fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
   51         -
        match self {
   52         -
            RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
   53         -
            RecvBuf::Partial(segmented) => segmented,
   54         -
            RecvBuf::EosPartial(segmented) => segmented,
   55         -
            RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
   56         -
        }
   57         -
    }
   58         -
   59         -
    /// Returns a new `RecvBuf` with additional data buffered. This will only allocate
   60         -
    /// if the `RecvBuf` was previously empty.
   61         -
    fn with_partial(self, partial: Bytes) -> Self {
   62         -
        match self {
   63         -
            RecvBuf::Empty => {
   64         -
                let mut segmented = SegmentedBuf::new();
   65         -
                segmented.push(partial);
   66         -
                RecvBuf::Partial(segmented)
   67         -
            }
   68         -
            RecvBuf::Partial(mut segmented) => {
   69         -
                segmented.push(partial);
   70         -
                RecvBuf::Partial(segmented)
   71         -
            }
   72         -
            RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
   73         -
                panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
   74         -
            }
   75         -
        }
   76         -
    }
   77         -
   78         -
    /// Returns a `RecvBuf` that has reached end of stream.
   79         -
    fn ended(self) -> Self {
   80         -
        match self {
   81         -
            RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
   82         -
            RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
   83         -
            RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
   84         -
            RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
   85         -
        }
   86         -
    }
   87         -
}
   88         -
   89         -
#[derive(Debug)]
   90         -
enum ReceiverErrorKind {
   91         -
    /// The stream ended before a complete message frame was received.
   92         -
    UnexpectedEndOfStream,
   93         -
}
   94         -
   95         -
/// An error that occurs within an event stream receiver.
   96         -
#[derive(Debug)]
   97         -
pub struct ReceiverError {
   98         -
    kind: ReceiverErrorKind,
   99         -
}
  100         -
  101         -
impl fmt::Display for ReceiverError {
  102         -
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  103         -
        match self.kind {
  104         -
            ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
  105         -
        }
  106         -
    }
  107         -
}
  108         -
  109         -
impl StdError for ReceiverError {}
  110         -
  111         -
/// Receives Smithy-modeled messages out of an Event Stream.
  112         -
#[derive(Debug)]
  113         -
pub struct Receiver<T, E> {
  114         -
    unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
  115         -
    decoder: MessageFrameDecoder,
  116         -
    buffer: RecvBuf,
  117         -
    body: SdkBody,
  118         -
    /// Event Stream has optional initial response frames an with `:message-type` of
  119         -
    /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
  120         -
    /// initial response, then the message will be stored in `buffered_message` so that it can
  121         -
    /// be returned with the next call of `recv()`.
  122         -
    buffered_message: Option<Message>,
  123         -
    _phantom: PhantomData<E>,
  124         -
}
  125         -
  126         -
// Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
  127         -
#[doc(hidden)]
  128         -
#[non_exhaustive]
  129         -
pub enum InitialMessageType {
  130         -
    Request,
  131         -
    Response,
  132         -
}
  133         -
  134         -
impl InitialMessageType {
  135         -
    fn as_str(&self) -> &'static str {
  136         -
        match self {
  137         -
            InitialMessageType::Request => "initial-request",
  138         -
            InitialMessageType::Response => "initial-response",
  139         -
        }
  140         -
    }
  141         -
}
  142         -
  143         -
impl<T, E> Receiver<T, E> {
  144         -
    /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
  145         -
    pub fn new(
  146         -
        unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
  147         -
        body: SdkBody,
  148         -
    ) -> Self {
  149         -
        Receiver {
  150         -
            unmarshaller: Box::new(unmarshaller),
  151         -
            decoder: MessageFrameDecoder::new(),
  152         -
            buffer: RecvBuf::Empty,
  153         -
            body,
  154         -
            buffered_message: None,
  155         -
            _phantom: Default::default(),
  156         -
        }
  157         -
    }
  158         -
  159         -
    fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
  160         -
        match self.unmarshaller.unmarshall(&message) {
  161         -
            Ok(unmarshalled) => match unmarshalled {
  162         -
                UnmarshalledMessage::Event(event) => Ok(Some(event)),
  163         -
                UnmarshalledMessage::Error(err) => {
  164         -
                    Err(SdkError::service_error(err, RawMessage::Decoded(message)))
  165         -
                }
  166         -
            },
  167         -
            Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
  168         -
        }
  169         -
    }
  170         -
  171         -
    async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
  172         -
        use http_body_04x::Body;
  173         -
  174         -
        if !self.buffer.is_eos() {
  175         -
            let next_chunk = self
  176         -
                .body
  177         -
                .data()
  178         -
                .await
  179         -
                .transpose()
  180         -
                .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
  181         -
            let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
  182         -
            if let Some(chunk) = next_chunk {
  183         -
                self.buffer = buffer.with_partial(chunk);
  184         -
            } else {
  185         -
                self.buffer = buffer.ended();
  186         -
            }
  187         -
        }
  188         -
        Ok(())
  189         -
    }
  190         -
  191         -
    async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
  192         -
        while !self.buffer.is_eos() {
  193         -
            if self.buffer.has_data() {
  194         -
                if let DecodedFrame::Complete(message) = self
  195         -
                    .decoder
  196         -
                    .decode_frame(self.buffer.buffered())
  197         -
                    .map_err(|err| {
  198         -
                        SdkError::response_error(
  199         -
                            err,
  200         -
                            // the buffer has been consumed
  201         -
                            RawMessage::Invalid(None),
  202         -
                        )
  203         -
                    })?
  204         -
                {
  205         -
                    trace!(message = ?message, "received complete event stream message");
  206         -
                    return Ok(Some(message));
  207         -
                }
  208         -
            }
  209         -
  210         -
            self.buffer_next_chunk().await?;
  211         -
        }
  212         -
        if self.buffer.has_data() {
  213         -
            trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
  214         -
            let buf = self.buffer.buffered();
  215         -
            return Err(SdkError::response_error(
  216         -
                ReceiverError {
  217         -
                    kind: ReceiverErrorKind::UnexpectedEndOfStream,
  218         -
                },
  219         -
                RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
  220         -
            ));
  221         -
        }
  222         -
        Ok(None)
  223         -
    }
  224         -
  225         -
    /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
  226         -
    /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
  227         -
    #[doc(hidden)]
  228         -
    pub async fn try_recv_initial(
  229         -
        &mut self,
  230         -
        message_type: InitialMessageType,
  231         -
    ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
  232         -
        if let Some(message) = self.next_message().await? {
  233         -
            if let Some(event_type) = message
  234         -
                .headers()
  235         -
                .iter()
  236         -
                .find(|h| h.name().as_str() == ":event-type")
  237         -
            {
  238         -
                if event_type
  239         -
                    .value()
  240         -
                    .as_string()
  241         -
                    .map(|s| s.as_str() == message_type.as_str())
  242         -
                    .unwrap_or(false)
  243         -
                {
  244         -
                    return Ok(Some(message));
  245         -
                }
  246         -
            }
  247         -
            // Buffer the message so that it can be returned by the next call to `recv()`
  248         -
            self.buffered_message = Some(message);
  249         -
        }
  250         -
        Ok(None)
  251         -
    }
  252         -
  253         -
    /// Asynchronously tries to receive a message from the stream. If the stream has ended,
  254         -
    /// it returns an `Ok(None)`. If there is a transport layer error, it will return
  255         -
    /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
  256         -
    /// messages.
  257         -
    pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
  258         -
        if let Some(buffered) = self.buffered_message.take() {
  259         -
            return match self.unmarshall(buffered) {
  260         -
                Ok(message) => Ok(message),
  261         -
                Err(error) => {
  262         -
                    self.buffer = RecvBuf::Terminated;
  263         -
                    Err(error)
  264         -
                }
  265         -
            };
  266         -
        }
  267         -
        if let Some(message) = self.next_message().await? {
  268         -
            match self.unmarshall(message) {
  269         -
                Ok(message) => Ok(message),
  270         -
                Err(error) => {
  271         -
                    self.buffer = RecvBuf::Terminated;
  272         -
                    Err(error)
  273         -
                }
  274         -
            }
  275         -
        } else {
  276         -
            Ok(None)
  277         -
        }
  278         -
    }
  279         -
}
  280         -
  281         -
#[cfg(test)]
  282         -
mod tests {
  283         -
    use super::{InitialMessageType, Receiver, UnmarshallMessage};
  284         -
    use aws_smithy_eventstream::error::Error as EventStreamError;
  285         -
    use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
  286         -
    use aws_smithy_runtime_api::client::result::SdkError;
  287         -
    use aws_smithy_types::body::SdkBody;
  288         -
    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
  289         -
    use bytes::Bytes;
  290         -
    use hyper::body::Body;
  291         -
    use std::error::Error as StdError;
  292         -
    use std::io::{Error as IOError, ErrorKind};
  293         -
  294         -
    fn encode_initial_response() -> Bytes {
  295         -
        let mut buffer = Vec::new();
  296         -
        let message = Message::new(Bytes::new())
  297         -
            .add_header(Header::new(
  298         -
                ":message-type",
  299         -
                HeaderValue::String("event".into()),
  300         -
            ))
  301         -
            .add_header(Header::new(
  302         -
                ":event-type",
  303         -
                HeaderValue::String("initial-response".into()),
  304         -
            ));
  305         -
        write_message_to(&message, &mut buffer).unwrap();
  306         -
        buffer.into()
  307         -
    }
  308         -
  309         -
    fn encode_message(message: &str) -> Bytes {
  310         -
        let mut buffer = Vec::new();
  311         -
        let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
  312         -
        write_message_to(&message, &mut buffer).unwrap();
  313         -
        buffer.into()
  314         -
    }
  315         -
  316         -
    #[derive(Debug)]
  317         -
    struct FakeError;
  318         -
    impl std::fmt::Display for FakeError {
  319         -
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  320         -
            write!(f, "FakeError")
  321         -
        }
  322         -
    }
  323         -
    impl StdError for FakeError {}
  324         -
  325         -
    #[derive(Debug, Eq, PartialEq)]
  326         -
    struct TestMessage(String);
  327         -
  328         -
    #[derive(Debug)]
  329         -
    struct Unmarshaller;
  330         -
    impl UnmarshallMessage for Unmarshaller {
  331         -
        type Output = TestMessage;
  332         -
        type Error = EventStreamError;
  333         -
  334         -
        fn unmarshall(
  335         -
            &self,
  336         -
            message: &Message,
  337         -
        ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
  338         -
            Ok(UnmarshalledMessage::Event(TestMessage(
  339         -
                std::str::from_utf8(&message.payload()[..]).unwrap().into(),
  340         -
            )))
  341         -
        }
  342         -
    }
  343         -
  344         -
    #[tokio::test]
  345         -
    async fn receive_success() {
  346         -
        let chunks: Vec<Result<_, IOError>> =
  347         -
            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
  348         -
        let chunk_stream = futures_util::stream::iter(chunks);
  349         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  350         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  351         -
        assert_eq!(
  352         -
            TestMessage("one".into()),
  353         -
            receiver.recv().await.unwrap().unwrap()
  354         -
        );
  355         -
        assert_eq!(
  356         -
            TestMessage("two".into()),
  357         -
            receiver.recv().await.unwrap().unwrap()
  358         -
        );
  359         -
        assert_eq!(None, receiver.recv().await.unwrap());
  360         -
    }
  361         -
  362         -
    #[tokio::test]
  363         -
    async fn receive_last_chunk_empty() {
  364         -
        let chunks: Vec<Result<_, IOError>> = vec![
  365         -
            Ok(encode_message("one")),
  366         -
            Ok(encode_message("two")),
  367         -
            Ok(Bytes::from_static(&[])),
  368         -
        ];
  369         -
        let chunk_stream = futures_util::stream::iter(chunks);
  370         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  371         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  372         -
        assert_eq!(
  373         -
            TestMessage("one".into()),
  374         -
            receiver.recv().await.unwrap().unwrap()
  375         -
        );
  376         -
        assert_eq!(
  377         -
            TestMessage("two".into()),
  378         -
            receiver.recv().await.unwrap().unwrap()
  379         -
        );
  380         -
        assert_eq!(None, receiver.recv().await.unwrap());
  381         -
    }
  382         -
  383         -
    #[tokio::test]
  384         -
    async fn receive_last_chunk_not_full_message() {
  385         -
        let chunks: Vec<Result<_, IOError>> = vec![
  386         -
            Ok(encode_message("one")),
  387         -
            Ok(encode_message("two")),
  388         -
            Ok(encode_message("three").split_to(10)),
  389         -
        ];
  390         -
        let chunk_stream = futures_util::stream::iter(chunks);
  391         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  392         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  393         -
        assert_eq!(
  394         -
            TestMessage("one".into()),
  395         -
            receiver.recv().await.unwrap().unwrap()
  396         -
        );
  397         -
        assert_eq!(
  398         -
            TestMessage("two".into()),
  399         -
            receiver.recv().await.unwrap().unwrap()
  400         -
        );
  401         -
        assert!(matches!(
  402         -
            receiver.recv().await,
  403         -
            Err(SdkError::ResponseError { .. }),
  404         -
        ));
  405         -
    }
  406         -
  407         -
    #[tokio::test]
  408         -
    async fn receive_last_chunk_has_multiple_messages() {
  409         -
        let chunks: Vec<Result<_, IOError>> = vec![
  410         -
            Ok(encode_message("one")),
  411         -
            Ok(encode_message("two")),
  412         -
            Ok(Bytes::from(
  413         -
                [encode_message("three"), encode_message("four")].concat(),
  414         -
            )),
  415         -
        ];
  416         -
        let chunk_stream = futures_util::stream::iter(chunks);
  417         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  418         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  419         -
        assert_eq!(
  420         -
            TestMessage("one".into()),
  421         -
            receiver.recv().await.unwrap().unwrap()
  422         -
        );
  423         -
        assert_eq!(
  424         -
            TestMessage("two".into()),
  425         -
            receiver.recv().await.unwrap().unwrap()
  426         -
        );
  427         -
        assert_eq!(
  428         -
            TestMessage("three".into()),
  429         -
            receiver.recv().await.unwrap().unwrap()
  430         -
        );
  431         -
        assert_eq!(
  432         -
            TestMessage("four".into()),
  433         -
            receiver.recv().await.unwrap().unwrap()
  434         -
        );
  435         -
        assert_eq!(None, receiver.recv().await.unwrap());
  436         -
    }
  437         -
  438         -
    proptest::proptest! {
  439         -
        #[test]
  440         -
        fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
  441         -
            let combined = Bytes::from([
  442         -
                encode_message("one"),
  443         -
                encode_message("two"),
  444         -
                encode_message("three"),
  445         -
                encode_message("four"),
  446         -
                encode_message("five"),
  447         -
                encode_message("six"),
  448         -
                encode_message("seven"),
  449         -
                encode_message("eight"),
  450         -
            ].concat());
  451         -
  452         -
            let midpoint = combined.len() / 2;
  453         -
            let (start, boundary1, boundary2, end) = (
  454         -
                0,
  455         -
                b1 % midpoint,
  456         -
                midpoint + b2 % midpoint,
  457         -
                combined.len()
  458         -
            );
  459         -
            println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
  460         -
  461         -
            let rt = tokio::runtime::Runtime::new().unwrap();
  462         -
            rt.block_on(async move {
  463         -
                let chunks: Vec<Result<_, IOError>> = vec![
  464         -
                    Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
  465         -
                    Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
  466         -
                    Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
  467         -
                ];
  468         -
  469         -
                let chunk_stream = futures_util::stream::iter(chunks);
  470         -
                let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  471         -
                let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  472         -
                for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
  473         -
                    assert_eq!(
  474         -
                        TestMessage((*payload).into()),
  475         -
                        receiver.recv().await.unwrap().unwrap()
  476         -
                    );
  477         -
                }
  478         -
                assert_eq!(None, receiver.recv().await.unwrap());
  479         -
            });
  480         -
        }
  481         -
    }
  482         -
  483         -
    #[tokio::test]
  484         -
    async fn receive_network_failure() {
  485         -
        let chunks: Vec<Result<_, IOError>> = vec![
  486         -
            Ok(encode_message("one")),
  487         -
            Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
  488         -
        ];
  489         -
        let chunk_stream = futures_util::stream::iter(chunks);
  490         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  491         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  492         -
        assert_eq!(
  493         -
            TestMessage("one".into()),
  494         -
            receiver.recv().await.unwrap().unwrap()
  495         -
        );
  496         -
        assert!(matches!(
  497         -
            receiver.recv().await,
  498         -
            Err(SdkError::DispatchFailure(_))
  499         -
        ));
  500         -
    }
  501         -
  502         -
    #[tokio::test]
  503         -
    async fn receive_message_parse_failure() {
  504         -
        let chunks: Vec<Result<_, IOError>> = vec![
  505         -
            Ok(encode_message("one")),
  506         -
            // A zero length message will be invalid. We need to provide a minimum of 12 bytes
  507         -
            // for the MessageFrameDecoder to actually start parsing it.
  508         -
            Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
  509         -
        ];
  510         -
        let chunk_stream = futures_util::stream::iter(chunks);
  511         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  512         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  513         -
        assert_eq!(
  514         -
            TestMessage("one".into()),
  515         -
            receiver.recv().await.unwrap().unwrap()
  516         -
        );
  517         -
        assert!(matches!(
  518         -
            receiver.recv().await,
  519         -
            Err(SdkError::ResponseError { .. })
  520         -
        ));
  521         -
    }
  522         -
  523         -
    #[tokio::test]
  524         -
    async fn receive_initial_response() {
  525         -
        let chunks: Vec<Result<_, IOError>> =
  526         -
            vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
  527         -
        let chunk_stream = futures_util::stream::iter(chunks);
  528         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  529         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  530         -
        assert!(receiver
  531         -
            .try_recv_initial(InitialMessageType::Response)
  532         -
            .await
  533         -
            .unwrap()
  534         -
            .is_some());
  535         -
        assert_eq!(
  536         -
            TestMessage("one".into()),
  537         -
            receiver.recv().await.unwrap().unwrap()
  538         -
        );
  539         -
    }
  540         -
  541         -
    #[tokio::test]
  542         -
    async fn receive_no_initial_response() {
  543         -
        let chunks: Vec<Result<_, IOError>> =
  544         -
            vec![Ok(encode_message("one")), Ok(encode_message("two"))];
  545         -
        let chunk_stream = futures_util::stream::iter(chunks);
  546         -
        let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
  547         -
        let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
  548         -
        assert!(receiver
  549         -
            .try_recv_initial(InitialMessageType::Response)
  550         -
            .await
  551         -
            .unwrap()
  552         -
            .is_none());
  553         -
        assert_eq!(
  554         -
            TestMessage("one".into()),
  555         -
            receiver.recv().await.unwrap().unwrap()
  556         -
        );
  557         -
        assert_eq!(
  558         -
            TestMessage("two".into()),
  559         -
            receiver.recv().await.unwrap().unwrap()
  560         -
        );
  561         -
    }
  562         -
  563         -
    fn assert_send_and_sync<T: Send + Sync>() {}
  564         -
  565         -
    #[tokio::test]
  566         -
    async fn receiver_is_send_and_sync() {
  567         -
        assert_send_and_sync::<Receiver<(), ()>>();
  568         -
    }
  569         -
}

tmp-codegen-diff/aws-sdk/sdk/aws-smithy-legacy-http/src/event_stream/sender.rs

@@ -1,0 +377,0 @@
    1         -
/*
    2         -
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    3         -
 * SPDX-License-Identifier: Apache-2.0
    4         -
 */
    5         -
    6         -
use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
    7         -
use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
    8         -
use aws_smithy_runtime_api::client::result::SdkError;
    9         -
use aws_smithy_types::error::ErrorMetadata;
   10         -
use bytes::Bytes;
   11         -
use futures_core::Stream;
   12         -
use std::error::Error as StdError;
   13         -
use std::fmt;
   14         -
use std::fmt::Debug;
   15         -
use std::marker::PhantomData;
   16         -
use std::pin::Pin;
   17         -
use std::task::{Context, Poll};
   18         -
use tracing::trace;
   19         -
   20         -
/// Input type for Event Streams.
   21         -
pub struct EventStreamSender<T, E> {
   22         -
    input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
   23         -
}
   24         -
   25         -
impl<T, E> Debug for EventStreamSender<T, E> {
   26         -
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
   27         -
        let name_t = std::any::type_name::<T>();
   28         -
        let name_e = std::any::type_name::<E>();
   29         -
        write!(f, "EventStreamSender<{name_t}, {name_e}>")
   30         -
    }
   31         -
}
   32         -
   33         -
impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
   34         -
    /// Creates an `EventStreamSender` from a single item.
   35         -
    pub fn once(item: Result<T, E>) -> Self {
   36         -
        Self::from(futures_util::stream::once(async move { item }))
   37         -
    }
   38         -
}
   39         -
   40         -
impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
   41         -
    #[doc(hidden)]
   42         -
    pub fn into_body_stream(
   43         -
        self,
   44         -
        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
   45         -
        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
   46         -
        signer: impl SignMessage + Send + Sync + 'static,
   47         -
    ) -> MessageStreamAdapter<T, E> {
   48         -
        MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
   49         -
    }
   50         -
}
   51         -
   52         -
impl<T, E, S> From<S> for EventStreamSender<T, E>
   53         -
where
   54         -
    S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
   55         -
{
   56         -
    fn from(stream: S) -> Self {
   57         -
        EventStreamSender {
   58         -
            input_stream: Box::pin(stream),
   59         -
        }
   60         -
    }
   61         -
}
   62         -
   63         -
/// An error that occurs within a message stream.
   64         -
#[derive(Debug)]
   65         -
pub struct MessageStreamError {
   66         -
    kind: MessageStreamErrorKind,
   67         -
    pub(crate) meta: ErrorMetadata,
   68         -
}
   69         -
   70         -
#[derive(Debug)]
   71         -
enum MessageStreamErrorKind {
   72         -
    Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
   73         -
}
   74         -
   75         -
impl MessageStreamError {
   76         -
    /// Creates the `MessageStreamError::Unhandled` variant from any error type.
   77         -
    pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
   78         -
        Self {
   79         -
            meta: Default::default(),
   80         -
            kind: MessageStreamErrorKind::Unhandled(err.into()),
   81         -
        }
   82         -
    }
   83         -
   84         -
    /// Creates the `MessageStreamError::Unhandled` variant from an [`ErrorMetadata`].
   85         -
    pub fn generic(err: ErrorMetadata) -> Self {
   86         -
        Self {
   87         -
            meta: err.clone(),
   88         -
            kind: MessageStreamErrorKind::Unhandled(err.into()),
   89         -
        }
   90         -
    }
   91         -
   92         -
    /// Returns error metadata, which includes the error code, message,
   93         -
    /// request ID, and potentially additional information.
   94         -
    pub fn meta(&self) -> &ErrorMetadata {
   95         -
        &self.meta
   96         -
    }
   97         -
}
   98         -
   99         -
impl StdError for MessageStreamError {
  100         -
    fn source(&self) -> Option<&(dyn StdError + 'static)> {
  101         -
        match &self.kind {
  102         -
            MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
  103         -
        }
  104         -
    }
  105         -
}
  106         -
  107         -
impl fmt::Display for MessageStreamError {
  108         -
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  109         -
        match &self.kind {
  110         -
            MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
  111         -
        }
  112         -
    }
  113         -
}
  114         -
  115         -
/// Adapts a `Stream<SmithyMessageType>` to a signed `Stream<Bytes>` by using the provided
  116         -
/// message marshaller and signer implementations.
  117         -
///
  118         -
/// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be
  119         -
/// marshalled into an Event Stream frame, (e.g., if the message payload was too large).
  120         -
#[allow(missing_debug_implementations)]
  121         -
pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
  122         -
    marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
  123         -
    error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
  124         -
    signer: Box<dyn SignMessage + Send + Sync>,
  125         -
    stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
  126         -
    end_signal_sent: bool,
  127         -
    _phantom: PhantomData<E>,
  128         -
}
  129         -
  130         -
impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
  131         -
  132         -
impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
  133         -
    /// Create a new `MessageStreamAdapter`.
  134         -
    pub fn new(
  135         -
        marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
  136         -
        error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
  137         -
        signer: impl SignMessage + Send + Sync + 'static,
  138         -
        stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
  139         -
    ) -> Self {
  140         -
        MessageStreamAdapter {
  141         -
            marshaller: Box::new(marshaller),
  142         -
            error_marshaller: Box::new(error_marshaller),
  143         -
            signer: Box::new(signer),
  144         -
            stream,
  145         -
            end_signal_sent: false,
  146         -
            _phantom: Default::default(),
  147         -
        }
  148         -
    }
  149         -
}
  150         -
  151         -
impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
  152         -
    type Item =
  153         -
        Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
  154         -
  155         -
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
  156         -
        match self.stream.as_mut().poll_next(cx) {
  157         -
            Poll::Ready(message_option) => {
  158         -
                if let Some(message_result) = message_option {
  159         -
                    let message = match message_result {
  160         -
                        Ok(message) => self
  161         -
                            .marshaller
  162         -
                            .marshall(message)
  163         -
                            .map_err(SdkError::construction_failure)?,
  164         -
                        Err(message) => self
  165         -
                            .error_marshaller
  166         -
                            .marshall(message)
  167         -
                            .map_err(SdkError::construction_failure)?,
  168         -
                    };
  169         -
  170         -
                    trace!(unsigned_message = ?message, "signing event stream message");
  171         -
                    let message = self
  172         -
                        .signer
  173         -
                        .sign(message)
  174         -
                        .map_err(SdkError::construction_failure)?;
  175         -
  176         -
                    let mut buffer = Vec::with_capacity(message.size_hint());
  177         -
                    write_message_to(&message, &mut buffer)
  178         -
                        .map_err(SdkError::construction_failure)?;
  179         -
                    trace!(signed_message = ?buffer, "sending signed event stream message");
  180         -
                    Poll::Ready(Some(Ok(Bytes::from(buffer))))
  181         -
                } else if !self.end_signal_sent {
  182         -
                    self.end_signal_sent = true;
  183         -
                    match self.signer.sign_empty() {
  184         -
                        Some(sign) => {
  185         -
                            let message = sign.map_err(SdkError::construction_failure)?;
  186         -
                            let mut buffer = Vec::with_capacity(message.size_hint());
  187         -
                            write_message_to(&message, &mut buffer)
  188         -
                                .map_err(SdkError::construction_failure)?;
  189         -
                            trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
  190         -
                            Poll::Ready(Some(Ok(Bytes::from(buffer))))
  191         -
                        }
  192         -
                        None => Poll::Ready(None),
  193         -
                    }
  194         -
                } else {
  195         -
                    Poll::Ready(None)
  196         -
                }
  197         -
            }
  198         -
            Poll::Pending => Poll::Pending,
  199         -
        }
  200         -
    }
  201         -
}
  202         -
  203         -
#[cfg(test)]
  204         -
mod tests {
  205         -
    use super::MarshallMessage;
  206         -
    use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
  207         -
    use async_stream::stream;
  208         -
    use aws_smithy_eventstream::error::Error as EventStreamError;
  209         -
    use aws_smithy_eventstream::frame::{
  210         -
        read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
  211         -
    };
  212         -
    use aws_smithy_runtime_api::client::result::SdkError;
  213         -
    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
  214         -
    use bytes::Bytes;
  215         -
    use futures_core::Stream;
  216         -
    use futures_util::stream::StreamExt;
  217         -
    use std::error::Error as StdError;
  218         -
  219         -
    #[derive(Debug, Eq, PartialEq)]
  220         -
    struct TestMessage(String);
  221         -
  222         -
    #[derive(Debug)]
  223         -
    struct Marshaller;
  224         -
    impl MarshallMessage for Marshaller {
  225         -
        type Input = TestMessage;
  226         -
  227         -
        fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
  228         -
            Ok(Message::new(input.0.as_bytes().to_vec()))
  229         -
        }
  230         -
    }
  231         -
    #[derive(Debug)]
  232         -
    struct ErrorMarshaller;
  233         -
    impl MarshallMessage for ErrorMarshaller {
  234         -
        type Input = TestServiceError;
  235         -
  236         -
        fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
  237         -
            Err(read_message_from(&b""[..]).expect_err("this should always fail"))
  238         -
        }
  239         -
    }
  240         -
  241         -
    #[derive(Debug)]
  242         -
    struct TestServiceError;
  243         -
    impl std::fmt::Display for TestServiceError {
  244         -
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  245         -
            write!(f, "TestServiceError")
  246         -
        }
  247         -
    }
  248         -
    impl StdError for TestServiceError {}
  249         -
  250         -
    #[derive(Debug)]
  251         -
    struct TestSigner;
  252         -
    impl SignMessage for TestSigner {
  253         -
        fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
  254         -
            let mut buffer = Vec::new();
  255         -
            write_message_to(&message, &mut buffer).unwrap();
  256         -
            Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
  257         -
        }
  258         -
  259         -
        fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
  260         -
            Some(Ok(
  261         -
                Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
  262         -
            ))
  263         -
        }
  264         -
    }
  265         -
  266         -
    fn check_send_sync<T: Send + Sync>(value: T) -> T {
  267         -
        value
  268         -
    }
  269         -
  270         -
    #[test]
  271         -
    fn event_stream_sender_send_sync() {
  272         -
        check_send_sync(EventStreamSender::from(stream! {
  273         -
            yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
  274         -
        }));
  275         -
    }
  276         -
  277         -
    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
  278         -
    where
  279         -
        S: Stream<Item = Result<O, E>> + Send + 'static,
  280         -
        O: Into<Bytes> + 'static,
  281         -
        E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
  282         -
    {
  283         -
        stream
  284         -
    }
  285         -
  286         -
    #[tokio::test]
  287         -
    async fn message_stream_adapter_success() {
  288         -
        let stream = stream! {
  289         -
            yield Ok(TestMessage("test".into()));
  290         -
        };
  291         -
        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
  292         -
            TestMessage,
  293         -
            TestServiceError,
  294         -
        >::new(
  295         -
            Marshaller,
  296         -
            ErrorMarshaller,
  297         -
            TestSigner,
  298         -
            Box::pin(stream),
  299         -
        ));
  300         -
  301         -
        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
  302         -
        let sent = read_message_from(&mut sent_bytes).unwrap();
  303         -
        assert_eq!("signed", sent.headers()[0].name().as_str());
  304         -
        assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
  305         -
        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
  306         -
        assert_eq!(&b"test"[..], &inner.payload()[..]);
  307         -
  308         -
        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
  309         -
        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
  310         -
        assert_eq!("signed", end_signal.headers()[0].name().as_str());
  311         -
        assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
  312         -
        assert_eq!(0, end_signal.payload().len());
  313         -
    }
  314         -
  315         -
    #[tokio::test]
  316         -
    async fn message_stream_adapter_construction_failure() {
  317         -
        let stream = stream! {
  318         -
            yield Err(TestServiceError);
  319         -
        };
  320         -
        let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
  321         -
            TestMessage,
  322         -
            TestServiceError,
  323         -
        >::new(
  324         -
            Marshaller,
  325         -
            ErrorMarshaller,
  326         -
            NoOpSigner {},
  327         -
            Box::pin(stream),
  328         -
        ));
  329         -
  330         -
        let result = adapter.next().await.unwrap();
  331         -
        assert!(result.is_err());
  332         -
        assert!(matches!(
  333         -
            result.err().unwrap(),
  334         -
            SdkError::ConstructionFailure(_)
  335         -
        ));
  336         -
    }
  337         -
  338         -
    #[tokio::test]
  339         -
    async fn event_stream_sender_once() {
  340         -
        let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
  341         -
        let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
  342         -
            Marshaller,
  343         -
            ErrorMarshaller,
  344         -
            TestSigner,
  345         -
            sender.input_stream,
  346         -
        );
  347         -
  348         -
        let mut sent_bytes = adapter.next().await.unwrap().unwrap();
  349         -
        let sent = read_message_from(&mut sent_bytes).unwrap();
  350         -
        assert_eq!("signed", sent.headers()[0].name().as_str());
  351         -
        let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
  352         -
        assert_eq!(&b"test"[..], &inner.payload()[..]);
  353         -
  354         -
        // Should get end signal next
  355         -
        let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
  356         -
        let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
  357         -
        assert_eq!("signed", end_signal.headers()[0].name().as_str());
  358         -
        assert_eq!(0, end_signal.payload().len());
  359         -
  360         -
        // Stream should be exhausted
  361         -
        assert!(adapter.next().await.is_none());
  362         -
    }
  363         -
  364         -
    // Verify the developer experience for this compiles
  365         -
    #[allow(unused)]
  366         -
    fn event_stream_input_ergonomics() {
  367         -
        fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
  368         -
            let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
  369         -
        }
  370         -
        check(stream! {
  371         -
            yield Ok(TestMessage("test".into()));
  372         -
        });
  373         -
        check(stream! {
  374         -
            yield Err(TestServiceError);
  375         -
        });
  376         -
    }
  377         -
}

tmp-codegen-diff/aws-sdk/sdk/aws-smithy-legacy-http/src/futures_stream_adapter.rs

@@ -1,0 +62,0 @@
    1         -
/*
    2         -
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
    3         -
 * SPDX-License-Identifier: Apache-2.0
    4         -
 */
    5         -
    6         -
use aws_smithy_types::body::SdkBody;
    7         -
use aws_smithy_types::byte_stream::error::Error as ByteStreamError;
    8         -
use aws_smithy_types::byte_stream::ByteStream;
    9         -
use bytes::Bytes;
   10         -
use futures_core::stream::Stream;
   11         -
use std::pin::Pin;
   12         -
use std::task::{Context, Poll};
   13         -
   14         -
/// A new-type wrapper to enable the impl of the `futures_core::stream::Stream` trait
   15         -
///
   16         -
/// [`ByteStream`] no longer implements `futures_core::stream::Stream` so we wrap it in the
   17         -
/// new-type to enable the trait when it is required.
   18         -
///
   19         -
/// This is meant to be used by codegen code, and users should not need to use it directly.
   20         -
#[derive(Debug)]
   21         -
pub struct FuturesStreamCompatByteStream(ByteStream);
   22         -
   23         -
impl FuturesStreamCompatByteStream {
   24         -
    /// Creates a new `FuturesStreamCompatByteStream` by wrapping `stream`.
   25         -
    pub fn new(stream: ByteStream) -> Self {
   26         -
        Self(stream)
   27         -
    }
   28         -
   29         -
    /// Returns [`SdkBody`] of the wrapped [`ByteStream`].
   30         -
    pub fn into_inner(self) -> SdkBody {
   31         -
        self.0.into_inner()
   32         -
    }
   33         -
}
   34         -
   35         -
impl Stream for FuturesStreamCompatByteStream {
   36         -
    type Item = Result<Bytes, ByteStreamError>;
   37         -
   38         -
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
   39         -
        Pin::new(&mut self.0).poll_next(cx)
   40         -
    }
   41         -
}
   42         -
   43         -
#[cfg(test)]
   44         -
mod tests {
   45         -
    use super::*;
   46         -
    use futures_core::stream::Stream;
   47         -
   48         -
    fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
   49         -
    where
   50         -
        S: Stream<Item = Result<O, E>> + Send + 'static,
   51         -
        O: Into<Bytes> + 'static,
   52         -
        E: Into<Box<dyn std::error::Error + Send + Sync + 'static>> + 'static,
   53         -
    {
   54         -
        stream
   55         -
    }
   56         -
   57         -
    #[test]
   58         -
    fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() {
   59         -
        let stream = ByteStream::from_static(b"Hello world");
   60         -
        check_compatible_with_hyper_wrap_stream(FuturesStreamCompatByteStream::new(stream));
   61         -
    }
   62         -
}