aws_smithy_http/event_stream/
sender.rs1use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20pub struct EventStreamSender<T, E> {
22 input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 let name_t = std::any::type_name::<T>();
28 let name_e = std::any::type_name::<E>();
29 write!(f, "EventStreamSender<{name_t}, {name_e}>")
30 }
31}
32
33impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34 #[doc(hidden)]
35 pub fn into_body_stream(
36 self,
37 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
38 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
39 signer: impl SignMessage + Send + Sync + 'static,
40 ) -> MessageStreamAdapter<T, E> {
41 MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
42 }
43}
44
45impl<T, E, S> From<S> for EventStreamSender<T, E>
46where
47 S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
48{
49 fn from(stream: S) -> Self {
50 EventStreamSender {
51 input_stream: Box::pin(stream),
52 }
53 }
54}
55
56#[derive(Debug)]
58pub struct MessageStreamError {
59 kind: MessageStreamErrorKind,
60 pub(crate) meta: ErrorMetadata,
61}
62
63#[derive(Debug)]
64enum MessageStreamErrorKind {
65 Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
66}
67
68impl MessageStreamError {
69 pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
71 Self {
72 meta: Default::default(),
73 kind: MessageStreamErrorKind::Unhandled(err.into()),
74 }
75 }
76
77 pub fn generic(err: ErrorMetadata) -> Self {
79 Self {
80 meta: err.clone(),
81 kind: MessageStreamErrorKind::Unhandled(err.into()),
82 }
83 }
84
85 pub fn meta(&self) -> &ErrorMetadata {
88 &self.meta
89 }
90}
91
92impl StdError for MessageStreamError {
93 fn source(&self) -> Option<&(dyn StdError + 'static)> {
94 match &self.kind {
95 MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
96 }
97 }
98}
99
100impl fmt::Display for MessageStreamError {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match &self.kind {
103 MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
104 }
105 }
106}
107
108#[allow(missing_debug_implementations)]
114pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
115 marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
116 error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
117 signer: Box<dyn SignMessage + Send + Sync>,
118 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
119 end_signal_sent: bool,
120 _phantom: PhantomData<E>,
121}
122
123impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
124
125impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
126 pub fn new(
128 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
129 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
130 signer: impl SignMessage + Send + Sync + 'static,
131 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
132 ) -> Self {
133 MessageStreamAdapter {
134 marshaller: Box::new(marshaller),
135 error_marshaller: Box::new(error_marshaller),
136 signer: Box::new(signer),
137 stream,
138 end_signal_sent: false,
139 _phantom: Default::default(),
140 }
141 }
142}
143
144impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
145 type Item =
146 Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
147
148 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 match self.stream.as_mut().poll_next(cx) {
150 Poll::Ready(message_option) => {
151 if let Some(message_result) = message_option {
152 let message = match message_result {
153 Ok(message) => self
154 .marshaller
155 .marshall(message)
156 .map_err(SdkError::construction_failure)?,
157 Err(message) => self
158 .error_marshaller
159 .marshall(message)
160 .map_err(SdkError::construction_failure)?,
161 };
162
163 trace!(unsigned_message = ?message, "signing event stream message");
164 let message = self
165 .signer
166 .sign(message)
167 .map_err(SdkError::construction_failure)?;
168
169 let mut buffer = Vec::with_capacity(message.size_hint());
170 write_message_to(&message, &mut buffer)
171 .map_err(SdkError::construction_failure)?;
172 trace!(signed_message = ?buffer, "sending signed event stream message");
173 Poll::Ready(Some(Ok(Bytes::from(buffer))))
174 } else if !self.end_signal_sent {
175 self.end_signal_sent = true;
176 match self.signer.sign_empty() {
177 Some(sign) => {
178 let message = sign.map_err(SdkError::construction_failure)?;
179 let mut buffer = Vec::with_capacity(message.size_hint());
180 write_message_to(&message, &mut buffer)
181 .map_err(SdkError::construction_failure)?;
182 trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
183 Poll::Ready(Some(Ok(Bytes::from(buffer))))
184 }
185 None => Poll::Ready(None),
186 }
187 } else {
188 Poll::Ready(None)
189 }
190 }
191 Poll::Pending => Poll::Pending,
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::MarshallMessage;
199 use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
200 use async_stream::stream;
201 use aws_smithy_eventstream::error::Error as EventStreamError;
202 use aws_smithy_eventstream::frame::{
203 read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
204 };
205 use aws_smithy_runtime_api::client::result::SdkError;
206 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
207 use bytes::Bytes;
208 use futures_core::Stream;
209 use futures_util::stream::StreamExt;
210 use std::error::Error as StdError;
211
212 #[derive(Debug, Eq, PartialEq)]
213 struct TestMessage(String);
214
215 #[derive(Debug)]
216 struct Marshaller;
217 impl MarshallMessage for Marshaller {
218 type Input = TestMessage;
219
220 fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
221 Ok(Message::new(input.0.as_bytes().to_vec()))
222 }
223 }
224 #[derive(Debug)]
225 struct ErrorMarshaller;
226 impl MarshallMessage for ErrorMarshaller {
227 type Input = TestServiceError;
228
229 fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
230 Err(read_message_from(&b""[..]).expect_err("this should always fail"))
231 }
232 }
233
234 #[derive(Debug)]
235 struct TestServiceError;
236 impl std::fmt::Display for TestServiceError {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(f, "TestServiceError")
239 }
240 }
241 impl StdError for TestServiceError {}
242
243 #[derive(Debug)]
244 struct TestSigner;
245 impl SignMessage for TestSigner {
246 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
247 let mut buffer = Vec::new();
248 write_message_to(&message, &mut buffer).unwrap();
249 Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
250 }
251
252 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
253 Some(Ok(
254 Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
255 ))
256 }
257 }
258
259 fn check_send_sync<T: Send + Sync>(value: T) -> T {
260 value
261 }
262
263 #[test]
264 fn event_stream_sender_send_sync() {
265 check_send_sync(EventStreamSender::from(stream! {
266 yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
267 }));
268 }
269
270 fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
271 where
272 S: Stream<Item = Result<O, E>> + Send + 'static,
273 O: Into<Bytes> + 'static,
274 E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
275 {
276 stream
277 }
278
279 #[tokio::test]
280 async fn message_stream_adapter_success() {
281 let stream = stream! {
282 yield Ok(TestMessage("test".into()));
283 };
284 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
285 TestMessage,
286 TestServiceError,
287 >::new(
288 Marshaller,
289 ErrorMarshaller,
290 TestSigner,
291 Box::pin(stream),
292 ));
293
294 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
295 let sent = read_message_from(&mut sent_bytes).unwrap();
296 assert_eq!("signed", sent.headers()[0].name().as_str());
297 assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
298 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
299 assert_eq!(&b"test"[..], &inner.payload()[..]);
300
301 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
302 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
303 assert_eq!("signed", end_signal.headers()[0].name().as_str());
304 assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
305 assert_eq!(0, end_signal.payload().len());
306 }
307
308 #[tokio::test]
309 async fn message_stream_adapter_construction_failure() {
310 let stream = stream! {
311 yield Err(TestServiceError);
312 };
313 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
314 TestMessage,
315 TestServiceError,
316 >::new(
317 Marshaller,
318 ErrorMarshaller,
319 NoOpSigner {},
320 Box::pin(stream),
321 ));
322
323 let result = adapter.next().await.unwrap();
324 assert!(result.is_err());
325 assert!(matches!(
326 result.err().unwrap(),
327 SdkError::ConstructionFailure(_)
328 ));
329 }
330
331 #[allow(unused)]
333 fn event_stream_input_ergonomics() {
334 fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
335 let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
336 }
337 check(stream! {
338 yield Ok(TestMessage("test".into()));
339 });
340 check(stream! {
341 yield Err(TestServiceError);
342 });
343 }
344}