aws_smithy_legacy_http/event_stream/
sender.rs1use aws_smithy_eventstream::frame::{write_message_to, MarshallMessage, SignMessage};
7use aws_smithy_eventstream::message_size_hint::MessageSizeHint;
8use aws_smithy_runtime_api::client::result::SdkError;
9use aws_smithy_types::error::ErrorMetadata;
10use bytes::Bytes;
11use futures_core::Stream;
12use std::error::Error as StdError;
13use std::fmt;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use tracing::trace;
19
20pub struct EventStreamSender<T, E> {
22 input_stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>>,
23}
24
25impl<T, E> Debug for EventStreamSender<T, E> {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 let name_t = std::any::type_name::<T>();
28 let name_e = std::any::type_name::<E>();
29 write!(f, "EventStreamSender<{name_t}, {name_e}>")
30 }
31}
32
33impl<T: Send + Sync + 'static, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
34 pub fn once(item: Result<T, E>) -> Self {
36 Self::from(futures_util::stream::once(async move { item }))
37 }
38}
39
40impl<T, E: StdError + Send + Sync + 'static> EventStreamSender<T, E> {
41 #[doc(hidden)]
42 pub fn into_body_stream(
43 self,
44 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
45 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
46 signer: impl SignMessage + Send + Sync + 'static,
47 ) -> MessageStreamAdapter<T, E> {
48 MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream)
49 }
50
51 #[doc(hidden)]
53 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<T, E>> + Send + Sync>> {
54 self.input_stream
55 }
56}
57
58impl<T, E, S> From<S> for EventStreamSender<T, E>
59where
60 S: Stream<Item = Result<T, E>> + Send + Sync + 'static,
61{
62 fn from(stream: S) -> Self {
63 EventStreamSender {
64 input_stream: Box::pin(stream),
65 }
66 }
67}
68
69#[derive(Debug)]
71pub struct MessageStreamError {
72 kind: MessageStreamErrorKind,
73 pub(crate) meta: ErrorMetadata,
74}
75
76#[derive(Debug)]
77enum MessageStreamErrorKind {
78 Unhandled(Box<dyn std::error::Error + Send + Sync + 'static>),
79}
80
81impl MessageStreamError {
82 pub fn unhandled(err: impl Into<Box<dyn std::error::Error + Send + Sync + 'static>>) -> Self {
84 Self {
85 meta: Default::default(),
86 kind: MessageStreamErrorKind::Unhandled(err.into()),
87 }
88 }
89
90 pub fn generic(err: ErrorMetadata) -> Self {
92 Self {
93 meta: err.clone(),
94 kind: MessageStreamErrorKind::Unhandled(err.into()),
95 }
96 }
97
98 pub fn meta(&self) -> &ErrorMetadata {
101 &self.meta
102 }
103}
104
105impl StdError for MessageStreamError {
106 fn source(&self) -> Option<&(dyn StdError + 'static)> {
107 match &self.kind {
108 MessageStreamErrorKind::Unhandled(source) => Some(source.as_ref() as _),
109 }
110 }
111}
112
113impl fmt::Display for MessageStreamError {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 match &self.kind {
116 MessageStreamErrorKind::Unhandled(_) => write!(f, "message stream error"),
117 }
118 }
119}
120
121#[allow(missing_debug_implementations)]
127pub struct MessageStreamAdapter<T, E: StdError + Send + Sync + 'static> {
128 marshaller: Box<dyn MarshallMessage<Input = T> + Send + Sync>,
129 error_marshaller: Box<dyn MarshallMessage<Input = E> + Send + Sync>,
130 signer: Box<dyn SignMessage + Send + Sync>,
131 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
132 end_signal_sent: bool,
133 _phantom: PhantomData<E>,
134}
135
136impl<T, E: StdError + Send + Sync + 'static> Unpin for MessageStreamAdapter<T, E> {}
137
138impl<T, E: StdError + Send + Sync + 'static> MessageStreamAdapter<T, E> {
139 pub fn new(
141 marshaller: impl MarshallMessage<Input = T> + Send + Sync + 'static,
142 error_marshaller: impl MarshallMessage<Input = E> + Send + Sync + 'static,
143 signer: impl SignMessage + Send + Sync + 'static,
144 stream: Pin<Box<dyn Stream<Item = Result<T, E>> + Send>>,
145 ) -> Self {
146 MessageStreamAdapter {
147 marshaller: Box::new(marshaller),
148 error_marshaller: Box::new(error_marshaller),
149 signer: Box::new(signer),
150 stream,
151 end_signal_sent: false,
152 _phantom: Default::default(),
153 }
154 }
155}
156
157impl<T, E: StdError + Send + Sync + 'static> Stream for MessageStreamAdapter<T, E> {
158 type Item =
159 Result<Bytes, SdkError<E, aws_smithy_runtime_api::client::orchestrator::HttpResponse>>;
160
161 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162 match self.stream.as_mut().poll_next(cx) {
163 Poll::Ready(message_option) => {
164 if let Some(message_result) = message_option {
165 let message = match message_result {
166 Ok(message) => self
167 .marshaller
168 .marshall(message)
169 .map_err(SdkError::construction_failure)?,
170 Err(message) => self
171 .error_marshaller
172 .marshall(message)
173 .map_err(SdkError::construction_failure)?,
174 };
175
176 trace!(unsigned_message = ?message, "signing event stream message");
177 let message = self
178 .signer
179 .sign(message)
180 .map_err(SdkError::construction_failure)?;
181
182 let mut buffer = Vec::with_capacity(message.size_hint());
183 write_message_to(&message, &mut buffer)
184 .map_err(SdkError::construction_failure)?;
185 trace!(signed_message = ?buffer, "sending signed event stream message");
186 Poll::Ready(Some(Ok(Bytes::from(buffer))))
187 } else if !self.end_signal_sent {
188 self.end_signal_sent = true;
189 match self.signer.sign_empty() {
190 Some(sign) => {
191 let message = sign.map_err(SdkError::construction_failure)?;
192 let mut buffer = Vec::with_capacity(message.size_hint());
193 write_message_to(&message, &mut buffer)
194 .map_err(SdkError::construction_failure)?;
195 trace!(signed_message = ?buffer, "sending signed empty message to terminate the event stream");
196 Poll::Ready(Some(Ok(Bytes::from(buffer))))
197 }
198 None => Poll::Ready(None),
199 }
200 } else {
201 Poll::Ready(None)
202 }
203 }
204 Poll::Pending => Poll::Pending,
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::MarshallMessage;
212 use crate::event_stream::{EventStreamSender, MessageStreamAdapter};
213 use async_stream::stream;
214 use aws_smithy_eventstream::error::Error as EventStreamError;
215 use aws_smithy_eventstream::frame::{
216 read_message_from, write_message_to, NoOpSigner, SignMessage, SignMessageError,
217 };
218 use aws_smithy_runtime_api::client::result::SdkError;
219 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
220 use bytes::Bytes;
221 use futures_core::Stream;
222 use futures_util::stream::StreamExt;
223 use std::error::Error as StdError;
224
225 #[derive(Debug, Eq, PartialEq)]
226 struct TestMessage(String);
227
228 #[derive(Debug)]
229 struct Marshaller;
230 impl MarshallMessage for Marshaller {
231 type Input = TestMessage;
232
233 fn marshall(&self, input: Self::Input) -> Result<Message, EventStreamError> {
234 Ok(Message::new(input.0.as_bytes().to_vec()))
235 }
236 }
237 #[derive(Debug)]
238 struct ErrorMarshaller;
239 impl MarshallMessage for ErrorMarshaller {
240 type Input = TestServiceError;
241
242 fn marshall(&self, _input: Self::Input) -> Result<Message, EventStreamError> {
243 Err(read_message_from(&b""[..]).expect_err("this should always fail"))
244 }
245 }
246
247 #[derive(Debug)]
248 struct TestServiceError;
249 impl std::fmt::Display for TestServiceError {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "TestServiceError")
252 }
253 }
254 impl StdError for TestServiceError {}
255
256 #[derive(Debug)]
257 struct TestSigner;
258 impl SignMessage for TestSigner {
259 fn sign(&mut self, message: Message) -> Result<Message, SignMessageError> {
260 let mut buffer = Vec::new();
261 write_message_to(&message, &mut buffer).unwrap();
262 Ok(Message::new(buffer).add_header(Header::new("signed", HeaderValue::Bool(true))))
263 }
264
265 fn sign_empty(&mut self) -> Option<Result<Message, SignMessageError>> {
266 Some(Ok(
267 Message::new(&b""[..]).add_header(Header::new("signed", HeaderValue::Bool(true)))
268 ))
269 }
270 }
271
272 fn check_send_sync<T: Send + Sync>(value: T) -> T {
273 value
274 }
275
276 #[test]
277 fn event_stream_sender_send_sync() {
278 check_send_sync(EventStreamSender::from(stream! {
279 yield Result::<_, SignMessageError>::Ok(TestMessage("test".into()));
280 }));
281 }
282
283 fn check_compatible_with_hyper_wrap_stream<S, O, E>(stream: S) -> S
284 where
285 S: Stream<Item = Result<O, E>> + Send + 'static,
286 O: Into<Bytes> + 'static,
287 E: Into<Box<dyn StdError + Send + Sync + 'static>> + 'static,
288 {
289 stream
290 }
291
292 #[tokio::test]
293 async fn message_stream_adapter_success() {
294 let stream = stream! {
295 yield Ok(TestMessage("test".into()));
296 };
297 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
298 TestMessage,
299 TestServiceError,
300 >::new(
301 Marshaller,
302 ErrorMarshaller,
303 TestSigner,
304 Box::pin(stream),
305 ));
306
307 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
308 let sent = read_message_from(&mut sent_bytes).unwrap();
309 assert_eq!("signed", sent.headers()[0].name().as_str());
310 assert_eq!(&HeaderValue::Bool(true), sent.headers()[0].value());
311 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
312 assert_eq!(&b"test"[..], &inner.payload()[..]);
313
314 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
315 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
316 assert_eq!("signed", end_signal.headers()[0].name().as_str());
317 assert_eq!(&HeaderValue::Bool(true), end_signal.headers()[0].value());
318 assert_eq!(0, end_signal.payload().len());
319 }
320
321 #[tokio::test]
322 async fn message_stream_adapter_construction_failure() {
323 let stream = stream! {
324 yield Err(TestServiceError);
325 };
326 let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::<
327 TestMessage,
328 TestServiceError,
329 >::new(
330 Marshaller,
331 ErrorMarshaller,
332 NoOpSigner {},
333 Box::pin(stream),
334 ));
335
336 let result = adapter.next().await.unwrap();
337 assert!(result.is_err());
338 assert!(matches!(
339 result.err().unwrap(),
340 SdkError::ConstructionFailure(_)
341 ));
342 }
343
344 #[tokio::test]
345 async fn event_stream_sender_once() {
346 let sender = EventStreamSender::once(Ok(TestMessage("test".into())));
347 let mut adapter = MessageStreamAdapter::<TestMessage, TestServiceError>::new(
348 Marshaller,
349 ErrorMarshaller,
350 TestSigner,
351 sender.input_stream,
352 );
353
354 let mut sent_bytes = adapter.next().await.unwrap().unwrap();
355 let sent = read_message_from(&mut sent_bytes).unwrap();
356 assert_eq!("signed", sent.headers()[0].name().as_str());
357 let inner = read_message_from(&mut (&sent.payload()[..])).unwrap();
358 assert_eq!(&b"test"[..], &inner.payload()[..]);
359
360 let mut end_signal_bytes = adapter.next().await.unwrap().unwrap();
362 let end_signal = read_message_from(&mut end_signal_bytes).unwrap();
363 assert_eq!("signed", end_signal.headers()[0].name().as_str());
364 assert_eq!(0, end_signal.payload().len());
365
366 assert!(adapter.next().await.is_none());
368 }
369
370 #[allow(unused)]
372 fn event_stream_input_ergonomics() {
373 fn check(input: impl Into<EventStreamSender<TestMessage, TestServiceError>>) {
374 let _: EventStreamSender<TestMessage, TestServiceError> = input.into();
375 }
376 check(stream! {
377 yield Ok(TestMessage("test".into()));
378 });
379 check(stream! {
380 yield Err(TestServiceError);
381 });
382 }
383}