aws_smithy_http/event_stream/
sender.rs1use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use aws_smithy_types::event_stream::Message;
11use bytes::Bytes;
12use futures_core::Stream;
13use std::error::Error as StdError;
14use std::fmt;
15use std::fmt::Debug;
16use std::marker::PhantomData;
17use std::pin::Pin;
18use std::task::{Context, Poll};
19use tracing::trace;
20
21#[doc(hidden)]
24#[derive(Debug)]
25pub enum EventOrInitial<T> {
26 Event(T),
28 InitialMessage(Message),
30}
31
32pub struct EventStreamSender<T, E> {
34 input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
35}
36
37impl<T, E> Debug for EventStreamSender<T, E> {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 let name_t = std::any::type_name::<T>();
40 let name_e = std::any::type_name::<E>();
41 write!(f, "EventStreamSender<{name_t}, {name_e}>")
42 }
43}
44
45impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
46 pub fn once(item: Result<T, E>) -> Self {
48 Self::from(futures_util::stream::once(async move { item }))
49 }
50}
51
52impl<T: Send + Sync, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
53 #[doc(hidden)]
54 pub fn into_body_stream(
55 self,
56 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
57 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
58 signer: impl SignMessage + Send + Sync + 'static,
59 ) -> MessageStreamAdapter<T, E> {
60 MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
61 }
62
63 #[doc(hidden)]
65 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
66 self.input_stream
67 }
68}
69
70impl<T, E, S> From<S> for EventStreamSender<T, E>
71where
72 S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
73{
74 fn from(stream: S) -> Self {
75 EventStreamSender {
76 input_stream: Box::pin(stream),
77 }
78 }
79}
80
81#[derive(Debug)]
83pub struct MessageStreamError {
84 kind: MessageStreamErrorKind,
85 pub(crate) meta: ErrorMetadata,
86}
87
88#[derive(Debug)]
89enum MessageStreamErrorKind {
90 Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
91}
92
93impl MessageStreamError {
94 pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
96 Self {
97 meta: Default::default(),
98 kind: MessageStreamErrorKind::Unhandled(err.into()),
99 }
100 }
101
102 pub fn generic(err: ErrorMetadata) -> Self {
104 Self {
105 meta: err.clone(),
106 kind: MessageStreamErrorKind::Unhandled(err.into()),
107 }
108 }
109
110 pub fn meta(&self) -> &ErrorMetadata {
113 &self.meta
114 }
115}
116
117impl StdError for MessageStreamError {
118 fn source(&self) -> Option<&(dyn StdError + 'static)> {
119 match &self.kind {
120 MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
121 }
122 }
123}
124
125impl fmt::Display for MessageStreamError {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match &self.kind {
128 MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
129 }
130 }
131}
132
133#[allow(missing_debug_implementations)]
139pub struct MessageStreamAdapter<T: Send + Sync, E: StdError + Send + Sync + 'static> {
140 marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
141 error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
142 signer: Box<dyn SignMessage + Send + Sync>,
143 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
144 end_signal_sent: bool,
145 _phantom: PhantomData<E>,
146}
147
148impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
149
150impl<T: Send + Sync, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
151 pub fn new(
153 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
154 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
155 signer: impl SignMessage + Send + Sync + 'static,
156 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
157 ) -> Self {
158 MessageStreamAdapter {
159 marshaller: Box::new(marshaller),
160 error_marshaller: Box::new(error_marshaller),
161 signer: Box::new(signer),
162 stream,
163 end_signal_sent: false,
164 _phantom: Default::default(),
165 }
166 }
167}
168
169impl<T: Send + Sync, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
170 type Item = Result<
171 http_body_1x::Frame<Bytes>,
172 SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>,
173 >;
174
175 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
176 match self.stream.as_mut().poll_next(cx) {
177 Poll::Ready(message_option) => {
178 if let Some(message_result) = message_option {
179 let message = match message_result {
180 Ok(message) => self
181 .marshaller
182 .marshall(message)
183 .map_err(SdkError::construction_failure)?,
184 Err(message) => self
185 .error_marshaller
186 .marshall(message)
187 .map_err(SdkError::construction_failure)?,
188 };
189
190 trace!(unsigned_message = ?message, "signing event stream message");
191 let message = self
192 .signer
193 .sign(message)
194 .map_err(SdkError::construction_failure)?;
195
196 let mut buffer = Vec::with_capacity(message.size_hint());
197 write_message_to(&message, &mut buffer)
198 .map_err(SdkError::construction_failure)?;
199 trace!(signed_message = ?buffer, "sending signed event stream message");
200 Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
201 } else if !self.end_signal_sent {
202 self.end_signal_sent = true;
203 match self.signer.sign_empty() {
204 Some(sign) => {
205 let message = sign.map_err(SdkError::construction_failure)?;
206 let mut buffer = Vec::with_capacity(message.size_hint());
207 write_message_to(&message, &mut buffer)
208 .map_err(SdkError::construction_failure)?;
209 trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
210 Poll::Ready(Some(Ok(http_body_1x::Frame::data(Bytes::from(buffer)))))
211 }
212 None => Poll::Ready(None),
213 }
214 } else {
215 Poll::Ready(None)
216 }
217 }
218 Poll::Pending => Poll::Pending,
219 }
220 }
221}
222
223#[doc(hidden)]
226#[derive(Debug)]
227pub struct EventOrInitialMarshaller<M> {
228 inner: M,
229}
230
231impl<M> EventOrInitialMarshaller<M> {
232 #[doc(hidden)]
233 pub fn new(inner: M) -> Self {
234 Self { inner }
235 }
236}
237
238impl<M, T> MarshallMessage for EventOrInitialMarshaller<M>
239where
240 M: MarshallMessage<Input = T>,
241{
242 type Input = EventOrInitial<T>;
243
244 fn marshall(
245 &self,
246 input: Self::Input,
247 ) -> Result<Message, aws_smithy_eventstream::error::Error> {
248 match input {
249 EventOrInitial::Event(event) => self.inner.marshall(event),
250 EventOrInitial::InitialMessage(message) => Ok(message),
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::MarshallMessage;
258 use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
259 use async_stream::stream;
260 use aws_smithy_eventstream::error::Error as EventStreamError;
261 use aws_smithy_eventstream::frame::{
262 read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
263 };
264 use aws_smithy_runtime_api::client::result::SdkError;
265 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
266 use futures_util::stream::StreamExt;
267 use std::error::Error as StdError;
268
269 #[derive(Debug, Eq, PartialEq)]
270 struct TestMessage(String);
271
272 #[derive(Debug)]
273 struct Marshaller;
274 impl MarshallMessage for Marshaller {
275 type Input = TestMessage;
276
277 fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
278 Ok(Message::new(input.0.as_bytes().to_vec()))
279 }
280 }
281 #[derive(Debug)]
282 struct ErrorMarshaller;
283 impl MarshallMessage for ErrorMarshaller {
284 type Input = TestServiceError;
285
286 fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
287 Err(read_message_from(&b""[..]).expect_err("this should always fail"))
288 }
289 }
290
291 #[derive(Debug)]
292 struct TestServiceError;
293 impl std::fmt::Display for TestServiceError {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 write!(f, "TestServiceError")
296 }
297 }
298 impl StdError for TestServiceError {}
299
300 #[derive(Debug)]
301 struct TestSigner;
302 impl SignMessage for TestSigner {
303 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
304 let mut buffer = Vec::new();
305 write_message_to(&message, &mut buffer).unwrap();
306 Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
307 }
308
309 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
310 Some(Ok(
311 Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
312 ))
313 }
314 }
315
316 fn check_send_sync<T: Send + Sync>(value: T) -> T {
317 value
318 }
319
320 #[test]
321 fn event_stream_sender_send_sync() {
322 check_send_sync(EventStreamSender::from(stream! {
323 yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
324 }));
325 }
326
327 #[tokio::test]
328 async fn message_stream_adapter_success() {
329 let stream = stream! {
330 yield Ok(TestMessage("test".into()));
331 };
332 let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
333 Marshaller,
334 ErrorMarshaller,
335 TestSigner,
336 Box::pin(stream),
337 );
338
339 let sent_bytes = adapter.next().await.unwrap().unwrap();
340 let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
341 assert_eq!("signed", sent.headers()[0].name().as_str());
342 assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
343 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
344 assert_eq!(&b"test"[..], &inner.payload()[..]);
345
346 let end_signal_bytes = adapter.next().await.unwrap().unwrap();
347 let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
348 assert_eq!("signed", end_signal.headers()[0].name().as_str());
349 assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
350 assert_eq!(0, end_signal.payload().len());
351 }
352
353 #[tokio::test]
354 async fn message_stream_adapter_construction_failure() {
355 let stream = stream! {
356 yield Err(TestServiceError);
357 };
358 let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
359 Marshaller,
360 ErrorMarshaller,
361 NoOpSigner {},
362 Box::pin(stream),
363 );
364
365 let result = adapter.next().await.unwrap();
366 assert!(result.is_err());
367 assert!(matches!(
368 result.err().unwrap(),
369 SdkError::ConstructionFailure(_)
370 ));
371 }
372
373 #[tokio::test]
374 async fn event_stream_sender_once() {
375 let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
376 let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
377 Marshaller,
378 ErrorMarshaller,
379 TestSigner,
380 sender.input_stream,
381 );
382
383 let sent_bytes = adapter.next().await.unwrap().unwrap();
384 let sent = read_message_from(&mut sent_bytes.into_data().unwrap()).unwrap();
385 assert_eq!("signed", sent.headers()[0].name().as_str());
386 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
387 assert_eq!(&b"test"[..], &inner.payload()[..]);
388
389 let end_signal_bytes = adapter.next().await.unwrap().unwrap();
391 let end_signal = read_message_from(&mut end_signal_bytes.into_data().unwrap()).unwrap();
392 assert_eq!("signed", end_signal.headers()[0].name().as_str());
393 assert_eq!(0, end_signal.payload().len());
394
395 assert!(adapter.next().await.is_none());
397 }
398
399 #[allow(unused)]
401 fn event_stream_input_ergonomics() {
402 fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
403 let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
404 }
405 check(stream! {
406 yield Ok(TestMessage("test".into()));
407 });
408 check(stream! {
409 yield Err(TestServiceError);
410 });
411 }
412}