aws_smithy_cbor/
encode.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_types::{Blob, DateTime};
7
8/// Macro for delegating method calls to the encoder.
9///
10/// This macro generates wrapper methods for calling specific encoder methods on the encoder
11/// and returning a mutable reference to self for method chaining.
12///
13/// # Example
14///
15/// ```ignore
16/// delegate_method! {
17///     /// Wrapper method for encoding method `encode_str` on the encoder.
18///     encode_str_wrapper => encode_str(data: &str);
19///     /// Wrapper method for encoding method `encode_int` on the encoder.
20///     encode_int_wrapper => encode_int(value: i32);
21/// }
22/// ```
23macro_rules! delegate_method {
24    ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($($param_name:ident : $param_type:ty),*);)+) => {
25        $(
26            pub fn $wrapper_name(&mut self, $($param_name: $param_type),*) -> &mut Self {
27                self.encoder.$encoder_name($($param_name)*).expect(INFALLIBLE_WRITE);
28                self
29            }
30        )+
31    };
32}
33
34#[derive(Debug, Clone)]
35pub struct Encoder {
36    encoder: minicbor::Encoder<Vec<u8>>,
37}
38
39/// We always write to a `Vec<u8>`, which is infallible in `minicbor`.
40/// <https://docs.rs/minicbor/latest/minicbor/encode/write/trait.Write.html#impl-Write-for-Vec%3Cu8%3E>
41const INFALLIBLE_WRITE: &str = "write failed";
42
43impl Encoder {
44    pub fn new(writer: Vec<u8>) -> Self {
45        Self {
46            encoder: minicbor::Encoder::new(writer),
47        }
48    }
49
50    delegate_method! {
51        /// Used when it's not cheap to calculate the size, i.e. when the struct has one or more
52        /// `Option`al members.
53        begin_map => begin_map();
54        /// Writes a boolean value.
55        boolean => bool(x: bool);
56        /// Writes a byte value.
57        byte => i8(x: i8);
58        /// Writes a short value.
59        short => i16(x: i16);
60        /// Writes an integer value.
61        integer => i32(x: i32);
62        /// Writes an long value.
63        long => i64(x: i64);
64        /// Writes an float value.
65        float => f32(x: f32);
66        /// Writes an double value.
67        double => f64(x: f64);
68        /// Writes a null tag.
69        null => null();
70        /// Writes an end tag.
71        end => end();
72    }
73
74    /// Maximum size of a CBOR type+length header: 1 byte major type + up to 8 bytes for the length.
75    const MAX_HEADER_LEN: usize = 9;
76
77    /// Writes a CBOR type+length header directly to the writer.
78    ///
79    /// Encodes the "additional information" field per RFC 8949 §3:
80    /// - 0..=23: length is stored directly in the low 5 bits of the initial byte.
81    /// - 24: one-byte uint follows (value 24..=0xff).
82    /// - 25: two-byte big-endian uint follows (value 0x100..=0xffff).
83    /// - 26: four-byte big-endian uint follows (value 0x1_0000..=0xffff_ffff).
84    /// - 27: eight-byte big-endian uint follows (larger values).
85    #[inline]
86    fn write_type_len(writer: &mut Vec<u8>, major: u8, len: usize) {
87        let mut buf = [0u8; Self::MAX_HEADER_LEN];
88        let n = match len {
89            0..=23 => {
90                buf[0] = major | len as u8;
91                1
92            }
93            24..=0xff => {
94                buf[0] = major | 24;
95                buf[1] = len as u8;
96                2
97            }
98            0x100..=0xffff => {
99                buf[0] = major | 25;
100                buf[1..3].copy_from_slice(&(len as u16).to_be_bytes());
101                3
102            }
103            0x1_0000..=0xffff_ffff => {
104                buf[0] = major | 26;
105                buf[1..5].copy_from_slice(&(len as u32).to_be_bytes());
106                5
107            }
108            _ => {
109                buf[0] = major | 27;
110                buf[1..9].copy_from_slice(&(len as u64).to_be_bytes());
111                9
112            }
113        };
114        writer.extend_from_slice(&buf[..n]);
115    }
116
117    /// Writes a definite length string. Collapses header+data into a single reserve+write.
118    pub fn str(&mut self, x: &str) -> &mut Self {
119        let writer = self.encoder.writer_mut();
120        let len = x.len();
121        writer.reserve(Self::MAX_HEADER_LEN + len);
122        Self::write_type_len(writer, 0x60, len);
123        writer.extend_from_slice(x.as_bytes());
124        self
125    }
126
127    /// Writes a blob. Collapses header+data into a single reserve+write.
128    pub fn blob(&mut self, x: &Blob) -> &mut Self {
129        let data = x.as_ref();
130        let writer = self.encoder.writer_mut();
131        let len = data.len();
132        writer.reserve(Self::MAX_HEADER_LEN + len);
133        Self::write_type_len(writer, 0x40, len);
134        writer.extend_from_slice(data);
135        self
136    }
137
138    /// Writes a fixed length array of given length.
139    pub fn array(&mut self, len: usize) -> &mut Self {
140        Self::write_type_len(self.encoder.writer_mut(), 0x80, len);
141        self
142    }
143
144    /// Writes a fixed length map of given length.
145    /// Used when we know the size in advance, i.e.:
146    /// - when a struct has all non-`Option`al members.
147    /// - when serializing `union` shapes (they can only have one member set).
148    /// - when serializing a `map` shape.
149    pub fn map(&mut self, len: usize) -> &mut Self {
150        Self::write_type_len(self.encoder.writer_mut(), 0xa0, len);
151        self
152    }
153
154    pub fn timestamp(&mut self, x: &DateTime) -> &mut Self {
155        self.encoder
156            .tag(minicbor::data::Tag::from(
157                minicbor::data::IanaTag::Timestamp,
158            ))
159            .expect(INFALLIBLE_WRITE);
160        self.encoder.f64(x.as_secs_f64()).expect(INFALLIBLE_WRITE);
161        self
162    }
163
164    pub fn into_writer(self) -> Vec<u8> {
165        self.encoder.into_writer()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::Encoder;
172    use aws_smithy_types::Blob;
173
174    /// Verify our `str()` produces byte-identical output to minicbor's.
175    #[test]
176    fn str_matches_minicbor() {
177        let cases = [
178            "",                        // len 0
179            "a",                       // len 1 (in 0..=23 range)
180            "hello world!! test str",  // len 22 (still 0..=23)
181            "this is exactly 24 char", // len 24 (0x18, first 1-byte length)
182            &"x".repeat(0xff),         // len 255 (max 1-byte length)
183            &"y".repeat(0x100),        // len 256 (first 2-byte length)
184            &"z".repeat(0x1_0000),     // len 65536 (first 4-byte length)
185        ];
186        for input in &cases {
187            let mut ours = Encoder::new(Vec::new());
188            ours.str(input);
189
190            let mut theirs = minicbor::Encoder::new(Vec::new());
191            theirs.str(input).unwrap();
192
193            assert_eq!(
194                ours.into_writer(),
195                theirs.into_writer(),
196                "str mismatch for input len={}",
197                input.len()
198            );
199        }
200    }
201
202    /// Verify our `blob()` produces byte-identical output to minicbor's.
203    #[test]
204    fn blob_matches_minicbor() {
205        let cases: Vec<Vec<u8>> = vec![
206            vec![],               // empty
207            vec![0x42],           // 1 byte
208            vec![0xAB; 23],       // max inline length
209            vec![0xCD; 24],       // first 1-byte length
210            vec![0xEF; 0xff],     // max 1-byte length
211            vec![0x01; 0x100],    // first 2-byte length
212            vec![0x02; 0x1_0000], // first 4-byte length
213        ];
214        for input in &cases {
215            let mut ours = Encoder::new(Vec::new());
216            ours.blob(&Blob::new(input.clone()));
217
218            let mut theirs = minicbor::Encoder::new(Vec::new());
219            theirs.bytes(input).unwrap();
220
221            assert_eq!(
222                ours.into_writer(),
223                theirs.into_writer(),
224                "blob mismatch for input len={}",
225                input.len()
226            );
227        }
228    }
229
230    /// Verify chained `str()` calls don't corrupt encoder state for subsequent writes.
231    #[test]
232    fn str_chained_matches_minicbor() {
233        let mut ours = Encoder::new(Vec::new());
234        ours.str("key1").str("value1").str("key2").str("value2");
235
236        let mut theirs = minicbor::Encoder::new(Vec::new());
237        theirs
238            .str("key1")
239            .unwrap()
240            .str("value1")
241            .unwrap()
242            .str("key2")
243            .unwrap()
244            .str("value2")
245            .unwrap();
246
247        assert_eq!(ours.into_writer(), theirs.into_writer());
248    }
249
250    /// Verify `str()` works correctly inside a map structure (the real-world hot path).
251    #[test]
252    fn str_inside_map_matches_minicbor() {
253        let mut ours = Encoder::new(Vec::new());
254        ours.begin_map().str("TableName").str("my-table").end();
255
256        let mut theirs = minicbor::Encoder::new(Vec::new());
257        theirs
258            .begin_map()
259            .unwrap()
260            .str("TableName")
261            .unwrap()
262            .str("my-table")
263            .unwrap()
264            .end()
265            .unwrap();
266
267        assert_eq!(ours.into_writer(), theirs.into_writer());
268    }
269
270    /// Verify `str()` handles multi-byte UTF-8 correctly (CBOR text strings must be valid UTF-8).
271    #[test]
272    fn str_utf8_matches_minicbor() {
273        let cases = [
274            "café",          // 2-byte UTF-8
275            "日本語",        // 3-byte UTF-8
276            "🦀🔥",          // 4-byte UTF-8 (emoji)
277            "mixed: aé日🦀", // all byte widths
278        ];
279        for input in &cases {
280            let mut ours = Encoder::new(Vec::new());
281            ours.str(input);
282
283            let mut theirs = minicbor::Encoder::new(Vec::new());
284            theirs.str(input).unwrap();
285
286            assert_eq!(
287                ours.into_writer(),
288                theirs.into_writer(),
289                "str UTF-8 mismatch for {:?}",
290                input
291            );
292        }
293    }
294}