1use aws_smithy_eventstream::frame::{
7 DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21#[derive(Debug)]
23enum RecvBuf {
24 Empty,
26 Partial(SegmentedBuf<Bytes>),
29 EosPartial(SegmentedBuf<Bytes>),
31 Terminated,
33}
34
35impl RecvBuf {
36 fn has_data(&self) -> bool {
38 match self {
39 RecvBuf::Empty | RecvBuf::Terminated => false,
40 RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41 }
42 }
43
44 fn is_eos(&self) -> bool {
46 matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47 }
48
49 fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51 match self {
52 RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53 RecvBuf::Partial(segmented) => segmented,
54 RecvBuf::EosPartial(segmented) => segmented,
55 RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56 }
57 }
58
59 fn with_partial(self, partial: Bytes) -> Self {
62 match self {
63 RecvBuf::Empty => {
64 let mut segmented = SegmentedBuf::new();
65 segmented.push(partial);
66 RecvBuf::Partial(segmented)
67 }
68 RecvBuf::Partial(mut segmented) => {
69 segmented.push(partial);
70 RecvBuf::Partial(segmented)
71 }
72 RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73 panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74 }
75 }
76 }
77
78 fn ended(self) -> Self {
80 match self {
81 RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82 RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83 RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84 RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85 }
86 }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91 UnexpectedEndOfStream,
93}
94
95#[derive(Debug)]
97pub struct ReceiverError {
98 kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self.kind {
104 ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105 }
106 }
107}
108
109impl StdError for ReceiverError {}
110
111#[derive(Debug)]
113pub struct Receiver<T, E> {
114 unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115 decoder: MessageFrameDecoder,
116 buffer: RecvBuf,
117 body: SdkBody,
118 buffered_message: Option<Message>,
123 _phantom: PhantomData<E>,
124}
125
126#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130 Request,
131 Response,
132}
133
134impl InitialMessageType {
135 fn as_str(&self) -> &'static str {
136 match self {
137 InitialMessageType::Request => "initial-request",
138 InitialMessageType::Response => "initial-response",
139 }
140 }
141}
142
143impl<T, E> Receiver<T, E> {
144 pub fn new(
146 unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147 body: SdkBody,
148 ) -> Self {
149 Receiver {
150 unmarshaller: Box::new(unmarshaller),
151 decoder: MessageFrameDecoder::new(),
152 buffer: RecvBuf::Empty,
153 body,
154 buffered_message: None,
155 _phantom: Default::default(),
156 }
157 }
158
159 fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160 match self.unmarshaller.unmarshall(&message) {
161 Ok(unmarshalled) => match unmarshalled {
162 UnmarshalledMessage::Event(event) => Ok(Some(event)),
163 UnmarshalledMessage::Error(err) => {
164 Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165 }
166 },
167 Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168 }
169 }
170
171 async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172 use http_body_util::BodyExt;
173
174 if !self.buffer.is_eos() {
175 let next_chunk = self
176 .body
177 .frame()
178 .await
179 .transpose()
180 .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181 let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182 if let Some(chunk) = next_chunk {
183 if let Ok(data) = chunk.into_data() {
185 self.buffer = buffer.with_partial(data);
186 }
187 } else {
188 self.buffer = buffer.ended();
189 }
190 }
191 Ok(())
192 }
193
194 async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
195 while !self.buffer.is_eos() {
196 if self.buffer.has_data() {
197 if let DecodedFrame::Complete(message) = self
198 .decoder
199 .decode_frame(self.buffer.buffered())
200 .map_err(|err| {
201 SdkError::response_error(
202 err,
203 RawMessage::Invalid(None),
205 )
206 })?
207 {
208 trace!(message = ?message, "received complete event stream message");
209 return Ok(Some(message));
210 }
211 }
212
213 self.buffer_next_chunk().await?;
214 }
215 if self.buffer.has_data() {
216 trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
217 let buf = self.buffer.buffered();
218 return Err(SdkError::response_error(
219 ReceiverError {
220 kind: ReceiverErrorKind::UnexpectedEndOfStream,
221 },
222 RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
223 ));
224 }
225 Ok(None)
226 }
227
228 #[doc(hidden)]
231 pub async fn try_recv_initial(
232 &mut self,
233 message_type: InitialMessageType,
234 ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
235 self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
236 .await
237 .map(|opt| opt.map(|(msg, _)| msg))
238 }
239
240 #[doc(hidden)]
247 pub async fn try_recv_initial_with_preprocessor<F, M>(
248 &mut self,
249 message_type: InitialMessageType,
250 preprocessor: F,
251 ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
252 where
253 F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
254 {
255 if let Some(message) = self.next_message().await? {
256 let (processed_message, metadata) =
257 preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
258
259 if let Some(event_type) = processed_message
260 .headers()
261 .iter()
262 .find(|h| h.name().as_str() == ":event-type")
263 {
264 if event_type
265 .value()
266 .as_string()
267 .map(|s| s.as_str() == message_type.as_str())
268 .unwrap_or(false)
269 {
270 return Ok(Some((processed_message, metadata)));
271 }
272 }
273 self.buffered_message = Some(message);
275 }
276 Ok(None)
277 }
278
279 pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
284 if let Some(buffered) = self.buffered_message.take() {
285 return match self.unmarshall(buffered) {
286 Ok(message) => Ok(message),
287 Err(error) => {
288 self.buffer = RecvBuf::Terminated;
289 Err(error)
290 }
291 };
292 }
293 if let Some(message) = self.next_message().await? {
294 match self.unmarshall(message) {
295 Ok(message) => Ok(message),
296 Err(error) => {
297 self.buffer = RecvBuf::Terminated;
298 Err(error)
299 }
300 }
301 } else {
302 Ok(None)
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::{InitialMessageType, Receiver, UnmarshallMessage};
310 use aws_smithy_eventstream::error::Error as EventStreamError;
311 use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
312 use aws_smithy_runtime_api::client::result::SdkError;
313 use aws_smithy_types::body::SdkBody;
314 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
315 use bytes::Bytes;
316 use http_body_1x::Frame;
317 use std::error::Error as StdError;
318 use std::io::{Error as IOError, ErrorKind};
319
320 fn encode_initial_response() -> Bytes {
321 let mut buffer = Vec::new();
322 let message = Message::new(Bytes::new())
323 .add_header(Header::new(
324 ":message-type",
325 HeaderValue::String("event".into()),
326 ))
327 .add_header(Header::new(
328 ":event-type",
329 HeaderValue::String("initial-response".into()),
330 ));
331 write_message_to(&message, &mut buffer).unwrap();
332 buffer.into()
333 }
334
335 fn encode_message(message: &str) -> Bytes {
336 let mut buffer = Vec::new();
337 let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
338 write_message_to(&message, &mut buffer).unwrap();
339 buffer.into()
340 }
341
342 fn map_to_frame(stream: Vec<Result<Bytes, IOError>>) -> Vec<Result<Frame<Bytes>, IOError>> {
343 stream
344 .into_iter()
345 .map(|chunk| chunk.map(Frame::data))
346 .collect()
347 }
348
349 #[derive(Debug)]
350 struct FakeError;
351 impl std::fmt::Display for FakeError {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 write!(f, "FakeError")
354 }
355 }
356 impl StdError for FakeError {}
357
358 #[derive(Debug, Eq, PartialEq)]
359 struct TestMessage(String);
360
361 #[derive(Debug)]
362 struct Unmarshaller;
363 impl UnmarshallMessage for Unmarshaller {
364 type Output = TestMessage;
365 type Error = EventStreamError;
366
367 fn unmarshall(
368 &self,
369 message: &Message,
370 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
371 Ok(UnmarshalledMessage::Event(TestMessage(
372 std::str::from_utf8(&message.payload()[..]).unwrap().into(),
373 )))
374 }
375 }
376
377 #[tokio::test]
378 async fn receive_success() {
379 let chunks: Vec<Result<_, IOError>> =
380 map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
381 let chunk_stream = futures_util::stream::iter(chunks);
382 let stream_body = http_body_util::StreamBody::new(chunk_stream);
383 let body = SdkBody::from_body_1_x(stream_body);
384
385 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
386
387 assert_eq!(
388 TestMessage("one".into()),
389 receiver.recv().await.unwrap().unwrap()
390 );
391 assert_eq!(
392 TestMessage("two".into()),
393 receiver.recv().await.unwrap().unwrap()
394 );
395 assert_eq!(None, receiver.recv().await.unwrap());
396 }
397
398 #[tokio::test]
399 async fn receive_last_chunk_empty() {
400 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
401 Ok(encode_message("one")),
402 Ok(encode_message("two")),
403 Ok(Bytes::from_static(&[])),
404 ]);
405 let chunk_stream = futures_util::stream::iter(chunks);
406 let stream_body = http_body_util::StreamBody::new(chunk_stream);
407 let body = SdkBody::from_body_1_x(stream_body);
408 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
409 assert_eq!(
410 TestMessage("one".into()),
411 receiver.recv().await.unwrap().unwrap()
412 );
413 assert_eq!(
414 TestMessage("two".into()),
415 receiver.recv().await.unwrap().unwrap()
416 );
417 assert_eq!(None, receiver.recv().await.unwrap());
418 }
419
420 #[tokio::test]
421 async fn receive_last_chunk_not_full_message() {
422 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
423 Ok(encode_message("one")),
424 Ok(encode_message("two")),
425 Ok(encode_message("three").split_to(10)),
426 ]);
427 let chunk_stream = futures_util::stream::iter(chunks);
428 let stream_body = http_body_util::StreamBody::new(chunk_stream);
429 let body = SdkBody::from_body_1_x(stream_body);
430 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
431 assert_eq!(
432 TestMessage("one".into()),
433 receiver.recv().await.unwrap().unwrap()
434 );
435 assert_eq!(
436 TestMessage("two".into()),
437 receiver.recv().await.unwrap().unwrap()
438 );
439 assert!(matches!(
440 receiver.recv().await,
441 Err(SdkError::ResponseError { .. }),
442 ));
443 }
444
445 #[tokio::test]
446 async fn receive_last_chunk_has_multiple_messages() {
447 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
448 Ok(encode_message("one")),
449 Ok(encode_message("two")),
450 Ok(Bytes::from(
451 [encode_message("three"), encode_message("four")].concat(),
452 )),
453 ]);
454 let chunk_stream = futures_util::stream::iter(chunks);
455 let stream_body = http_body_util::StreamBody::new(chunk_stream);
456 let body = SdkBody::from_body_1_x(stream_body);
457 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
458 assert_eq!(
459 TestMessage("one".into()),
460 receiver.recv().await.unwrap().unwrap()
461 );
462 assert_eq!(
463 TestMessage("two".into()),
464 receiver.recv().await.unwrap().unwrap()
465 );
466 assert_eq!(
467 TestMessage("three".into()),
468 receiver.recv().await.unwrap().unwrap()
469 );
470 assert_eq!(
471 TestMessage("four".into()),
472 receiver.recv().await.unwrap().unwrap()
473 );
474 assert_eq!(None, receiver.recv().await.unwrap());
475 }
476
477 proptest::proptest! {
478 #[test]
479 fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
480 let combined = Bytes::from([
481 encode_message("one"),
482 encode_message("two"),
483 encode_message("three"),
484 encode_message("four"),
485 encode_message("five"),
486 encode_message("six"),
487 encode_message("seven"),
488 encode_message("eight"),
489 ].concat());
490
491 let midpoint = combined.len() / 2;
492 let (start, boundary1, boundary2, end) = (
493 0,
494 b1 % midpoint,
495 midpoint + b2 % midpoint,
496 combined.len()
497 );
498 println!("[{start}, {boundary1}], [{boundary1}, {boundary2}], [{boundary2}, {end}]");
499
500 let rt = tokio::runtime::Runtime::new().unwrap();
501 rt.block_on(async move {
502 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
503 Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
504 Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
505 Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
506 ]);
507
508 let chunk_stream = futures_util::stream::iter(chunks);
509 let stream_body = http_body_util::StreamBody::new(chunk_stream);
510 let body = SdkBody::from_body_1_x(stream_body);
511 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
512 for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
513 assert_eq!(
514 TestMessage((*payload).into()),
515 receiver.recv().await.unwrap().unwrap()
516 );
517 }
518 assert_eq!(None, receiver.recv().await.unwrap());
519 });
520 }
521 }
522
523 #[tokio::test]
524 async fn receive_network_failure() {
525 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
526 Ok(encode_message("one")),
527 Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
528 ]);
529 let chunk_stream = futures_util::stream::iter(chunks);
530 let stream_body = http_body_util::StreamBody::new(chunk_stream);
531 let body = SdkBody::from_body_1_x(stream_body);
532 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
533 assert_eq!(
534 TestMessage("one".into()),
535 receiver.recv().await.unwrap().unwrap()
536 );
537 assert!(matches!(
538 receiver.recv().await,
539 Err(SdkError::DispatchFailure(_))
540 ));
541 }
542
543 #[tokio::test]
544 async fn receive_message_parse_failure() {
545 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
546 Ok(encode_message("one")),
547 Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
550 ]);
551 let chunk_stream = futures_util::stream::iter(chunks);
552 let stream_body = http_body_util::StreamBody::new(chunk_stream);
553 let body = SdkBody::from_body_1_x(stream_body);
554 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
555 assert_eq!(
556 TestMessage("one".into()),
557 receiver.recv().await.unwrap().unwrap()
558 );
559 assert!(matches!(
560 receiver.recv().await,
561 Err(SdkError::ResponseError { .. })
562 ));
563 }
564
565 #[tokio::test]
566 async fn receive_initial_response() {
567 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
568 Ok(encode_initial_response()),
569 Ok(encode_message("one")),
570 ]);
571 let chunk_stream = futures_util::stream::iter(chunks);
572 let stream_body = http_body_util::StreamBody::new(chunk_stream);
573 let body = SdkBody::from_body_1_x(stream_body);
574 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
575 assert!(receiver
576 .try_recv_initial(InitialMessageType::Response)
577 .await
578 .unwrap()
579 .is_some());
580 assert_eq!(
581 TestMessage("one".into()),
582 receiver.recv().await.unwrap().unwrap()
583 );
584 }
585
586 #[tokio::test]
587 async fn receive_no_initial_response() {
588 let chunks: Vec<Result<_, IOError>> =
589 map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
590 let chunk_stream = futures_util::stream::iter(chunks);
591 let stream_body = http_body_util::StreamBody::new(chunk_stream);
592
593 let body = SdkBody::from_body_1_x(stream_body);
594 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
595 assert!(receiver
596 .try_recv_initial(InitialMessageType::Response)
597 .await
598 .unwrap()
599 .is_none());
600 assert_eq!(
601 TestMessage("one".into()),
602 receiver.recv().await.unwrap().unwrap()
603 );
604 assert_eq!(
605 TestMessage("two".into()),
606 receiver.recv().await.unwrap().unwrap()
607 );
608 }
609
610 fn assert_send_and_sync<T: Send + Sync>() {}
611
612 #[tokio::test]
613 async fn receiver_is_send_and_sync() {
614 assert_send_and_sync::<Receiver<(), ()>>();
615 }
616}