1use std::borrow::Cow;
7
8use aws_smithy_types::{Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22 decoder: minicbor::Decoder<'b>,
23}
24
25#[derive(Debug)]
28pub struct DeserializeError {
29 #[allow(dead_code)]
30 _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 self._inner.fmt(f)
36 }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42 pub(crate) fn new(inner: Error) -> Self {
43 Self { _inner: inner }
44 }
45
46 pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48 Self {
49 _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50 .with_message("encountered unexpected union variant; expected end of union")
51 .at(at),
52 }
53 }
54
55 pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57 Self {
58 _inner: Error::message(format!(
59 "encountered unknown union variant {}",
60 variant_name
61 ))
62 .at(at),
63 }
64 }
65
66 pub fn mixed_union_variants(at: usize) -> Self {
70 Self {
71 _inner: Error::message(
72 "encountered mixed variants in union; expected a single union variant to be set",
73 )
74 .at(at),
75 }
76 }
77
78 pub fn expected_end_of_stream(at: usize) -> Self {
80 Self {
81 _inner: Error::message("encountered additional data; expected end of stream").at(at),
82 }
83 }
84
85 pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
87 Self {
88 _inner: Error::message(message.into()).at(at),
89 }
90 }
91
92 pub fn is_type_mismatch(&self) -> bool {
96 self._inner.is_type_mismatch()
97 }
98}
99
100macro_rules! delegate_method {
116 ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
117 $(
118 pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
119 self.decoder.$encoder_name().map_err(DeserializeError::new)
120 }
121 )+
122 };
123}
124
125impl<'b> Decoder<'b> {
126 pub fn new(bytes: &'b [u8]) -> Self {
127 Self {
128 decoder: minicbor::Decoder::new(bytes),
129 }
130 }
131
132 pub fn datatype(&self) -> Result<Type, DeserializeError> {
133 self.decoder
134 .datatype()
135 .map(Type::new)
136 .map_err(DeserializeError::new)
137 }
138
139 delegate_method! {
140 skip => skip(());
142 boolean => bool(bool);
144 byte => i8(i8);
146 short => i16(i16);
148 integer => i32(i32);
150 long => i64(i64);
152 float => f32(f32);
154 double => f64(f64);
156 null => null(());
158 list => array(Option<u64>);
160 map => map(Option<u64>);
162 }
163
164 pub fn position(&self) -> usize {
166 self.decoder.position()
167 }
168
169 pub fn set_position(&mut self, pos: usize) {
171 self.decoder.set_position(pos)
172 }
173
174 pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
179 let bookmark = self.decoder.position();
180 match self.decoder.str() {
181 Ok(str_value) => Ok(Cow::Borrowed(str_value)),
182 Err(e) if e.is_type_mismatch() => {
183 self.decoder.set_position(bookmark);
186 Ok(Cow::Owned(self.string()?))
187 }
188 Err(e) => Err(DeserializeError::new(e)),
189 }
190 }
191
192 pub fn string(&mut self) -> Result<String, DeserializeError> {
195 let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
196 let head = iter.next();
197
198 let decoded_string = match head {
199 None => String::new(),
200 Some(head) => {
201 let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
202 for chunk in iter {
203 combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
204 }
205 combined_chunks
206 }
207 };
208
209 Ok(decoded_string)
210 }
211
212 pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
215 let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
216 let parts: Vec<&[u8]> = iter
217 .collect::<Result<_, _>>()
218 .map_err(DeserializeError::new)?;
219
220 Ok(if parts.len() == 1 {
221 Blob::new(parts[0]) } else {
223 Blob::new(parts.concat()) })
225 }
226
227 pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
230 let tag = self.decoder.tag().map_err(DeserializeError::new)?;
231 let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
232
233 if tag != timestamp_tag {
234 Err(DeserializeError::new(Error::message(
235 "expected timestamp tag",
236 )))
237 } else {
238 let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
247 let mut result = DateTime::from_secs_f64(epoch_seconds);
248 let subsec_nanos = result.subsec_nanos();
249 result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
250 Ok(result)
251 }
252 }
253}
254
255#[allow(dead_code)] #[derive(Debug)]
257pub struct ArrayIter<'a, 'b, T> {
258 inner: minicbor::decode::ArrayIter<'a, 'b, T>,
259}
260
261impl<'a, 'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'a, 'b, T> {
262 type Item = Result<T, DeserializeError>;
263
264 fn next(&mut self) -> Option<Self::Item> {
265 self.inner
266 .next()
267 .map(|opt| opt.map_err(DeserializeError::new))
268 }
269}
270
271#[allow(dead_code)] #[derive(Debug)]
273pub struct MapIter<'a, 'b, K, V> {
274 inner: minicbor::decode::MapIter<'a, 'b, K, V>,
275}
276
277impl<'a, 'b, K, V> Iterator for MapIter<'a, 'b, K, V>
278where
279 K: minicbor::Decode<'b, ()>,
280 V: minicbor::Decode<'b, ()>,
281{
282 type Item = Result<(K, V), DeserializeError>;
283
284 fn next(&mut self) -> Option<Self::Item> {
285 self.inner
286 .next()
287 .map(|opt| opt.map_err(DeserializeError::new))
288 }
289}
290
291pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
292where
293 F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
294{
295 match decoder.datatype()? {
296 crate::data::Type::Null => {
297 decoder.null()?;
298 Ok(builder)
299 }
300 _ => f(builder, decoder),
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use crate::Decoder;
307 use aws_smithy_types::date_time::Format;
308
309 #[test]
310 fn test_definite_str_is_cow_borrowed() {
311 let definite_bytes = [
313 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
314 ];
315 let mut decoder = Decoder::new(&definite_bytes);
316 let member = decoder.str().expect("could not decode str");
317 assert_eq!(member, "thisIsAKey");
318 assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
319 }
320
321 #[test]
322 fn test_indefinite_str_is_cow_owned() {
323 let indefinite_bytes = [
325 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
326 0x79, 0xff,
327 ];
328 let mut decoder = Decoder::new(&indefinite_bytes);
329 let member = decoder.str().expect("could not decode str");
330 assert_eq!(member, "thisIsAKey");
331 assert!(matches!(member, std::borrow::Cow::Owned(_)));
332 }
333
334 #[test]
335 fn test_empty_str_works() {
336 let bytes = [0x60];
337 let mut decoder = Decoder::new(&bytes);
338 let member = decoder.str().expect("could not decode empty str");
339 assert_eq!(member, "");
340 }
341
342 #[test]
343 fn test_empty_blob_works() {
344 let bytes = [0x40];
345 let mut decoder = Decoder::new(&bytes);
346 let member = decoder.blob().expect("could not decode an empty blob");
347 assert_eq!(member, aws_smithy_types::Blob::new(&[]));
348 }
349
350 #[test]
351 fn test_indefinite_length_blob() {
352 let indefinite_bytes = [
355 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
356 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
357 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
358 0x61, 0xff,
359 ];
360 let mut decoder = Decoder::new(&indefinite_bytes);
361 let member = decoder.blob().expect("could not decode blob");
362 assert_eq!(
363 member,
364 aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
365 );
366 }
367
368 #[test]
369 fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
370 let bytes = [
373 0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
374 ];
375 let mut decoder = Decoder::new(&bytes);
376 let timestamp = decoder.timestamp().expect("should decode timestamp");
377 assert_eq!(
378 timestamp,
379 aws_smithy_types::date_time::DateTime::from_str(
380 "2000-01-02T20:34:56.123Z",
381 Format::DateTime
382 )
383 .unwrap()
384 );
385 }
386}