aws_smithy_cbor/
decode.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::borrow::Cow;
7
8use aws_smithy_types::{Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13/// Provides functions for decoding a CBOR object with a known schema.
14///
15/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema
16/// is known in advance. Therefore, the caller can determine which object key exists at the current
17/// position by calling `str` method, and call the relevant function based on the predetermined schema
18/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip
19/// over the element.
20#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22    decoder: minicbor::Decoder<'b>,
23}
24
25/// When any of the decode methods are called they look for that particular data type at the current
26/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned.
27#[derive(Debug)]
28pub struct DeserializeError {
29    #[allow(dead_code)]
30    _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        self._inner.fmt(f)
36    }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42    pub(crate) fn new(inner: Error) -> Self {
43        Self { _inner: inner }
44    }
45
46    /// More than one union variant was detected: `unexpected_type` was unexpected.
47    pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48        Self {
49            _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50                .with_message("encountered unexpected union variant; expected end of union")
51                .at(at),
52        }
53    }
54
55    /// Unknown union variant was detected. Servers reject unknown union varaints.
56    pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57        Self {
58            _inner: Error::message(format!("encountered unknown union variant {variant_name}"))
59                .at(at),
60        }
61    }
62
63    /// More than one union variant was detected, but we never even got to parse the first one.
64    /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR
65    /// map whose length (specified upfront) is a value different than 1.
66    pub fn mixed_union_variants(at: usize) -> Self {
67        Self {
68            _inner: Error::message(
69                "encountered mixed variants in union; expected a single union variant to be set",
70            )
71            .at(at),
72        }
73    }
74
75    /// Expected end of stream but more data is available.
76    pub fn expected_end_of_stream(at: usize) -> Self {
77        Self {
78            _inner: Error::message("encountered additional data; expected end of stream").at(at),
79        }
80    }
81
82    /// Returns a custom error with an offset.
83    pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
84        Self {
85            _inner: Error::message(message.into()).at(at),
86        }
87    }
88
89    /// An unexpected type was encountered.
90    // We handle this one when decoding sparse collections: we have to expect either a `null` or an
91    // item, so we try decoding both.
92    pub fn is_type_mismatch(&self) -> bool {
93        self._inner.is_type_mismatch()
94    }
95}
96
97/// Macro for delegating method calls to the decoder.
98///
99/// This macro generates wrapper methods for calling specific methods on the decoder and returning
100/// the result with error handling.
101///
102/// # Example
103///
104/// ```ignore
105/// delegate_method! {
106///     /// Wrapper method for encoding method `encode_str` on the decoder.
107///     encode_str_wrapper => encode_str(String);
108///     /// Wrapper method for encoding method `encode_int` on the decoder.
109///     encode_int_wrapper => encode_int(i32);
110/// }
111/// ```
112macro_rules! delegate_method {
113    ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
114        $(
115            pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
116                self.decoder.$encoder_name().map_err(DeserializeError::new)
117            }
118        )+
119    };
120}
121
122impl<'b> Decoder<'b> {
123    pub fn new(bytes: &'b [u8]) -> Self {
124        Self {
125            decoder: minicbor::Decoder::new(bytes),
126        }
127    }
128
129    pub fn datatype(&self) -> Result<Type, DeserializeError> {
130        self.decoder
131            .datatype()
132            .map(Type::new)
133            .map_err(DeserializeError::new)
134    }
135
136    delegate_method! {
137        /// Skips the current CBOR element.
138        skip => skip(());
139        /// Reads a boolean at the current position.
140        boolean => bool(bool);
141        /// Reads a byte at the current position.
142        byte => i8(i8);
143        /// Reads a short at the current position.
144        short => i16(i16);
145        /// Reads a integer at the current position.
146        integer => i32(i32);
147        /// Reads a long at the current position.
148        long => i64(i64);
149        /// Reads a float at the current position.
150        float => f32(f32);
151        /// Reads a double at the current position.
152        double => f64(f64);
153        /// Reads a null CBOR element at the current position.
154        null => null(());
155        /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`.
156        list => array(Option<u64>);
157        /// Returns the number of elements in a definite map. For indefinite map it returns a `None`.
158        map => map(Option<u64>);
159    }
160
161    /// Returns the current position of the buffer, which will be decoded when any of the methods is called.
162    pub fn position(&self) -> usize {
163        self.decoder.position()
164    }
165
166    /// Set the current decode position.
167    pub fn set_position(&mut self, pos: usize) {
168        self.decoder.set_position(pos)
169    }
170
171    /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite
172    /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an
173    /// indefinite-length string. An error is returned if the element is neither a definite length nor an
174    /// indefinite-length string.
175    pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
176        let bookmark = self.decoder.position();
177        match self.decoder.str() {
178            Ok(str_value) => Ok(Cow::Borrowed(str_value)),
179            Err(e) if e.is_type_mismatch() => {
180                // Move the position back to the start of the CBOR element and then try
181                // decoding it as an indefinite length string.
182                self.decoder.set_position(bookmark);
183                Ok(Cow::Owned(self.string()?))
184            }
185            Err(e) => Err(DeserializeError::new(e)),
186        }
187    }
188
189    /// Allocates and returns a `String` if the element at the current position in the buffer is either a
190    /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type.
191    pub fn string(&mut self) -> Result<String, DeserializeError> {
192        let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
193        let head = iter.next();
194
195        let decoded_string = match head {
196            None => String::new(),
197            Some(head) => {
198                let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
199                for chunk in iter {
200                    combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
201                }
202                combined_chunks
203            }
204        };
205
206        Ok(decoded_string)
207    }
208
209    /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise,
210    /// a `DeserializeError` error is returned.
211    pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
212        let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
213        let parts: Vec<&[u8]> = iter
214            .collect::<Result<_, _>>()
215            .map_err(DeserializeError::new)?;
216
217        Ok(if parts.len() == 1 {
218            Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part.
219        } else {
220            Blob::new(parts.concat()) // Concatenate all parts into a single Blob.
221        })
222    }
223
224    /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise,
225    /// a `DeserializeError` error is returned.
226    pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
227        let tag = self.decoder.tag().map_err(DeserializeError::new)?;
228        let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
229
230        if tag != timestamp_tag {
231            Err(DeserializeError::new(Error::message(
232                "expected timestamp tag",
233            )))
234        } else {
235            // Values that are more granular than millisecond precision SHOULD be truncated to fit
236            // millisecond precision for epoch-seconds:
237            // https://smithy.io/2.0/spec/protocol-traits.html#timestamp-formats
238            //
239            // Without truncation, the `RpcV2CborDateTimeWithFractionalSeconds` protocol test would
240            // fail since the upstream test expect `123000000` in subsec but the decoded actual
241            // subsec would be `123000025`.
242            // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/fractional-seconds.smithy#L17
243            let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
244            let mut result = DateTime::from_secs_f64(epoch_seconds);
245            let subsec_nanos = result.subsec_nanos();
246            result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
247            Ok(result)
248        }
249    }
250}
251
252#[allow(dead_code)] // to avoid `never constructed` warning
253#[derive(Debug)]
254pub struct ArrayIter<'a, 'b, T> {
255    inner: minicbor::decode::ArrayIter<'a, 'b, T>,
256}
257
258impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
259    type Item = Result<T, DeserializeError>;
260
261    fn next(&mut self) -> Option<Self::Item> {
262        self.inner
263            .next()
264            .map(|opt| opt.map_err(DeserializeError::new))
265    }
266}
267
268#[allow(dead_code)] // to avoid `never constructed` warning
269#[derive(Debug)]
270pub struct MapIter<'a, 'b, K, V> {
271    inner: minicbor::decode::MapIter<'a, 'b, K, V>,
272}
273
274impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
275where
276    K: minicbor::Decode<'b, ()>,
277    V: minicbor::Decode<'b, ()>,
278{
279    type Item = Result<(K, V), DeserializeError>;
280
281    fn next(&mut self) -> Option<Self::Item> {
282        self.inner
283            .next()
284            .map(|opt| opt.map_err(DeserializeError::new))
285    }
286}
287
288pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
289where
290    F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
291{
292    match decoder.datatype()? {
293        crate::data::Type::Null => {
294            decoder.null()?;
295            Ok(builder)
296        }
297        _ => f(builder, decoder),
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use crate::Decoder;
304    use aws_smithy_types::date_time::Format;
305
306    #[test]
307    fn test_definite_str_is_cow_borrowed() {
308        // Definite length key `thisIsAKey`.
309        let definite_bytes = [
310            0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
311        ];
312        let mut decoder = Decoder::new(&definite_bytes);
313        let member = decoder.str().expect("could not decode str");
314        assert_eq!(member, "thisIsAKey");
315        assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
316    }
317
318    #[test]
319    fn test_indefinite_str_is_cow_owned() {
320        // Indefinite length key `this`, `Is`, `A` and `Key`.
321        let indefinite_bytes = [
322            0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
323            0x79, 0xff,
324        ];
325        let mut decoder = Decoder::new(&indefinite_bytes);
326        let member = decoder.str().expect("could not decode str");
327        assert_eq!(member, "thisIsAKey");
328        assert!(matches!(member, std::borrow::Cow::Owned(_)));
329    }
330
331    #[test]
332    fn test_empty_str_works() {
333        let bytes = [0x60];
334        let mut decoder = Decoder::new(&bytes);
335        let member = decoder.str().expect("could not decode empty str");
336        assert_eq!(member, "");
337    }
338
339    #[test]
340    fn test_empty_blob_works() {
341        let bytes = [0x40];
342        let mut decoder = Decoder::new(&bytes);
343        let member = decoder.blob().expect("could not decode an empty blob");
344        assert_eq!(member, aws_smithy_types::Blob::new([]));
345    }
346
347    #[test]
348    fn test_indefinite_length_blob() {
349        // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`.
350        // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff
351        let indefinite_bytes = [
352            0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
353            0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
354            0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
355            0x61, 0xff,
356        ];
357        let mut decoder = Decoder::new(&indefinite_bytes);
358        let member = decoder.blob().expect("could not decode blob");
359        assert_eq!(
360            member,
361            aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
362        );
363    }
364
365    #[test]
366    fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
367        // Input bytes are derived from the `RpcV2CborDateTimeWithFractionalSeconds` protocol test,
368        // extracting portion representing a timestamp value.
369        let bytes = [
370            0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
371        ];
372        let mut decoder = Decoder::new(&bytes);
373        let timestamp = decoder.timestamp().expect("should decode timestamp");
374        assert_eq!(
375            timestamp,
376            aws_smithy_types::date_time::DateTime::from_str(
377                "2000-01-02T20:34:56.123Z",
378                Format::DateTime
379            )
380            .unwrap()
381        );
382    }
383}