aws_smithy_eventstream/
test_util.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::frame::{read_message_from, DecodedFrame, MessageFrameDecoder};
7use aws_smithy_types::event_stream::{HeaderValue, Message};
8use std::collections::{BTreeMap, BTreeSet};
9use std::error::Error as StdError;
10
11/// Validate that the bodies match, which includes headers and messages
12///
13/// When `full_stream` is true, it also verifies the length of frames
14pub fn validate_body(
15    expected_body: &[u8],
16    actual_body: &[u8],
17    full_stream: bool,
18) -> Result<(), Box<dyn StdError>> {
19    let expected_frames = decode_frames(expected_body);
20    let actual_frames = decode_frames(actual_body);
21
22    if full_stream {
23        assert_eq!(
24            expected_frames.len(),
25            actual_frames.len(),
26            "Frame count didn't match.\n\
27        Expected: {:?}\n\
28        Actual:   {:?}",
29            expected_frames,
30            actual_frames
31        );
32    }
33
34    for ((expected_wrapper, expected_message), (actual_wrapper, actual_message)) in
35        expected_frames.into_iter().zip(actual_frames.into_iter())
36    {
37        assert_eq!(
38            header_names(&expected_wrapper),
39            header_names(&actual_wrapper)
40        );
41        if let Some(expected_message) = expected_message {
42            let actual_message = actual_message.unwrap();
43            assert_eq!(header_map(&expected_message), header_map(&actual_message));
44            assert_eq!(expected_message.payload(), actual_message.payload());
45        }
46    }
47    Ok(())
48}
49
50// Returned tuples are (SignedWrapperMessage, WrappedMessage).
51// Some signed messages don't have payloads, so in those cases, the wrapped message will be None.
52fn decode_frames(mut body: &[u8]) -> Vec<(Message, Option<Message>)> {
53    let mut result = Vec::new();
54    let mut decoder = MessageFrameDecoder::new();
55    while let DecodedFrame::Complete(msg) = decoder.decode_frame(&mut body).unwrap() {
56        let inner_msg = if msg.payload().is_empty() {
57            None
58        } else {
59            Some(read_message_from(msg.payload().as_ref()).unwrap())
60        };
61        result.push((msg, inner_msg));
62    }
63    result
64}
65
66fn header_names(msg: &Message) -> BTreeSet<String> {
67    msg.headers()
68        .iter()
69        .map(|h| h.name().as_str().into())
70        .collect()
71}
72fn header_map(msg: &Message) -> BTreeMap<String, &HeaderValue> {
73    msg.headers()
74        .iter()
75        .map(|h| (h.name().as_str().to_string(), h.value()))
76        .collect()
77}