1use std::borrow::Cow;
7
8use aws_smithy_types::{Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22    decoder: minicbor::Decoder<'b>,
23}
24
25#[derive(Debug)]
28pub struct DeserializeError {
29    #[allow(dead_code)]
30    _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        self._inner.fmt(f)
36    }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42    pub(crate) fn new(inner: Error) -> Self {
43        Self { _inner: inner }
44    }
45
46    pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48        Self {
49            _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50                .with_message("encountered unexpected union variant; expected end of union")
51                .at(at),
52        }
53    }
54
55    pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57        Self {
58            _inner: Error::message(format!(
59                "encountered unknown union variant {}",
60                variant_name
61            ))
62            .at(at),
63        }
64    }
65
66    pub fn mixed_union_variants(at: usize) -> Self {
70        Self {
71            _inner: Error::message(
72                "encountered mixed variants in union; expected a single union variant to be set",
73            )
74            .at(at),
75        }
76    }
77
78    pub fn expected_end_of_stream(at: usize) -> Self {
80        Self {
81            _inner: Error::message("encountered additional data; expected end of stream").at(at),
82        }
83    }
84
85    pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
87        Self {
88            _inner: Error::message(message.into()).at(at),
89        }
90    }
91
92    pub fn is_type_mismatch(&self) -> bool {
96        self._inner.is_type_mismatch()
97    }
98}
99
100macro_rules! delegate_method {
116    ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
117        $(
118            pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
119                self.decoder.$encoder_name().map_err(DeserializeError::new)
120            }
121        )+
122    };
123}
124
125impl<'b> Decoder<'b> {
126    pub fn new(bytes: &'b [u8]) -> Self {
127        Self {
128            decoder: minicbor::Decoder::new(bytes),
129        }
130    }
131
132    pub fn datatype(&self) -> Result<Type, DeserializeError> {
133        self.decoder
134            .datatype()
135            .map(Type::new)
136            .map_err(DeserializeError::new)
137    }
138
139    delegate_method! {
140        skip => skip(());
142        boolean => bool(bool);
144        byte => i8(i8);
146        short => i16(i16);
148        integer => i32(i32);
150        long => i64(i64);
152        float => f32(f32);
154        double => f64(f64);
156        null => null(());
158        list => array(Option<u64>);
160        map => map(Option<u64>);
162    }
163
164    pub fn position(&self) -> usize {
166        self.decoder.position()
167    }
168
169    pub fn set_position(&mut self, pos: usize) {
171        self.decoder.set_position(pos)
172    }
173
174    pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
179        let bookmark = self.decoder.position();
180        match self.decoder.str() {
181            Ok(str_value) => Ok(Cow::Borrowed(str_value)),
182            Err(e) if e.is_type_mismatch() => {
183                self.decoder.set_position(bookmark);
186                Ok(Cow::Owned(self.string()?))
187            }
188            Err(e) => Err(DeserializeError::new(e)),
189        }
190    }
191
192    pub fn string(&mut self) -> Result<String, DeserializeError> {
195        let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
196        let head = iter.next();
197
198        let decoded_string = match head {
199            None => String::new(),
200            Some(head) => {
201                let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
202                for chunk in iter {
203                    combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
204                }
205                combined_chunks
206            }
207        };
208
209        Ok(decoded_string)
210    }
211
212    pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
215        let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
216        let parts: Vec<&[u8]> = iter
217            .collect::<Result<_, _>>()
218            .map_err(DeserializeError::new)?;
219
220        Ok(if parts.len() == 1 {
221            Blob::new(parts[0]) } else {
223            Blob::new(parts.concat()) })
225    }
226
227    pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
230        let tag = self.decoder.tag().map_err(DeserializeError::new)?;
231        let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
232
233        if tag != timestamp_tag {
234            Err(DeserializeError::new(Error::message(
235                "expected timestamp tag",
236            )))
237        } else {
238            let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
247            let mut result = DateTime::from_secs_f64(epoch_seconds);
248            let subsec_nanos = result.subsec_nanos();
249            result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
250            Ok(result)
251        }
252    }
253}
254
255#[allow(dead_code)] #[derive(Debug)]
257pub struct ArrayIter<'a, 'b, T> {
258    inner: minicbor::decode::ArrayIter<'a, 'b, T>,
259}
260
261impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
262    type Item = Result<T, DeserializeError>;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        self.inner
266            .next()
267            .map(|opt| opt.map_err(DeserializeError::new))
268    }
269}
270
271#[allow(dead_code)] #[derive(Debug)]
273pub struct MapIter<'a, 'b, K, V> {
274    inner: minicbor::decode::MapIter<'a, 'b, K, V>,
275}
276
277impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
278where
279    K: minicbor::Decode<'b, ()>,
280    V: minicbor::Decode<'b, ()>,
281{
282    type Item = Result<(K, V), DeserializeError>;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        self.inner
286            .next()
287            .map(|opt| opt.map_err(DeserializeError::new))
288    }
289}
290
291pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
292where
293    F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
294{
295    match decoder.datatype()? {
296        crate::data::Type::Null => {
297            decoder.null()?;
298            Ok(builder)
299        }
300        _ => f(builder, decoder),
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use crate::Decoder;
307    use aws_smithy_types::date_time::Format;
308
309    #[test]
310    fn test_definite_str_is_cow_borrowed() {
311        let definite_bytes = [
313            0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
314        ];
315        let mut decoder = Decoder::new(&definite_bytes);
316        let member = decoder.str().expect("could not decode str");
317        assert_eq!(member, "thisIsAKey");
318        assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
319    }
320
321    #[test]
322    fn test_indefinite_str_is_cow_owned() {
323        let indefinite_bytes = [
325            0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
326            0x79, 0xff,
327        ];
328        let mut decoder = Decoder::new(&indefinite_bytes);
329        let member = decoder.str().expect("could not decode str");
330        assert_eq!(member, "thisIsAKey");
331        assert!(matches!(member, std::borrow::Cow::Owned(_)));
332    }
333
334    #[test]
335    fn test_empty_str_works() {
336        let bytes = [0x60];
337        let mut decoder = Decoder::new(&bytes);
338        let member = decoder.str().expect("could not decode empty str");
339        assert_eq!(member, "");
340    }
341
342    #[test]
343    fn test_empty_blob_works() {
344        let bytes = [0x40];
345        let mut decoder = Decoder::new(&bytes);
346        let member = decoder.blob().expect("could not decode an empty blob");
347        assert_eq!(member, aws_smithy_types::Blob::new([]));
348    }
349
350    #[test]
351    fn test_indefinite_length_blob() {
352        let indefinite_bytes = [
355            0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
356            0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
357            0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
358            0x61, 0xff,
359        ];
360        let mut decoder = Decoder::new(&indefinite_bytes);
361        let member = decoder.blob().expect("could not decode blob");
362        assert_eq!(
363            member,
364            aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
365        );
366    }
367
368    #[test]
369    fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
370        let bytes = [
373            0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
374        ];
375        let mut decoder = Decoder::new(&bytes);
376        let timestamp = decoder.timestamp().expect("should decode timestamp");
377        assert_eq!(
378            timestamp,
379            aws_smithy_types::date_time::DateTime::from_str(
380                "2000-01-02T20:34:56.123Z",
381                Format::DateTime
382            )
383            .unwrap()
384        );
385    }
386}