1use aws_smithy_eventstream::frame::{
7 DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
8};
9use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
10use aws_smithy_types::body::SdkBody;
11use aws_smithy_types::event_stream::{Message, RawMessage};
12use bytes::Buf;
13use bytes::Bytes;
14use bytes_utils::SegmentedBuf;
15use std::error::Error as StdError;
16use std::fmt;
17use std::marker::PhantomData;
18use std::mem;
19use tracing::trace;
20
21#[derive(Debug)]
23enum RecvBuf {
24 Empty,
26 Partial(SegmentedBuf<Bytes>),
29 EosPartial(SegmentedBuf<Bytes>),
31 Terminated,
33}
34
35impl RecvBuf {
36 fn has_data(&self) -> bool {
38 match self {
39 RecvBuf::Empty | RecvBuf::Terminated => false,
40 RecvBuf::Partial(segments) | RecvBuf::EosPartial(segments) => segments.remaining() > 0,
41 }
42 }
43
44 fn is_eos(&self) -> bool {
46 matches!(self, RecvBuf::EosPartial(_) | RecvBuf::Terminated)
47 }
48
49 fn buffered(&mut self) -> &mut SegmentedBuf<Bytes> {
51 match self {
52 RecvBuf::Empty => panic!("buffer must be populated before reading; this is a bug"),
53 RecvBuf::Partial(segmented) => segmented,
54 RecvBuf::EosPartial(segmented) => segmented,
55 RecvBuf::Terminated => panic!("buffer has been terminated; this is a bug"),
56 }
57 }
58
59 fn with_partial(self, partial: Bytes) -> Self {
62 match self {
63 RecvBuf::Empty => {
64 let mut segmented = SegmentedBuf::new();
65 segmented.push(partial);
66 RecvBuf::Partial(segmented)
67 }
68 RecvBuf::Partial(mut segmented) => {
69 segmented.push(partial);
70 RecvBuf::Partial(segmented)
71 }
72 RecvBuf::EosPartial(_) | RecvBuf::Terminated => {
73 panic!("cannot buffer more data after the stream has ended or been terminated; this is a bug")
74 }
75 }
76 }
77
78 fn ended(self) -> Self {
80 match self {
81 RecvBuf::Empty => RecvBuf::EosPartial(SegmentedBuf::new()),
82 RecvBuf::Partial(segmented) => RecvBuf::EosPartial(segmented),
83 RecvBuf::EosPartial(_) => panic!("already end of stream; this is a bug"),
84 RecvBuf::Terminated => panic!("stream terminated; this is a bug"),
85 }
86 }
87}
88
89#[derive(Debug)]
90enum ReceiverErrorKind {
91 UnexpectedEndOfStream,
93}
94
95#[derive(Debug)]
97pub struct ReceiverError {
98 kind: ReceiverErrorKind,
99}
100
101impl fmt::Display for ReceiverError {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self.kind {
104 ReceiverErrorKind::UnexpectedEndOfStream => write!(f, "unexpected end of stream"),
105 }
106 }
107}
108
109impl StdError for ReceiverError {}
110
111#[derive(Debug)]
113pub struct Receiver<T, E> {
114 unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
115 decoder: MessageFrameDecoder,
116 buffer: RecvBuf,
117 body: SdkBody,
118 buffered_message: Option<Message>,
123 _phantom: PhantomData<E>,
124}
125
126#[doc(hidden)]
128#[non_exhaustive]
129pub enum InitialMessageType {
130 Request,
131 Response,
132}
133
134impl InitialMessageType {
135 fn as_str(&self) -> &'static str {
136 match self {
137 InitialMessageType::Request => "initial-request",
138 InitialMessageType::Response => "initial-response",
139 }
140 }
141}
142
143impl<T, E> Receiver<T, E> {
144 pub fn new(
146 unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
147 body: SdkBody,
148 ) -> Self {
149 Receiver {
150 unmarshaller: Box::new(unmarshaller),
151 decoder: MessageFrameDecoder::new(),
152 buffer: RecvBuf::Empty,
153 body,
154 buffered_message: None,
155 _phantom: Default::default(),
156 }
157 }
158
159 fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
160 match self.unmarshaller.unmarshall(&message) {
161 Ok(unmarshalled) => match unmarshalled {
162 UnmarshalledMessage::Event(event) => Ok(Some(event)),
163 UnmarshalledMessage::Error(err) => {
164 Err(SdkError::service_error(err, RawMessage::Decoded(message)))
165 }
166 },
167 Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
168 }
169 }
170
171 async fn buffer_next_chunk(&mut self) -> Result<(), SdkError<E, RawMessage>> {
172 use http_body_04x::Body;
173
174 if !self.buffer.is_eos() {
175 let next_chunk = self
176 .body
177 .data()
178 .await
179 .transpose()
180 .map_err(|err| SdkError::dispatch_failure(ConnectorError::io(err)))?;
181 let buffer = mem::replace(&mut self.buffer, RecvBuf::Empty);
182 if let Some(chunk) = next_chunk {
183 self.buffer = buffer.with_partial(chunk);
184 } else {
185 self.buffer = buffer.ended();
186 }
187 }
188 Ok(())
189 }
190
191 async fn next_message(&mut self) -> Result<Option<Message>, SdkError<E, RawMessage>> {
192 while !self.buffer.is_eos() {
193 if self.buffer.has_data() {
194 if let DecodedFrame::Complete(message) = self
195 .decoder
196 .decode_frame(self.buffer.buffered())
197 .map_err(|err| {
198 SdkError::response_error(
199 err,
200 RawMessage::Invalid(None),
202 )
203 })?
204 {
205 trace!(message = ?message, "received complete event stream message");
206 return Ok(Some(message));
207 }
208 }
209
210 self.buffer_next_chunk().await?;
211 }
212 if self.buffer.has_data() {
213 trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
214 let buf = self.buffer.buffered();
215 return Err(SdkError::response_error(
216 ReceiverError {
217 kind: ReceiverErrorKind::UnexpectedEndOfStream,
218 },
219 RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
220 ));
221 }
222 Ok(None)
223 }
224
225 #[doc(hidden)]
228 pub async fn try_recv_initial(
229 &mut self,
230 message_type: InitialMessageType,
231 ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
232 self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
233 .await
234 .map(|opt| opt.map(|(msg, _)| msg))
235 }
236
237 #[doc(hidden)]
244 pub async fn try_recv_initial_with_preprocessor<F, M>(
245 &mut self,
246 message_type: InitialMessageType,
247 preprocessor: F,
248 ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
249 where
250 F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
251 {
252 if let Some(message) = self.next_message().await? {
253 let (processed_message, metadata) =
254 preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
255
256 if let Some(event_type) = processed_message
257 .headers()
258 .iter()
259 .find(|h| h.name().as_str() == ":event-type")
260 {
261 if event_type
262 .value()
263 .as_string()
264 .map(|s| s.as_str() == message_type.as_str())
265 .unwrap_or(false)
266 {
267 return Ok(Some((processed_message, metadata)));
268 }
269 }
270 self.buffered_message = Some(message);
272 }
273 Ok(None)
274 }
275
276 pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
281 if let Some(buffered) = self.buffered_message.take() {
282 return match self.unmarshall(buffered) {
283 Ok(message) => Ok(message),
284 Err(error) => {
285 self.buffer = RecvBuf::Terminated;
286 Err(error)
287 }
288 };
289 }
290 if let Some(message) = self.next_message().await? {
291 match self.unmarshall(message) {
292 Ok(message) => Ok(message),
293 Err(error) => {
294 self.buffer = RecvBuf::Terminated;
295 Err(error)
296 }
297 }
298 } else {
299 Ok(None)
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::{InitialMessageType, Receiver, UnmarshallMessage};
307 use aws_smithy_eventstream::error::Error as EventStreamError;
308 use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
309 use aws_smithy_runtime_api::client::result::SdkError;
310 use aws_smithy_types::body::SdkBody;
311 use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
312 use bytes::Bytes;
313 use hyper::body::Body;
314 use std::error::Error as StdError;
315 use std::io::{Error as IOError, ErrorKind};
316
317 fn encode_initial_response() -> Bytes {
318 let mut buffer = Vec::new();
319 let message = Message::new(Bytes::new())
320 .add_header(Header::new(
321 ":message-type",
322 HeaderValue::String("event".into()),
323 ))
324 .add_header(Header::new(
325 ":event-type",
326 HeaderValue::String("initial-response".into()),
327 ));
328 write_message_to(&message, &mut buffer).unwrap();
329 buffer.into()
330 }
331
332 fn encode_message(message: &str) -> Bytes {
333 let mut buffer = Vec::new();
334 let message = Message::new(Bytes::copy_from_slice(message.as_bytes()));
335 write_message_to(&message, &mut buffer).unwrap();
336 buffer.into()
337 }
338
339 #[derive(Debug)]
340 struct FakeError;
341 impl std::fmt::Display for FakeError {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 write!(f, "FakeError")
344 }
345 }
346 impl StdError for FakeError {}
347
348 #[derive(Debug, Eq, PartialEq)]
349 struct TestMessage(String);
350
351 #[derive(Debug)]
352 struct Unmarshaller;
353 impl UnmarshallMessage for Unmarshaller {
354 type Output = TestMessage;
355 type Error = EventStreamError;
356
357 fn unmarshall(
358 &self,
359 message: &Message,
360 ) -> Result<UnmarshalledMessage<Self::Output, Self::Error>, EventStreamError> {
361 Ok(UnmarshalledMessage::Event(TestMessage(
362 std::str::from_utf8(&message.payload()[..]).unwrap().into(),
363 )))
364 }
365 }
366
367 #[tokio::test]
368 async fn receive_success() {
369 let chunks: Vec<Result<_, IOError>> =
370 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
371 let chunk_stream = futures_util::stream::iter(chunks);
372 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
373 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
374 assert_eq!(
375 TestMessage("one".into()),
376 receiver.recv().await.unwrap().unwrap()
377 );
378 assert_eq!(
379 TestMessage("two".into()),
380 receiver.recv().await.unwrap().unwrap()
381 );
382 assert_eq!(None, receiver.recv().await.unwrap());
383 }
384
385 #[tokio::test]
386 async fn receive_last_chunk_empty() {
387 let chunks: Vec<Result<_, IOError>> = vec![
388 Ok(encode_message("one")),
389 Ok(encode_message("two")),
390 Ok(Bytes::from_static(&[])),
391 ];
392 let chunk_stream = futures_util::stream::iter(chunks);
393 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
394 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
395 assert_eq!(
396 TestMessage("one".into()),
397 receiver.recv().await.unwrap().unwrap()
398 );
399 assert_eq!(
400 TestMessage("two".into()),
401 receiver.recv().await.unwrap().unwrap()
402 );
403 assert_eq!(None, receiver.recv().await.unwrap());
404 }
405
406 #[tokio::test]
407 async fn receive_last_chunk_not_full_message() {
408 let chunks: Vec<Result<_, IOError>> = vec![
409 Ok(encode_message("one")),
410 Ok(encode_message("two")),
411 Ok(encode_message("three").split_to(10)),
412 ];
413 let chunk_stream = futures_util::stream::iter(chunks);
414 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
415 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
416 assert_eq!(
417 TestMessage("one".into()),
418 receiver.recv().await.unwrap().unwrap()
419 );
420 assert_eq!(
421 TestMessage("two".into()),
422 receiver.recv().await.unwrap().unwrap()
423 );
424 assert!(matches!(
425 receiver.recv().await,
426 Err(SdkError::ResponseError { .. }),
427 ));
428 }
429
430 #[tokio::test]
431 async fn receive_last_chunk_has_multiple_messages() {
432 let chunks: Vec<Result<_, IOError>> = vec![
433 Ok(encode_message("one")),
434 Ok(encode_message("two")),
435 Ok(Bytes::from(
436 [encode_message("three"), encode_message("four")].concat(),
437 )),
438 ];
439 let chunk_stream = futures_util::stream::iter(chunks);
440 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
441 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
442 assert_eq!(
443 TestMessage("one".into()),
444 receiver.recv().await.unwrap().unwrap()
445 );
446 assert_eq!(
447 TestMessage("two".into()),
448 receiver.recv().await.unwrap().unwrap()
449 );
450 assert_eq!(
451 TestMessage("three".into()),
452 receiver.recv().await.unwrap().unwrap()
453 );
454 assert_eq!(
455 TestMessage("four".into()),
456 receiver.recv().await.unwrap().unwrap()
457 );
458 assert_eq!(None, receiver.recv().await.unwrap());
459 }
460
461 proptest::proptest! {
462 #[test]
463 fn receive_multiple_messages_split_unevenly_across_chunks(b1: usize, b2: usize) {
464 let combined = Bytes::from([
465 encode_message("one"),
466 encode_message("two"),
467 encode_message("three"),
468 encode_message("four"),
469 encode_message("five"),
470 encode_message("six"),
471 encode_message("seven"),
472 encode_message("eight"),
473 ].concat());
474
475 let midpoint = combined.len() / 2;
476 let (start, boundary1, boundary2, end) = (
477 0,
478 b1 % midpoint,
479 midpoint + b2 % midpoint,
480 combined.len()
481 );
482 println!("[{start}, {boundary1}], [{boundary1}, {boundary2}], [{boundary2}, {end}]");
483
484 let rt = tokio::runtime::Runtime::new().unwrap();
485 rt.block_on(async move {
486 let chunks: Vec<Result<_, IOError>> = vec![
487 Ok(Bytes::copy_from_slice(&combined[start..boundary1])),
488 Ok(Bytes::copy_from_slice(&combined[boundary1..boundary2])),
489 Ok(Bytes::copy_from_slice(&combined[boundary2..end])),
490 ];
491
492 let chunk_stream = futures_util::stream::iter(chunks);
493 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
494 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
495 for payload in &["one", "two", "three", "four", "five", "six", "seven", "eight"] {
496 assert_eq!(
497 TestMessage((*payload).into()),
498 receiver.recv().await.unwrap().unwrap()
499 );
500 }
501 assert_eq!(None, receiver.recv().await.unwrap());
502 });
503 }
504 }
505
506 #[tokio::test]
507 async fn receive_network_failure() {
508 let chunks: Vec<Result<_, IOError>> = vec![
509 Ok(encode_message("one")),
510 Err(IOError::new(ErrorKind::ConnectionReset, FakeError)),
511 ];
512 let chunk_stream = futures_util::stream::iter(chunks);
513 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
514 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
515 assert_eq!(
516 TestMessage("one".into()),
517 receiver.recv().await.unwrap().unwrap()
518 );
519 assert!(matches!(
520 receiver.recv().await,
521 Err(SdkError::DispatchFailure(_))
522 ));
523 }
524
525 #[tokio::test]
526 async fn receive_message_parse_failure() {
527 let chunks: Vec<Result<_, IOError>> = vec![
528 Ok(encode_message("one")),
529 Ok(Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
532 ];
533 let chunk_stream = futures_util::stream::iter(chunks);
534 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
535 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
536 assert_eq!(
537 TestMessage("one".into()),
538 receiver.recv().await.unwrap().unwrap()
539 );
540 assert!(matches!(
541 receiver.recv().await,
542 Err(SdkError::ResponseError { .. })
543 ));
544 }
545
546 #[tokio::test]
547 async fn receive_initial_response() {
548 let chunks: Vec<Result<_, IOError>> =
549 vec![Ok(encode_initial_response()), Ok(encode_message("one"))];
550 let chunk_stream = futures_util::stream::iter(chunks);
551 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
552 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
553 assert!(receiver
554 .try_recv_initial(InitialMessageType::Response)
555 .await
556 .unwrap()
557 .is_some());
558 assert_eq!(
559 TestMessage("one".into()),
560 receiver.recv().await.unwrap().unwrap()
561 );
562 }
563
564 #[tokio::test]
565 async fn receive_no_initial_response() {
566 let chunks: Vec<Result<_, IOError>> =
567 vec![Ok(encode_message("one")), Ok(encode_message("two"))];
568 let chunk_stream = futures_util::stream::iter(chunks);
569 let body = SdkBody::from_body_0_4(Body::wrap_stream(chunk_stream));
570 let mut receiver = Receiver::<TestMessage, EventStreamError>::new(Unmarshaller, body);
571 assert!(receiver
572 .try_recv_initial(InitialMessageType::Response)
573 .await
574 .unwrap()
575 .is_none());
576 assert_eq!(
577 TestMessage("one".into()),
578 receiver.recv().await.unwrap().unwrap()
579 );
580 assert_eq!(
581 TestMessage("two".into()),
582 receiver.recv().await.unwrap().unwrap()
583 );
584 }
585
586 fn assert_send_and_sync<T: Send + Sync>() {}
587
588 #[tokio::test]
589 async fn receiver_is_send_and_sync() {
590 assert_send_and_sync::<Receiver<(), ()>>();
591 }
592}