aws_smithy_cbor/
decode.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::borrow::Cow;
7
8use aws_smithy_types::{Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13/// Provides functions for decoding a CBOR object with a known schema.
14///
15/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema
16/// is known in advance. Therefore, the caller can determine which object key exists at the current
17/// position by calling `str` method, and call the relevant function based on the predetermined schema
18/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip
19/// over the element.
20#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22    decoder: minicbor::Decoder<'b>,
23}
24
25/// When any of the decode methods are called they look for that particular data type at the current
26/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned.
27#[derive(Debug)]
28pub struct DeserializeError {
29    #[allow(dead_code)]
30    _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        self._inner.fmt(f)
36    }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42    pub(crate) fn new(inner: Error) -> Self {
43        Self { _inner: inner }
44    }
45
46    /// More than one union variant was detected: `unexpected_type` was unexpected.
47    pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48        Self {
49            _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50                .with_message("encountered unexpected union variant; expected end of union")
51                .at(at),
52        }
53    }
54
55    /// Unknown union variant was detected. Servers reject unknown union varaints.
56    pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57        Self {
58            _inner: Error::message(format!(
59                "encountered unknown union variant {}",
60                variant_name
61            ))
62            .at(at),
63        }
64    }
65
66    /// More than one union variant was detected, but we never even got to parse the first one.
67    /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR
68    /// map whose length (specified upfront) is a value different than 1.
69    pub fn mixed_union_variants(at: usize) -> Self {
70        Self {
71            _inner: Error::message(
72                "encountered mixed variants in union; expected a single union variant to be set",
73            )
74            .at(at),
75        }
76    }
77
78    /// Expected end of stream but more data is available.
79    pub fn expected_end_of_stream(at: usize) -> Self {
80        Self {
81            _inner: Error::message("encountered additional data; expected end of stream").at(at),
82        }
83    }
84
85    /// Returns a custom error with an offset.
86    pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
87        Self {
88            _inner: Error::message(message.into()).at(at),
89        }
90    }
91
92    /// An unexpected type was encountered.
93    // We handle this one when decoding sparse collections: we have to expect either a `null` or an
94    // item, so we try decoding both.
95    pub fn is_type_mismatch(&self) -> bool {
96        self._inner.is_type_mismatch()
97    }
98}
99
100/// Macro for delegating method calls to the decoder.
101///
102/// This macro generates wrapper methods for calling specific methods on the decoder and returning
103/// the result with error handling.
104///
105/// # Example
106///
107/// ```ignore
108/// delegate_method! {
109///     /// Wrapper method for encoding method `encode_str` on the decoder.
110///     encode_str_wrapper => encode_str(String);
111///     /// Wrapper method for encoding method `encode_int` on the decoder.
112///     encode_int_wrapper => encode_int(i32);
113/// }
114/// ```
115macro_rules! delegate_method {
116    ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
117        $(
118            pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
119                self.decoder.$encoder_name().map_err(DeserializeError::new)
120            }
121        )+
122    };
123}
124
125impl<'b> Decoder<'b> {
126    pub fn new(bytes: &'b [u8]) -> Self {
127        Self {
128            decoder: minicbor::Decoder::new(bytes),
129        }
130    }
131
132    pub fn datatype(&self) -> Result<Type, DeserializeError> {
133        self.decoder
134            .datatype()
135            .map(Type::new)
136            .map_err(DeserializeError::new)
137    }
138
139    delegate_method! {
140        /// Skips the current CBOR element.
141        skip => skip(());
142        /// Reads a boolean at the current position.
143        boolean => bool(bool);
144        /// Reads a byte at the current position.
145        byte => i8(i8);
146        /// Reads a short at the current position.
147        short => i16(i16);
148        /// Reads a integer at the current position.
149        integer => i32(i32);
150        /// Reads a long at the current position.
151        long => i64(i64);
152        /// Reads a float at the current position.
153        float => f32(f32);
154        /// Reads a double at the current position.
155        double => f64(f64);
156        /// Reads a null CBOR element at the current position.
157        null => null(());
158        /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`.
159        list => array(Option<u64>);
160        /// Returns the number of elements in a definite map. For indefinite map it returns a `None`.
161        map => map(Option<u64>);
162    }
163
164    /// Returns the current position of the buffer, which will be decoded when any of the methods is called.
165    pub fn position(&self) -> usize {
166        self.decoder.position()
167    }
168
169    /// Set the current decode position.
170    pub fn set_position(&mut self, pos: usize) {
171        self.decoder.set_position(pos)
172    }
173
174    /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite
175    /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an
176    /// indefinite-length string. An error is returned if the element is neither a definite length nor an
177    /// indefinite-length string.
178    pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
179        let bookmark = self.decoder.position();
180        match self.decoder.str() {
181            Ok(str_value) => Ok(Cow::Borrowed(str_value)),
182            Err(e) if e.is_type_mismatch() => {
183                // Move the position back to the start of the CBOR element and then try
184                // decoding it as an indefinite length string.
185                self.decoder.set_position(bookmark);
186                Ok(Cow::Owned(self.string()?))
187            }
188            Err(e) => Err(DeserializeError::new(e)),
189        }
190    }
191
192    /// Allocates and returns a `String` if the element at the current position in the buffer is either a
193    /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type.
194    pub fn string(&mut self) -> Result<String, DeserializeError> {
195        let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
196        let head = iter.next();
197
198        let decoded_string = match head {
199            None => String::new(),
200            Some(head) => {
201                let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
202                for chunk in iter {
203                    combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
204                }
205                combined_chunks
206            }
207        };
208
209        Ok(decoded_string)
210    }
211
212    /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise,
213    /// a `DeserializeError` error is returned.
214    pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
215        let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
216        let parts: Vec<&[u8]> = iter
217            .collect::<Result<_, _>>()
218            .map_err(DeserializeError::new)?;
219
220        Ok(if parts.len() == 1 {
221            Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part.
222        } else {
223            Blob::new(parts.concat()) // Concatenate all parts into a single Blob.
224        })
225    }
226
227    /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise,
228    /// a `DeserializeError` error is returned.
229    pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
230        let tag = self.decoder.tag().map_err(DeserializeError::new)?;
231        let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
232
233        if tag != timestamp_tag {
234            Err(DeserializeError::new(Error::message(
235                "expected timestamp tag",
236            )))
237        } else {
238            // Values that are more granular than millisecond precision SHOULD be truncated to fit
239            // millisecond precision for epoch-seconds:
240            // https://smithy.io/2.0/spec/protocol-traits.html#timestamp-formats
241            //
242            // Without truncation, the `RpcV2CborDateTimeWithFractionalSeconds` protocol test would
243            // fail since the upstream test expect `123000000` in subsec but the decoded actual
244            // subsec would be `123000025`.
245            // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/fractional-seconds.smithy#L17
246            let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
247            let mut result = DateTime::from_secs_f64(epoch_seconds);
248            let subsec_nanos = result.subsec_nanos();
249            result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
250            Ok(result)
251        }
252    }
253}
254
255#[allow(dead_code)] // to avoid `never constructed` warning
256#[derive(Debug)]
257pub struct ArrayIter<'a, 'b, T> {
258    inner: minicbor::decode::ArrayIter<'a, 'b, T>,
259}
260
261impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> {
262    type Item = Result<T, DeserializeError>;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        self.inner
266            .next()
267            .map(|opt| opt.map_err(DeserializeError::new))
268    }
269}
270
271#[allow(dead_code)] // to avoid `never constructed` warning
272#[derive(Debug)]
273pub struct MapIter<'a, 'b, K, V> {
274    inner: minicbor::decode::MapIter<'a, 'b, K, V>,
275}
276
277impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V>
278where
279    K: minicbor::Decode<'b, ()>,
280    V: minicbor::Decode<'b, ()>,
281{
282    type Item = Result<(K, V), DeserializeError>;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        self.inner
286            .next()
287            .map(|opt| opt.map_err(DeserializeError::new))
288    }
289}
290
291pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
292where
293    F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
294{
295    match decoder.datatype()? {
296        crate::data::Type::Null => {
297            decoder.null()?;
298            Ok(builder)
299        }
300        _ => f(builder, decoder),
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use crate::Decoder;
307    use aws_smithy_types::date_time::Format;
308
309    #[test]
310    fn test_definite_str_is_cow_borrowed() {
311        // Definite length key `thisIsAKey`.
312        let definite_bytes = [
313            0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
314        ];
315        let mut decoder = Decoder::new(&definite_bytes);
316        let member = decoder.str().expect("could not decode str");
317        assert_eq!(member, "thisIsAKey");
318        assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
319    }
320
321    #[test]
322    fn test_indefinite_str_is_cow_owned() {
323        // Indefinite length key `this`, `Is`, `A` and `Key`.
324        let indefinite_bytes = [
325            0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
326            0x79, 0xff,
327        ];
328        let mut decoder = Decoder::new(&indefinite_bytes);
329        let member = decoder.str().expect("could not decode str");
330        assert_eq!(member, "thisIsAKey");
331        assert!(matches!(member, std::borrow::Cow::Owned(_)));
332    }
333
334    #[test]
335    fn test_empty_str_works() {
336        let bytes = [0x60];
337        let mut decoder = Decoder::new(&bytes);
338        let member = decoder.str().expect("could not decode empty str");
339        assert_eq!(member, "");
340    }
341
342    #[test]
343    fn test_empty_blob_works() {
344        let bytes = [0x40];
345        let mut decoder = Decoder::new(&bytes);
346        let member = decoder.blob().expect("could not decode an empty blob");
347        assert_eq!(member, aws_smithy_types::Blob::new(&[]));
348    }
349
350    #[test]
351    fn test_indefinite_length_blob() {
352        // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`.
353        // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff
354        let indefinite_bytes = [
355            0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
356            0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
357            0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
358            0x61, 0xff,
359        ];
360        let mut decoder = Decoder::new(&indefinite_bytes);
361        let member = decoder.blob().expect("could not decode blob");
362        assert_eq!(
363            member,
364            aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
365        );
366    }
367
368    #[test]
369    fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
370        // Input bytes are derived from the `RpcV2CborDateTimeWithFractionalSeconds` protocol test,
371        // extracting portion representing a timestamp value.
372        let bytes = [
373            0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
374        ];
375        let mut decoder = Decoder::new(&bytes);
376        let timestamp = decoder.timestamp().expect("should decode timestamp");
377        assert_eq!(
378            timestamp,
379            aws_smithy_types::date_time::DateTime::from_str(
380                "2000-01-02T20:34:56.123Z",
381                Format::DateTime
382            )
383            .unwrap()
384        );
385    }
386}