aws_smithy_http/event_stream/
sender.rs1use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20pub struct EventStreamSender<T, E> {
22 input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 let name_t = std::any::type_name::<T>();
28 let name_e = std::any::type_name::<E>();
29 write!(f, "EventStreamSender<{name_t}, {name_e}>")
30 }
31}
32
33impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34 pub fn once(item: Result<T, E>) -> Self {
36 Self::from(futures_util::stream::once(async move { item }))
37 }
38}
39
40impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
41 #[doc(hidden)]
42 pub fn into_body_stream(
43 self,
44 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
45 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
46 signer: impl SignMessage + Send + Sync + 'static,
47 ) -> MessageStreamAdapter<T, E> {
48 MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
49 }
50}
51
52impl<T, E, S> From<S> for EventStreamSender<T, E>
53where
54 S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
55{
56 fn from(stream: S) -> Self {
57 EventStreamSender {
58 input_stream: Box::pin(stream),
59 }
60 }
61}
62
63#[derive(Debug)]
65pub struct MessageStreamError {
66 kind: MessageStreamErrorKind,
67 pub(crate) meta: ErrorMetadata,
68}
69
70#[derive(Debug)]
71enum MessageStreamErrorKind {
72 Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
73}
74
75impl MessageStreamError {
76 pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
78 Self {
79 meta: Default::default(),
80 kind: MessageStreamErrorKind::Unhandled(err.into()),
81 }
82 }
83
84 pub fn generic(err: ErrorMetadata) -> Self {
86 Self {
87 meta: err.clone(),
88 kind: MessageStreamErrorKind::Unhandled(err.into()),
89 }
90 }
91
92 pub fn meta(&self) -> &ErrorMetadata {
95 &self.meta
96 }
97}
98
99impl StdError for MessageStreamError {
100 fn source(&self) -> Option<&(dyn StdError + 'static)> {
101 match &self.kind {
102 MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
103 }
104 }
105}
106
107impl fmt::Display for MessageStreamError {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match &self.kind {
110 MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
111 }
112 }
113}
114
115#[allow(missing_debug_implementations)]
121pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
122 marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
123 error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
124 signer: Box<dyn SignMessage + Send + Sync>,
125 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
126 end_signal_sent: bool,
127 _phantom: PhantomData<E>,
128}
129
130impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
131
132impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
133 pub fn new(
135 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
136 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
137 signer: impl SignMessage + Send + Sync + 'static,
138 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
139 ) -> Self {
140 MessageStreamAdapter {
141 marshaller: Box::new(marshaller),
142 error_marshaller: Box::new(error_marshaller),
143 signer: Box::new(signer),
144 stream,
145 end_signal_sent: false,
146 _phantom: Default::default(),
147 }
148 }
149}
150
151impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
152 type Item =
153 Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
154
155 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
156 match self.stream.as_mut().poll_next(cx) {
157 Poll::Ready(message_option) => {
158 if let Some(message_result) = message_option {
159 let message = match message_result {
160 Ok(message) => self
161 .marshaller
162 .marshall(message)
163 .map_err(SdkError::construction_failure)?,
164 Err(message) => self
165 .error_marshaller
166 .marshall(message)
167 .map_err(SdkError::construction_failure)?,
168 };
169
170 trace!(unsigned_message = ?message, "signing event stream message");
171 let message = self
172 .signer
173 .sign(message)
174 .map_err(SdkError::construction_failure)?;
175
176 let mut buffer = Vec::with_capacity(message.size_hint());
177 write_message_to(&message, &mut buffer)
178 .map_err(SdkError::construction_failure)?;
179 trace!(signed_message = ?buffer, "sending signed event stream message");
180 Poll::Ready(Some(Ok(Bytes::from(buffer))))
181 } else if !self.end_signal_sent {
182 self.end_signal_sent = true;
183 match self.signer.sign_empty() {
184 Some(sign) => {
185 let message = sign.map_err(SdkError::construction_failure)?;
186 let mut buffer = Vec::with_capacity(message.size_hint());
187 write_message_to(&message, &mut buffer)
188 .map_err(SdkError::construction_failure)?;
189 trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
190 Poll::Ready(Some(Ok(Bytes::from(buffer))))
191 }
192 None => Poll::Ready(None),
193 }
194 } else {
195 Poll::Ready(None)
196 }
197 }
198 Poll::Pending => Poll::Pending,
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::MarshallMessage;
206 use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
207 use async_stream::stream;
208 use aws_smithy_eventstream::error::Error as EventStreamError;
209 use aws_smithy_eventstream::frame::{
210 read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
211 };
212 use aws_smithy_runtime_api::client::result::SdkError;
213 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
214 use bytes::Bytes;
215 use futures_core::Stream;
216 use futures_util::stream::StreamExt;
217 use std::error::Error as StdError;
218
219 #[derive(Debug, Eq, PartialEq)]
220 struct TestMessage(String);
221
222 #[derive(Debug)]
223 struct Marshaller;
224 impl MarshallMessage for Marshaller {
225 type Input = TestMessage;
226
227 fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
228 Ok(Message::new(input.0.as_bytes().to_vec()))
229 }
230 }
231 #[derive(Debug)]
232 struct ErrorMarshaller;
233 impl MarshallMessage for ErrorMarshaller {
234 type Input = TestServiceError;
235
236 fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
237 Err(read_message_from(&b""[..]).expect_err("this should always fail"))
238 }
239 }
240
241 #[derive(Debug)]
242 struct TestServiceError;
243 impl std::fmt::Display for TestServiceError {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 write!(f, "TestServiceError")
246 }
247 }
248 impl StdError for TestServiceError {}
249
250 #[derive(Debug)]
251 struct TestSigner;
252 impl SignMessage for TestSigner {
253 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
254 let mut buffer = Vec::new();
255 write_message_to(&message, &mut buffer).unwrap();
256 Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
257 }
258
259 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
260 Some(Ok(
261 Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
262 ))
263 }
264 }
265
266 fn check_send_sync<T: Send + Sync>(value: T) -> T {
267 value
268 }
269
270 #[test]
271 fn event_stream_sender_send_sync() {
272 check_send_sync(EventStreamSender::from(stream! {
273 yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
274 }));
275 }
276
277 fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
278 where
279 S: Stream<Item = Result<O, E>> + Send + 'static,
280 O: Into<Bytes> + 'static,
281 E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
282 {
283 stream
284 }
285
286 #[tokio::test]
287 async fn message_stream_adapter_success() {
288 let stream = stream! {
289 yield Ok(TestMessage("test".into()));
290 };
291 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
292 TestMessage,
293 TestServiceError,
294 >::new(
295 Marshaller,
296 ErrorMarshaller,
297 TestSigner,
298 Box::pin(stream),
299 ));
300
301 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
302 let sent = read_message_from(&mut sent_bytes).unwrap();
303 assert_eq!("signed", sent.headers()[0].name().as_str());
304 assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
305 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
306 assert_eq!(&b"test"[..], &inner.payload()[..]);
307
308 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
309 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
310 assert_eq!("signed", end_signal.headers()[0].name().as_str());
311 assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
312 assert_eq!(0, end_signal.payload().len());
313 }
314
315 #[tokio::test]
316 async fn message_stream_adapter_construction_failure() {
317 let stream = stream! {
318 yield Err(TestServiceError);
319 };
320 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
321 TestMessage,
322 TestServiceError,
323 >::new(
324 Marshaller,
325 ErrorMarshaller,
326 NoOpSigner {},
327 Box::pin(stream),
328 ));
329
330 let result = adapter.next().await.unwrap();
331 assert!(result.is_err());
332 assert!(matches!(
333 result.err().unwrap(),
334 SdkError::ConstructionFailure(_)
335 ));
336 }
337
338 #[tokio::test]
339 async fn event_stream_sender_once() {
340 let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
341 let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
342 Marshaller,
343 ErrorMarshaller,
344 TestSigner,
345 sender.input_stream,
346 );
347
348 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
349 let sent = read_message_from(&mut sent_bytes).unwrap();
350 assert_eq!("signed", sent.headers()[0].name().as_str());
351 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
352 assert_eq!(&b"test"[..], &inner.payload()[..]);
353
354 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
356 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
357 assert_eq!("signed", end_signal.headers()[0].name().as_str());
358 assert_eq!(0, end_signal.payload().len());
359
360 assert!(adapter.next().await.is_none());
362 }
363
364 #[allow(unused)]
366 fn event_stream_input_ergonomics() {
367 fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
368 let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
369 }
370 check(stream! {
371 yield Ok(TestMessage("test".into()));
372 });
373 check(stream! {
374 yield Err(TestServiceError);
375 });
376 }
377}