1 1 | /*
|
2 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 3 | * SPDX-License-Identifier: Apache-2.0
|
4 4 | */
|
5 5 |
|
6 6 | use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
|
7 7 | use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
|
8 8 | use aws_smithy_runtime_api::client::result::SdkError;
|
9 9 | use aws_smithy_types::error::ErrorMetadata;
|
10 + | use aws_smithy_types::event_stream::Message;
|
10 11 | use bytes::Bytes;
|
11 12 | use futures_core::Stream;
|
12 13 | use std::error::Error as StdError;
|
13 14 | use std::fmt;
|
14 15 | use std::fmt::Debug;
|
15 16 | use std::marker::PhantomData;
|
16 17 | use std::pin::Pin;
|
17 18 | use std::task::{Context, Poll};
|
18 19 | use tracing::trace;
|
19 20 |
|
21 + | /// Wrapper for event stream items that may include an initial-request message.
|
22 + | /// This is used internally to allow initial messages to flow through the signing pipeline.
|
23 + | #[doc(hidden)]
|
24 + | #[derive(Debug)]
|
25 + | pub enum EventOrInitial<T> {
|
26 + | /// A regular event that needs marshalling and signing
|
27 + | Event(T),
|
28 + | /// An initial-request message that's already marshalled, just needs signing
|
29 + | InitialMessage(Message),
|
30 + | }
|
31 + |
|
20 32 | /// Input type for Event Streams.
|
21 33 | pub struct EventStreamSender<T, E> {
|
22 34 | input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
|
23 35 | }
|
24 36 |
|
25 37 | impl<T, E> Debug for EventStreamSender<T, E> {
|
26 38 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
27 39 | let name_t = std::any::type_name::<T>();
|
28 40 | let name_e = std::any::type_name::<E>();
|
29 41 | write!(f, "EventStreamSender<{name_t}, {name_e}>")
|
30 42 | }
|
31 43 | }
|
32 44 |
|
33 45 | impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
|
34 46 | /// Creates an `EventStreamSender` from a single item.
|
35 47 | pub fn once(item: Result<T, E>) -> Self {
|
36 48 | Self::from(futures_util::stream::once(async move { item }))
|
37 49 | }
|
38 50 | }
|
39 51 |
|
40 52 | impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
|
41 53 | #[doc(hidden)]
|
42 54 | pub fn into_body_stream(
|
43 55 | self,
|
44 56 | marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
|
45 57 | error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
|
46 58 | signer: impl SignMessage + Send + Sync + 'static,
|
47 59 | ) -> MessageStreamAdapter<T, E> {
|
48 60 | MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
|
49 61 | }
|
62 + |
|
63 + | /// Extract the inner stream. This is used internally for composing streams.
|
64 + | #[doc(hidden)]
|
65 + | pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
|
66 + | self.input_stream
|
67 + | }
|
50 68 | }
|
51 69 |
|
52 70 | impl<T, E, S> From<S> for EventStreamSender<T, E>
|
53 71 | where
|
54 72 | S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
|
55 73 | {
|
56 74 | fn from(stream: S) -> Self {
|
57 75 | EventStreamSender {
|
58 76 | input_stream: Box::pin(stream),
|
59 77 | }
|
60 78 | }
|
61 79 | }
|
62 80 |
|
63 81 | /// An error that occurs within a message stream.
|
64 82 | #[derive(Debug)]
|
65 83 | pub struct MessageStreamError {
|
66 84 | kind: MessageStreamErrorKind,
|
67 85 | pub(crate) meta: ErrorMetadata,
|
68 86 | }
|
69 87 |
|
70 88 | #[derive(Debug)]
|
71 89 | enum MessageStreamErrorKind {
|
72 90 | Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
|
73 91 | }
|
74 92 |
|
75 93 | impl MessageStreamError {
|
76 94 | /// Creates the `MessageStreamError::Unhandled` variant from any error type.
|
77 95 | pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
|
78 96 | Self {
|
79 97 | meta: Default::default(),
|
173 191 | .sign(message)
|
174 192 | .map_err(SdkError::construction_failure)?;
|
175 193 |
|
176 194 | let mut buffer = Vec::with_capacity(message.size_hint());
|
177 195 | write_message_to(&message, &mut buffer)
|
178 196 | .map_err(SdkError::construction_failure)?;
|
179 197 | trace!(signed_message = ?buffer, "sending signed event stream message");
|
180 198 | Poll::Ready(Some(Ok(Bytes::from(buffer))))
|
181 199 | } else if !self.end_signal_sent {
|
182 200 | self.end_signal_sent = true;
|
183 201 | match self.signer.sign_empty() {
|
184 202 | Some(sign) => {
|
185 203 | let message = sign.map_err(SdkError::construction_failure)?;
|
186 204 | let mut buffer = Vec::with_capacity(message.size_hint());
|
187 205 | write_message_to(&message, &mut buffer)
|
188 206 | .map_err(SdkError::construction_failure)?;
|
189 207 | trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
|
190 208 | Poll::Ready(Some(Ok(Bytes::from(buffer))))
|
191 209 | }
|
192 210 | None => Poll::Ready(None),
|
193 211 | }
|
194 212 | } else {
|
195 213 | Poll::Ready(None)
|
196 214 | }
|
197 215 | }
|
198 216 | Poll::Pending => Poll::Pending,
|
199 217 | }
|
200 218 | }
|
201 219 | }
|
202 220 |
|
221 + | /// Marshaller wrapper that handles both regular events and initial messages.
|
222 + | /// This is used internally to support initial-request messages in event streams.
|
223 + | #[doc(hidden)]
|
224 + | #[derive(Debug)]
|
225 + | pub struct EventOrInitialMarshaller<M> {
|
226 + | inner: M,
|
227 + | }
|
228 + |
|
229 + | impl<M> EventOrInitialMarshaller<M> {
|
230 + | #[doc(hidden)]
|
231 + | pub fn new(inner: M) -> Self {
|
232 + | Self { inner }
|
233 + | }
|
234 + | }
|
235 + |
|
236 + | impl<M, T> MarshallMessage for EventOrInitialMarshaller<M>
|
237 + | where
|
238 + | M: MarshallMessage<Input = T>,
|
239 + | {
|
240 + | type Input = EventOrInitial<T>;
|
241 + |
|
242 + | fn marshall(
|
243 + | &self,
|
244 + | input: Self::Input,
|
245 + | ) -> Result<Message, aws_smithy_eventstream::error::Error> {
|
246 + | match input {
|
247 + | EventOrInitial::Event(event) => self.inner.marshall(event),
|
248 + | EventOrInitial::InitialMessage(message) => Ok(message),
|
249 + | }
|
250 + | }
|
251 + | }
|
252 + |
|
203 253 | #[cfg(test)]
|
204 254 | mod tests {
|
205 255 | use super::MarshallMessage;
|
206 256 | use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
|
207 257 | use async_stream::stream;
|
208 258 | use aws_smithy_eventstream::error::Error as EventStreamError;
|
209 259 | use aws_smithy_eventstream::frame::{
|
210 260 | read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
|
211 261 | };
|
212 262 | use aws_smithy_runtime_api::client::result::SdkError;
|
213 263 | use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
|
214 264 | use bytes::Bytes;
|
215 265 | use futures_core::Stream;
|
216 266 | use futures_util::stream::StreamExt;
|
217 267 | use std::error::Error as StdError;
|
218 268 |
|
219 269 | #[derive(Debug, Eq, PartialEq)]
|
220 270 | struct TestMessage(String);
|
221 271 |
|
222 272 | #[derive(Debug)]
|
223 273 | struct Marshaller;
|
224 274 | impl MarshallMessage for Marshaller {
|
225 275 | type Input = TestMessage;
|
226 276 |
|
227 277 | fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
|
228 278 | Ok(Message::new(input.0.as_bytes().to_vec()))
|
229 279 | }
|
230 280 | }
|
231 281 | #[derive(Debug)]
|
232 282 | struct ErrorMarshaller;
|