1use crate::error::{Error, ErrorKind};
7use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
8use aws_smithy_types::str_bytes::StrBytes;
9use aws_smithy_types::{Blob, DateTime};
10
11macro_rules! expect_shape_fn {
12    (fn $fn_name:ident[$val_typ:ident] -> $result_typ:ident { $val_name:ident -> $val_expr:expr }) => {
13        #[doc = "Expects that `header` is a `"]
14        #[doc = stringify!($result_typ)]
15        #[doc = "`."]
16        pub fn $fn_name(header: &Header) -> Result<$result_typ, Error> {
17            match header.value() {
18                HeaderValue::$val_typ($val_name) => Ok($val_expr),
19                _ => Err(ErrorKind::Unmarshalling(format!(
20                    "expected '{}' header value to be {}",
21                    header.name().as_str(),
22                    stringify!($val_typ)
23                ))
24                .into()),
25            }
26        }
27    };
28}
29
30expect_shape_fn!(fn expect_bool[Bool] -> bool { value -> *value });
31expect_shape_fn!(fn expect_byte[Byte] -> i8 { value -> *value });
32expect_shape_fn!(fn expect_int16[Int16] -> i16 { value -> *value });
33expect_shape_fn!(fn expect_int32[Int32] -> i32 { value -> *value });
34expect_shape_fn!(fn expect_int64[Int64] -> i64 { value -> *value });
35expect_shape_fn!(fn expect_byte_array[ByteArray] -> Blob { bytes -> Blob::new(bytes.as_ref()) });
36expect_shape_fn!(fn expect_string[String] -> String { value -> value.as_str().into() });
37expect_shape_fn!(fn expect_timestamp[Timestamp] -> DateTime { value -> *value });
38
39#[derive(Debug)]
41pub struct ResponseHeaders<'a> {
42    pub content_type: Option<&'a StrBytes>,
50
51    pub message_type: &'a StrBytes,
56
57    pub smithy_type: &'a StrBytes,
61}
62
63impl ResponseHeaders<'_> {
64    pub fn content_type(&self) -> Option<&str> {
66        self.content_type.map(|ct| ct.as_str())
67    }
68}
69
70fn expect_header_str_value<'a>(
71    header: Option<&'a Header>,
72    name: &str,
73) -> Result<&'a StrBytes, Error> {
74    match header {
75        Some(header) => Ok(header.value().as_string().map_err(|value| {
76            Error::from(ErrorKind::Unmarshalling(format!(
77                "expected response {} header to be string, received {:?}",
78                name, value
79            )))
80        })?),
81        None => Err(ErrorKind::Unmarshalling(format!(
82            "expected response to include {} header, but it was missing",
83            name
84        ))
85        .into()),
86    }
87}
88
89pub fn parse_response_headers(message: &Message) -> Result<ResponseHeaders<'_>, Error> {
94    let (mut content_type, mut message_type, mut event_type, mut exception_type) =
95        (None, None, None, None);
96    for header in message.headers() {
97        match header.name().as_str() {
98            ":content-type" => content_type = Some(header),
99            ":message-type" => message_type = Some(header),
100            ":event-type" => event_type = Some(header),
101            ":exception-type" => exception_type = Some(header),
102            _ => {}
103        }
104    }
105    let message_type = expect_header_str_value(message_type, ":message-type")?;
106    Ok(ResponseHeaders {
107        content_type: content_type
108            .map(|ct| expect_header_str_value(Some(ct), ":content-type"))
109            .transpose()?,
110        message_type,
111        smithy_type: if message_type.as_str() == "event" {
112            expect_header_str_value(event_type, ":event-type")?
113        } else if message_type.as_str() == "exception" {
114            expect_header_str_value(exception_type, ":exception-type")?
115        } else {
116            return Err(ErrorKind::Unmarshalling(format!(
117                "unrecognized `:message-type`: {}",
118                message_type.as_str()
119            ))
120            .into());
121        },
122    })
123}
124
125#[cfg(test)]
126mod tests {
127    use super::parse_response_headers;
128    use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
129
130    #[test]
131    fn normal_message() {
132        let message = Message::new(&b"test"[..])
133            .add_header(Header::new(
134                ":event-type",
135                HeaderValue::String("Foo".into()),
136            ))
137            .add_header(Header::new(
138                ":content-type",
139                HeaderValue::String("application/json".into()),
140            ))
141            .add_header(Header::new(
142                ":message-type",
143                HeaderValue::String("event".into()),
144            ));
145        let parsed = parse_response_headers(&message).unwrap();
146        assert_eq!("Foo", parsed.smithy_type.as_str());
147        assert_eq!(Some("application/json"), parsed.content_type());
148        assert_eq!("event", parsed.message_type.as_str());
149    }
150
151    #[test]
152    fn error_message() {
153        let message = Message::new(&b"test"[..])
154            .add_header(Header::new(
155                ":exception-type",
156                HeaderValue::String("BadRequestException".into()),
157            ))
158            .add_header(Header::new(
159                ":content-type",
160                HeaderValue::String("application/json".into()),
161            ))
162            .add_header(Header::new(
163                ":message-type",
164                HeaderValue::String("exception".into()),
165            ));
166        let parsed = parse_response_headers(&message).unwrap();
167        assert_eq!("BadRequestException", parsed.smithy_type.as_str());
168        assert_eq!(Some("application/json"), parsed.content_type());
169        assert_eq!("exception", parsed.message_type.as_str());
170    }
171
172    #[test]
173    fn missing_exception_type() {
174        let message = Message::new(&b"test"[..])
175            .add_header(Header::new(
176                ":content-type",
177                HeaderValue::String("application/json".into()),
178            ))
179            .add_header(Header::new(
180                ":message-type",
181                HeaderValue::String("exception".into()),
182            ));
183        let error = parse_response_headers(&message).err().unwrap().to_string();
184        assert_eq!(
185            "failed to unmarshall message: expected response to include :exception-type \
186             header, but it was missing",
187            error
188        );
189    }
190
191    #[test]
192    fn missing_event_type() {
193        let message = Message::new(&b"test"[..])
194            .add_header(Header::new(
195                ":content-type",
196                HeaderValue::String("application/json".into()),
197            ))
198            .add_header(Header::new(
199                ":message-type",
200                HeaderValue::String("event".into()),
201            ));
202        let error = parse_response_headers(&message).err().unwrap().to_string();
203        assert_eq!(
204            "failed to unmarshall message: expected response to include :event-type \
205             header, but it was missing",
206            error
207        );
208    }
209
210    #[test]
211    fn missing_content_type() {
212        let message = Message::new(&b"test"[..])
213            .add_header(Header::new(
214                ":event-type",
215                HeaderValue::String("Foo".into()),
216            ))
217            .add_header(Header::new(
218                ":message-type",
219                HeaderValue::String("event".into()),
220            ));
221        let parsed = parse_response_headers(&message).ok().unwrap();
222        assert_eq!(None, parsed.content_type);
223        assert_eq!("Foo", parsed.smithy_type.as_str());
224        assert_eq!("event", parsed.message_type.as_str());
225    }
226
227    #[test]
228    fn content_type_wrong_type() {
229        let message = Message::new(&b"test"[..])
230            .add_header(Header::new(
231                ":event-type",
232                HeaderValue::String("Foo".into()),
233            ))
234            .add_header(Header::new(":content-type", HeaderValue::Int32(16)))
235            .add_header(Header::new(
236                ":message-type",
237                HeaderValue::String("event".into()),
238            ));
239        let error = parse_response_headers(&message).err().unwrap().to_string();
240        assert_eq!(
241            "failed to unmarshall message: expected response :content-type \
242             header to be string, received Int32(16)",
243            error
244        );
245    }
246}