1 1 | /*
|
2 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 3 | * SPDX-License-Identifier: Apache-2.0
|
4 4 | */
|
5 5 |
|
6 6 | use aws_smithy_eventstream::frame::{
|
7 7 | DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
|
8 8 | };
|
9 - | use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
|
9 + | use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
|
10 10 | use aws_smithy_types::body::SdkBody;
|
11 11 | use aws_smithy_types::event_stream::{Message, RawMessage};
|
12 12 | use bytes::Buf;
|
13 13 | use bytes::Bytes;
|
14 14 | use bytes_utils::SegmentedBuf;
|
15 15 | use std::error::Error as StdError;
|
16 16 | use std::fmt;
|
17 17 | use std::marker::PhantomData;
|
18 18 | use std::mem;
|
19 19 | use tracing::trace;
|
20 20 |
|
21 21 | /// Wrapper around SegmentedBuf that tracks the state of the stream.
|
22 22 | #[derive(Debug)]
|
23 23 | enum RecvBuf {
|
24 24 | /// Nothing has been buffered yet.
|
25 25 | Empty,
|
26 26 | /// Some data has been buffered.
|
27 27 | /// The SegmentedBuf will automatically purge when it reads off the end of a chunk boundary.
|
28 28 | Partial(SegmentedBuf<Bytes>),
|
29 29 | /// The end of the stream has been reached, but there may still be some buffered data.
|
30 30 | EosPartial(SegmentedBuf<Bytes>),
|
31 31 | /// An exception terminated this stream.
|
32 32 | Terminated,
|
33 33 | }
|
34 34 |
|
35 35 | impl RecvBuf {
|
36 36 | /// Returns true if there's more buffered data.
|
37 37 | fn has_data(&self) -> bool {
|
38 38 | match self {
|
39 39 | RecvBuf::Empty | RecvBuf::Terminated => false,
|
202 202 | )
|
203 203 | })?
|
204 204 | {
|
205 205 | trace!(message = ?message, "received complete event stream message");
|
206 206 | return Ok(Some(message));
|
207 207 | }
|
208 208 | }
|
209 209 |
|
210 210 | self.buffer_next_chunk().await?;
|
211 211 | }
|
212 212 | if self.buffer.has_data() {
|
213 213 | trace!(remaining_data = ?self.buffer, "data left over in the event stream response stream");
|
214 214 | let buf = self.buffer.buffered();
|
215 215 | return Err(SdkError::response_error(
|
216 216 | ReceiverError {
|
217 217 | kind: ReceiverErrorKind::UnexpectedEndOfStream,
|
218 218 | },
|
219 219 | RawMessage::invalid(Some(buf.copy_to_bytes(buf.remaining()))),
|
220 220 | ));
|
221 221 | }
|
222 222 | Ok(None)
|
223 223 | }
|
224 224 |
|
225 225 | /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
|
226 226 | /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
|
227 227 | #[doc(hidden)]
|
228 228 | pub async fn try_recv_initial(
|
229 229 | &mut self,
|
230 230 | message_type: InitialMessageType,
|
231 231 | ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
|
232 + | self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
|
233 + | .await
|
234 + | .map(|opt| opt.map(|(msg, _)| msg))
|
235 + | }
|
236 + |
|
237 + | /// Tries to receive the initial response message with preprocessing support.
|
238 + | ///
|
239 + | /// The preprocessor function can transform the raw message (e.g., unwrap envelopes)
|
240 + | /// and return metadata alongside the transformed message. If the transformed message
|
241 + | /// matches the expected `message_type`, both the message and metadata are returned.
|
242 + | /// Otherwise, the transformed message is buffered and `Ok(None)` is returned.
|
243 + | #[doc(hidden)]
|
244 + | pub async fn try_recv_initial_with_preprocessor<F, M>(
|
245 + | &mut self,
|
246 + | message_type: InitialMessageType,
|
247 + | preprocessor: F,
|
248 + | ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
|
249 + | where
|
250 + | F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
|
251 + | {
|
232 252 | if let Some(message) = self.next_message().await? {
|
233 - | if let Some(event_type) = message
|
253 + | let (processed_message, metadata) =
|
254 + | preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
|
255 + |
|
256 + | if let Some(event_type) = processed_message
|
234 257 | .headers()
|
235 258 | .iter()
|
236 259 | .find(|h| h.name().as_str() == ":event-type")
|
237 260 | {
|
238 261 | if event_type
|
239 262 | .value()
|
240 263 | .as_string()
|
241 264 | .map(|s| s.as_str() == message_type.as_str())
|
242 265 | .unwrap_or(false)
|
243 266 | {
|
244 - | return Ok(Some(message));
|
267 + | return Ok(Some((processed_message, metadata)));
|
245 268 | }
|
246 269 | }
|
247 - | // Buffer the message so that it can be returned by the next call to `recv()`
|
270 + | // Buffer the processed message so that it can be returned by the next call to `recv()`
|
248 271 | self.buffered_message = Some(message);
|
249 272 | }
|
250 273 | Ok(None)
|
251 274 | }
|
252 275 |
|
253 276 | /// Asynchronously tries to receive a message from the stream. If the stream has ended,
|
254 277 | /// it returns an `Ok(None)`. If there is a transport layer error, it will return
|
255 278 | /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
|
256 279 | /// messages.
|
257 280 | pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
|