1use aws_smithy_eventstream::frame::{
7 DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21#[derive(Debug)]
23enum RecvBuf {
24 Empty,
26 Partial(SegmentedBuf<Bytes>),
29 EosPartial(SegmentedBuf<Bytes>),
31 Terminated,
33}
34
35impl RecvBuf {
36 fn has_data(&self) -> bool {
38 match self {
39 RecvBuf::Empty | RecvBuf::Terminated => false,
40 RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41 }
42 }
43
44 fn is_eos(&self) -> bool {
46 matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47 }
48
49 fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51 match self {
52 RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53 RecvBuf::Partial(segmented) => segmented,
54 RecvBuf::EosPartial(segmented) => segmented,
55 RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56 }
57 }
58
59 fn with_partial(self, partial: Bytes) -> Self {
62 match self {
63 RecvBuf::Empty => {
64 let mut segmented = SegmentedBuf::new();
65 segmented.push(partial);
66 RecvBuf::Partial(segmented)
67 }
68 RecvBuf::Partial(mut segmented) => {
69 segmented.push(partial);
70 RecvBuf::Partial(segmented)
71 }
72 RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73 panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74 }
75 }
76 }
77
78 fn ended(self) -> Self {
80 match self {
81 RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82 RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83 RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84 RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85 }
86 }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91 UnexpectedEndOfStream,
93}
94
95#[derive(Debug)]
97pub struct ReceiverError {
98 kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self.kind {
104 ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105 }
106 }
107}
108
109impl StdError for ReceiverError {}
110
111#[derive(Debug)]
113pub struct Receiver<T, E> {
114 unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115 decoder: MessageFrameDecoder,
116 buffer: RecvBuf,
117 body: SdkBody,
118 buffered_message: Option<Message>,
123 _phantom: PhantomData<E>,
124}
125
126#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130 Request,
131 Response,
132}
133
134impl InitialMessageType {
135 fn as_str(&self) -> &'static str {
136 match self {
137 InitialMessageType::Request => "initial-request",
138 InitialMessageType::Response => "initial-response",
139 }
140 }
141}
142
143impl<T, E> Receiver<T, E> {
144 pub fn new(
146 unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147 body: SdkBody,
148 ) -> Self {
149 Receiver {
150 unmarshaller: Box::new(unmarshaller),
151 decoder: MessageFrameDecoder::new(),
152 buffer: RecvBuf::Empty,
153 body,
154 buffered_message: None,
155 _phantom: Default::default(),
156 }
157 }
158
159 fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160 match self.unmarshaller.unmarshall(&message) {
161 Ok(unmarshalled) => match unmarshalled {
162 UnmarshalledMessage::Event(event) => Ok(Some(event)),
163 UnmarshalledMessage::Error(err) => {
164 Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165 }
166 },
167 Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168 }
169 }
170
171 async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172 use http_body_util::BodyExt;
173
174 if !self.buffer.is_eos() {
175 let next_chunk = self
176 .body
177 .frame()
178 .await
179 .transpose()
180 .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181 let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182 if let Some(chunk) = next_chunk {
183 if let Ok(data) = chunk.into_data() {
185 self.buffer = buffer.with_partial(data);
186 }
187 } else {
188 self.buffer = buffer.ended();
189 }
190 }
191 Ok(())
192 }
193
194 async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
195 while !self.buffer.is_eos() {
196 if self.buffer.has_data() {
197 if let DecodedFrame::Complete(message) = self
198 .decoder
199 .decode_frame(self.buffer.buffered())
200 .map_err(|err| {
201 SdkError::response_error(
202 err,
203 RawMessage::Invalid(None),
205 )
206 })?
207 {
208 trace!(message = ?message, "received complete event stream message");
209 return Ok(Some(message));
210 }
211 }
212
213 self.buffer_next_chunk().await?;
214 }
215 if self.buffer.has_data() {
216 trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
217 let buf = self.buffer.buffered();
218 return Err(SdkError::response_error(
219 ReceiverError {
220 kind: ReceiverErrorKind::UnexpectedEndOfStream,
221 },
222 RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
223 ));
224 }
225 Ok(None)
226 }
227
228 #[doc(hidden)]
231 pub async fn try_recv_initial(
232 &mut self,
233 message_type: InitialMessageType,
234 ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
235 if let Some(message) = self.next_message().await? {
236 if let Some(event_type) = message
237 .headers()
238 .iter()
239 .find(|h| h.name().as_str() == ":event-type")
240 {
241 if event_type
242 .value()
243 .as_string()
244 .map(|s| s.as_str() == message_type.as_str())
245 .unwrap_or(false)
246 {
247 return Ok(Some(message));
248 }
249 }
250 self.buffered_message = Some(message);
252 }
253 Ok(None)
254 }
255
256 pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
261 if let Some(buffered) = self.buffered_message.take() {
262 return match self.unmarshall(buffered) {
263 Ok(message) => Ok(message),
264 Err(error) => {
265 self.buffer = RecvBuf::Terminated;
266 Err(error)
267 }
268 };
269 }
270 if let Some(message) = self.next_message().await? {
271 match self.unmarshall(message) {
272 Ok(message) => Ok(message),
273 Err(error) => {
274 self.buffer = RecvBuf::Terminated;
275 Err(error)
276 }
277 }
278 } else {
279 Ok(None)
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::{InitialMessageType, Receiver, UnmarshallMessage};
287 use aws_smithy_eventstream::error::Error as EventStreamError;
288 use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
289 use aws_smithy_runtime_api::client::result::SdkError;
290 use aws_smithy_types::body::SdkBody;
291 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
292 use bytes::Bytes;
293 use http_body_1x::Frame;
294 use std::error::Error as StdError;
295 use std::io::{Error as IOError, ErrorKind};
296
297 fn encode_initial_response() -> Bytes {
298 let mut buffer = Vec::new();
299 let message = Message::new(Bytes::new())
300 .add_header(Header::new(
301 ":message-type",
302 HeaderValue::String("event".into()),
303 ))
304 .add_header(Header::new(
305 ":event-type",
306 HeaderValue::String("initial-response".into()),
307 ));
308 write_message_to(&message, &mut buffer).unwrap();
309 buffer.into()
310 }
311
312 fn encode_message(message: &str) -> Bytes {
313 let mut buffer = Vec::new();
314 let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
315 write_message_to(&message, &mut buffer).unwrap();
316 buffer.into()
317 }
318
319 fn map_to_frame(stream: Vec<Result<Bytes, IOError>>) -> Vec<Result<Frame<Bytes>, IOError>> {
320 stream
321 .into_iter()
322 .map(|chunk| chunk.map(Frame::data))
323 .collect()
324 }
325
326 #[derive(Debug)]
327 struct FakeError;
328 impl std::fmt::Display for FakeError {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 write!(f, "FakeError")
331 }
332 }
333 impl StdError for FakeError {}
334
335 #[derive(Debug, Eq, PartialEq)]
336 struct TestMessage(String);
337
338 #[derive(Debug)]
339 struct Unmarshaller;
340 impl UnmarshallMessage for Unmarshaller {
341 type Output = TestMessage;
342 type Error = EventStreamError;
343
344 fn unmarshall(
345 &self,
346 message: &Message,
347 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
348 Ok(UnmarshalledMessage::Event(TestMessage(
349 std::str::from_utf8(&message.payload()[..]).unwrap().into(),
350 )))
351 }
352 }
353
354 #[tokio::test]
355 async fn receive_success() {
356 let chunks: Vec<Result<_, IOError>> =
357 map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
358 let chunk_stream = futures_util::stream::iter(chunks);
359 let stream_body = http_body_util::StreamBody::new(chunk_stream);
360 let body = SdkBody::from_body_1_x(stream_body);
361
362 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
363
364 assert_eq!(
365 TestMessage("one".into()),
366 receiver.recv().await.unwrap().unwrap()
367 );
368 assert_eq!(
369 TestMessage("two".into()),
370 receiver.recv().await.unwrap().unwrap()
371 );
372 assert_eq!(None, receiver.recv().await.unwrap());
373 }
374
375 #[tokio::test]
376 async fn receive_last_chunk_empty() {
377 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
378 Ok(encode_message("one")),
379 Ok(encode_message("two")),
380 Ok(Bytes::from_static(&[])),
381 ]);
382 let chunk_stream = futures_util::stream::iter(chunks);
383 let stream_body = http_body_util::StreamBody::new(chunk_stream);
384 let body = SdkBody::from_body_1_x(stream_body);
385 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
386 assert_eq!(
387 TestMessage("one".into()),
388 receiver.recv().await.unwrap().unwrap()
389 );
390 assert_eq!(
391 TestMessage("two".into()),
392 receiver.recv().await.unwrap().unwrap()
393 );
394 assert_eq!(None, receiver.recv().await.unwrap());
395 }
396
397 #[tokio::test]
398 async fn receive_last_chunk_not_full_message() {
399 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
400 Ok(encode_message("one")),
401 Ok(encode_message("two")),
402 Ok(encode_message("three").split_to(10)),
403 ]);
404 let chunk_stream = futures_util::stream::iter(chunks);
405 let stream_body = http_body_util::StreamBody::new(chunk_stream);
406 let body = SdkBody::from_body_1_x(stream_body);
407 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
408 assert_eq!(
409 TestMessage("one".into()),
410 receiver.recv().await.unwrap().unwrap()
411 );
412 assert_eq!(
413 TestMessage("two".into()),
414 receiver.recv().await.unwrap().unwrap()
415 );
416 assert!(matches!(
417 receiver.recv().await,
418 Err(SdkError::ResponseError { .. }),
419 ));
420 }
421
422 #[tokio::test]
423 async fn receive_last_chunk_has_multiple_messages() {
424 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
425 Ok(encode_message("one")),
426 Ok(encode_message("two")),
427 Ok(Bytes::from(
428 [encode_message("three"), encode_message("four")].concat(),
429 )),
430 ]);
431 let chunk_stream = futures_util::stream::iter(chunks);
432 let stream_body = http_body_util::StreamBody::new(chunk_stream);
433 let body = SdkBody::from_body_1_x(stream_body);
434 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
435 assert_eq!(
436 TestMessage("one".into()),
437 receiver.recv().await.unwrap().unwrap()
438 );
439 assert_eq!(
440 TestMessage("two".into()),
441 receiver.recv().await.unwrap().unwrap()
442 );
443 assert_eq!(
444 TestMessage("three".into()),
445 receiver.recv().await.unwrap().unwrap()
446 );
447 assert_eq!(
448 TestMessage("four".into()),
449 receiver.recv().await.unwrap().unwrap()
450 );
451 assert_eq!(None, receiver.recv().await.unwrap());
452 }
453
454 proptest::proptest! {
455 #[test]
456 fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
457 let combined = Bytes::from([
458 encode_message("one"),
459 encode_message("two"),
460 encode_message("three"),
461 encode_message("four"),
462 encode_message("five"),
463 encode_message("six"),
464 encode_message("seven"),
465 encode_message("eight"),
466 ].concat());
467
468 let midpoint = combined.len() / 2;
469 let (start, boundary1, boundary2, end) = (
470 0,
471 b1 % midpoint,
472 midpoint + b2 % midpoint,
473 combined.len()
474 );
475 println!("[{}, {}], [{}, {}], [{}, {}]", start, boundary1, boundary1, boundary2, boundary2, end);
476
477 let rt = tokio::runtime::Runtime::new().unwrap();
478 rt.block_on(async move {
479 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
480 Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
481 Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
482 Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
483 ]);
484
485 let chunk_stream = futures_util::stream::iter(chunks);
486 let stream_body = http_body_util::StreamBody::new(chunk_stream);
487 let body = SdkBody::from_body_1_x(stream_body);
488 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
489 for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
490 assert_eq!(
491 TestMessage((*payload).into()),
492 receiver.recv().await.unwrap().unwrap()
493 );
494 }
495 assert_eq!(None, receiver.recv().await.unwrap());
496 });
497 }
498 }
499
500 #[tokio::test]
501 async fn receive_network_failure() {
502 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
503 Ok(encode_message("one")),
504 Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
505 ]);
506 let chunk_stream = futures_util::stream::iter(chunks);
507 let stream_body = http_body_util::StreamBody::new(chunk_stream);
508 let body = SdkBody::from_body_1_x(stream_body);
509 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
510 assert_eq!(
511 TestMessage("one".into()),
512 receiver.recv().await.unwrap().unwrap()
513 );
514 assert!(matches!(
515 receiver.recv().await,
516 Err(SdkError::DispatchFailure(_))
517 ));
518 }
519
520 #[tokio::test]
521 async fn receive_message_parse_failure() {
522 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
523 Ok(encode_message("one")),
524 Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
527 ]);
528 let chunk_stream = futures_util::stream::iter(chunks);
529 let stream_body = http_body_util::StreamBody::new(chunk_stream);
530 let body = SdkBody::from_body_1_x(stream_body);
531 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
532 assert_eq!(
533 TestMessage("one".into()),
534 receiver.recv().await.unwrap().unwrap()
535 );
536 assert!(matches!(
537 receiver.recv().await,
538 Err(SdkError::ResponseError { .. })
539 ));
540 }
541
542 #[tokio::test]
543 async fn receive_initial_response() {
544 let chunks: Vec<Result<_, IOError>> = map_to_frame(vec![
545 Ok(encode_initial_response()),
546 Ok(encode_message("one")),
547 ]);
548 let chunk_stream = futures_util::stream::iter(chunks);
549 let stream_body = http_body_util::StreamBody::new(chunk_stream);
550 let body = SdkBody::from_body_1_x(stream_body);
551 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
552 assert!(receiver
553 .try_recv_initial(InitialMessageType::Response)
554 .await
555 .unwrap()
556 .is_some());
557 assert_eq!(
558 TestMessage("one".into()),
559 receiver.recv().await.unwrap().unwrap()
560 );
561 }
562
563 #[tokio::test]
564 async fn receive_no_initial_response() {
565 let chunks: Vec<Result<_, IOError>> =
566 map_to_frame(vec![Ok(encode_message("one")), Ok(encode_message("two"))]);
567 let chunk_stream = futures_util::stream::iter(chunks);
568 let stream_body = http_body_util::StreamBody::new(chunk_stream);
569
570 let body = SdkBody::from_body_1_x(stream_body);
571 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
572 assert!(receiver
573 .try_recv_initial(InitialMessageType::Response)
574 .await
575 .unwrap()
576 .is_none());
577 assert_eq!(
578 TestMessage("one".into()),
579 receiver.recv().await.unwrap().unwrap()
580 );
581 assert_eq!(
582 TestMessage("two".into()),
583 receiver.recv().await.unwrap().unwrap()
584 );
585 }
586
587 fn assert_send_and_sync<T: Send + Sync>() {}
588
589 #[tokio::test]
590 async fn receiver_is_send_and_sync() {
591 assert_send_and_sync::<Receiver<(), ()>>();
592 }
593}