1 1 | /*
|
2 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 3 | * SPDX-License-Identifier: Apache-2.0
|
4 4 | */
|
5 5 |
|
6 6 | use aws_smithy_eventstream::frame::{
|
7 7 | DecodedFrame, MessageFrameDecoder, UnmarshallMessage, UnmarshalledMessage,
|
8 8 | };
|
9 - | use aws_smithy_runtime_api::client::result::{ConnectorError, SdkError};
|
9 + | use aws_smithy_runtime_api::client::result::{ConnectorError, ResponseError, SdkError};
|
10 10 | use aws_smithy_types::body::SdkBody;
|
11 11 | use aws_smithy_types::event_stream::{Message, RawMessage};
|
12 12 | use bytes::Buf;
|
13 13 | use bytes::Bytes;
|
14 14 | use bytes_utils::SegmentedBuf;
|
15 15 | use std::error::Error as StdError;
|
16 16 | use std::fmt;
|
17 17 | use std::marker::PhantomData;
|
18 18 | use std::mem;
|
19 19 | use tracing::trace;
|
113 113 | pub struct Receiver<T, E> {
|
114 114 | unmarshaller: Box<dyn UnmarshallMessage<Output = T, Error = E> + Send + Sync>,
|
115 115 | decoder: MessageFrameDecoder,
|
116 116 | buffer: RecvBuf,
|
117 117 | body: SdkBody,
|
118 118 | /// Event Stream has optional initial response frames an with `:message-type` of
|
119 119 | /// `initial-response`. If `try_recv_initial()` is called and the next message isn't an
|
120 120 | /// initial response, then the message will be stored in `buffered_message` so that it can
|
121 121 | /// be returned with the next call of `recv()`.
|
122 122 | buffered_message: Option<Message>,
|
123 + | /// Stores initial-request or initial-response messages that were filtered out by `recv()`.
|
124 + | /// These are only returned when `try_recv_initial()` is explicitly called.
|
125 + | buffered_initial_message: Option<Message>,
|
123 126 | _phantom: PhantomData<E>,
|
124 127 | }
|
125 128 |
|
126 129 | // Used by `Receiver::try_recv_initial`, hence this enum is also doc hidden
|
127 130 | #[doc(hidden)]
|
128 131 | #[non_exhaustive]
|
129 132 | pub enum InitialMessageType {
|
130 133 | Request,
|
131 134 | Response,
|
132 135 | }
|
133 136 |
|
134 137 | impl InitialMessageType {
|
135 138 | fn as_str(&self) -> &'static str {
|
136 139 | match self {
|
137 140 | InitialMessageType::Request => "initial-request",
|
138 141 | InitialMessageType::Response => "initial-response",
|
139 142 | }
|
140 143 | }
|
141 144 | }
|
142 145 |
|
143 146 | impl<T, E> Receiver<T, E> {
|
144 147 | /// Creates a new `Receiver` with the given message unmarshaller and SDK body.
|
145 148 | pub fn new(
|
146 149 | unmarshaller: impl UnmarshallMessage<Output = T, Error = E> + Send + Sync + 'static,
|
147 150 | body: SdkBody,
|
148 151 | ) -> Self {
|
149 152 | Receiver {
|
150 153 | unmarshaller: Box::new(unmarshaller),
|
151 154 | decoder: MessageFrameDecoder::new(),
|
152 155 | buffer: RecvBuf::Empty,
|
153 156 | body,
|
154 157 | buffered_message: None,
|
158 + | buffered_initial_message: None,
|
155 159 | _phantom: Default::default(),
|
156 160 | }
|
157 161 | }
|
158 162 |
|
163 + | /// Checks if a message is an initial-request or initial-response message.
|
164 + | fn is_initial_message(message: &Message) -> bool {
|
165 + | message
|
166 + | .headers()
|
167 + | .iter()
|
168 + | .find(|h| h.name().as_str() == ":event-type")
|
169 + | .and_then(|h| h.value().as_string().ok())
|
170 + | .map(|s| s.as_str() == "initial-request" || s.as_str() == "initial-response")
|
171 + | .unwrap_or(false)
|
172 + | }
|
173 + |
|
159 174 | fn unmarshall(&self, message: Message) -> Result<Option<T>, SdkError<E, RawMessage>> {
|
160 175 | match self.unmarshaller.unmarshall(&message) {
|
161 176 | Ok(unmarshalled) => match unmarshalled {
|
162 177 | UnmarshalledMessage::Event(event) => Ok(Some(event)),
|
163 178 | UnmarshalledMessage::Error(err) => {
|
164 179 | Err(SdkError::service_error(err, RawMessage::Decoded(message)))
|
165 180 | }
|
166 181 | },
|
167 182 | Err(err) => Err(SdkError::response_error(err, RawMessage::Decoded(message))),
|
168 183 | }
|
222 237 | Ok(None)
|
223 238 | }
|
224 239 |
|
225 240 | /// Tries to receive the initial response message that has `:event-type` of a given `message_type`.
|
226 241 | /// If a different event type is received, then it is buffered and `Ok(None)` is returned.
|
227 242 | #[doc(hidden)]
|
228 243 | pub async fn try_recv_initial(
|
229 244 | &mut self,
|
230 245 | message_type: InitialMessageType,
|
231 246 | ) -> Result<Option<Message>, SdkError<E, RawMessage>> {
|
232 - | if let Some(message) = self.next_message().await? {
|
233 - | if let Some(event_type) = message
|
247 + | self.try_recv_initial_with_preprocessor(message_type, |msg| Ok((msg, ())))
|
248 + | .await
|
249 + | .map(|opt| opt.map(|(msg, _)| msg))
|
250 + | }
|
251 + |
|
252 + | /// Tries to receive the initial response message with preprocessing support.
|
253 + | ///
|
254 + | /// The preprocessor function can transform the raw message (e.g., unwrap envelopes)
|
255 + | /// and return metadata alongside the transformed message. If the transformed message
|
256 + | /// matches the expected `message_type`, both the message and metadata are returned.
|
257 + | /// Otherwise, the transformed message is buffered and `Ok(None)` is returned.
|
258 + | ///
|
259 + | /// This method will block waiting for a message if no initial message has been buffered yet.
|
260 + | #[doc(hidden)]
|
261 + | pub async fn try_recv_initial_with_preprocessor<F, M>(
|
262 + | &mut self,
|
263 + | message_type: InitialMessageType,
|
264 + | preprocessor: F,
|
265 + | ) -> Result<Option<(Message, M)>, SdkError<E, RawMessage>>
|
266 + | where
|
267 + | F: FnOnce(Message) -> Result<(Message, M), ResponseError<RawMessage>>,
|
268 + | {
|
269 + | // Check if we already have a buffered initial message from recv()
|
270 + | let message = if let Some(buffered) = self.buffered_initial_message.take() {
|
271 + | buffered
|
272 + | } else {
|
273 + | // Otherwise, block waiting for the next message
|
274 + | match self.next_message().await? {
|
275 + | Some(msg) => msg,
|
276 + | None => return Ok(None),
|
277 + | }
|
278 + | };
|
279 + |
|
280 + | let (processed_message, metadata) =
|
281 + | preprocessor(message.clone()).map_err(|err| SdkError::ResponseError(err))?;
|
282 + |
|
283 + | if let Some(event_type) = processed_message
|
234 284 | .headers()
|
235 285 | .iter()
|
236 286 | .find(|h| h.name().as_str() == ":event-type")
|
237 287 | {
|
238 288 | if event_type
|
239 289 | .value()
|
240 290 | .as_string()
|
291 + | .ok()
|
241 292 | .map(|s| s.as_str() == message_type.as_str())
|
242 293 | .unwrap_or(false)
|
243 294 | {
|
244 - | return Ok(Some(message));
|
295 + | return Ok(Some((processed_message, metadata)));
|
245 296 | }
|
246 297 | }
|
247 - | // Buffer the message so that it can be returned by the next call to `recv()`
|
298 + | // Buffer the original message so that it can be returned by the next call to `recv()`
|
248 299 | self.buffered_message = Some(message);
|
249 - | }
|
250 300 | Ok(None)
|
251 301 | }
|
252 302 |
|
253 303 | /// Asynchronously tries to receive a message from the stream. If the stream has ended,
|
254 304 | /// it returns an `Ok(None)`. If there is a transport layer error, it will return
|
255 305 | /// `Err(SdkError::DispatchFailure)`. Service-modeled errors will be a part of the returned
|
256 306 | /// messages.
|
307 + | ///
|
308 + | /// This method filters out initial-request and initial-response messages. Those messages
|
309 + | /// are only returned when `try_recv_initial()` is explicitly called.
|
257 310 | pub async fn recv(&mut self) -> Result<Option<T>, SdkError<E, RawMessage>> {
|
311 + | loop {
|
258 312 | if let Some(buffered) = self.buffered_message.take() {
|
259 313 | return match self.unmarshall(buffered) {
|
260 314 | Ok(message) => Ok(message),
|
261 315 | Err(error) => {
|
262 316 | self.buffer = RecvBuf::Terminated;
|
263 317 | Err(error)
|
264 318 | }
|
265 319 | };
|
266 320 | }
|
267 321 | if let Some(message) = self.next_message().await? {
|
322 + | // Filter out initial messages - store them separately
|
323 + | if Self::is_initial_message(&message) {
|
324 + | self.buffered_initial_message = Some(message);
|
325 + | continue;
|
326 + | }
|
268 327 | match self.unmarshall(message) {
|
269 - | Ok(message) => Ok(message),
|
328 + | Ok(message) => return Ok(message),
|
270 329 | Err(error) => {
|
271 330 | self.buffer = RecvBuf::Terminated;
|
272 - | Err(error)
|
331 + | return Err(error);
|
273 332 | }
|
274 333 | }
|
275 334 | } else {
|
276 - | Ok(None)
|
335 + | return Ok(None);
|
336 + | }
|
277 337 | }
|
278 338 | }
|
279 339 | }
|
280 340 |
|
281 341 | #[cfg(test)]
|
282 342 | mod tests {
|
283 343 | use super::{InitialMessageType, Receiver, UnmarshallMessage};
|
284 344 | use aws_smithy_eventstream::error::Error as EventStreamError;
|
285 345 | use aws_smithy_eventstream::frame::{write_message_to, UnmarshalledMessage};
|
286 346 | use aws_smithy_runtime_api::client::result::SdkError;
|