1use std::borrow::Cow;
7
8use aws_smithy_types::{Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22 decoder: minicbor::Decoder<'b>,
23}
24
25#[derive(Debug)]
28pub struct DeserializeError {
29 #[allow(dead_code)]
30 _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 self._inner.fmt(f)
36 }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42 pub(crate) fn new(inner: Error) -> Self {
43 Self { _inner: inner }
44 }
45
46 pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48 Self {
49 _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50 .with_message("encountered unexpected union variant; expected end of union")
51 .at(at),
52 }
53 }
54
55 pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57 Self {
58 _inner: Error::message(format!("encountered unknown union variant {variant_name}"))
59 .at(at),
60 }
61 }
62
63 pub fn mixed_union_variants(at: usize) -> Self {
67 Self {
68 _inner: Error::message(
69 "encountered mixed variants in union; expected a single union variant to be set",
70 )
71 .at(at),
72 }
73 }
74
75 pub fn expected_end_of_stream(at: usize) -> Self {
77 Self {
78 _inner: Error::message("encountered additional data; expected end of stream").at(at),
79 }
80 }
81
82 pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
84 Self {
85 _inner: Error::message(message.into()).at(at),
86 }
87 }
88
89 pub fn is_type_mismatch(&self) -> bool {
93 self._inner.is_type_mismatch()
94 }
95}
96
97macro_rules! delegate_method {
113 ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
114 $(
115 pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
116 self.decoder.$encoder_name().map_err(DeserializeError::new)
117 }
118 )+
119 };
120}
121
122impl<'b> Decoder<'b> {
123 pub fn new(bytes: &'b [u8]) -> Self {
124 Self {
125 decoder: minicbor::Decoder::new(bytes),
126 }
127 }
128
129 pub fn datatype(&self) -> Result<Type, DeserializeError> {
130 self.decoder
131 .datatype()
132 .map(Type::new)
133 .map_err(DeserializeError::new)
134 }
135
136 delegate_method! {
137 skip => skip(());
139 boolean => bool(bool);
141 byte => i8(i8);
143 short => i16(i16);
145 integer => i32(i32);
147 long => i64(i64);
149 float => f32(f32);
151 double => f64(f64);
153 null => null(());
155 list => array(Option<u64>);
157 map => map(Option<u64>);
159 }
160
161 pub fn position(&self) -> usize {
163 self.decoder.position()
164 }
165
166 pub fn set_position(&mut self, pos: usize) {
168 self.decoder.set_position(pos)
169 }
170
171 pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
176 let bookmark = self.decoder.position();
177 match self.decoder.str() {
178 Ok(str_value) => Ok(Cow::Borrowed(str_value)),
179 Err(e) if e.is_type_mismatch() => {
180 self.decoder.set_position(bookmark);
183 Ok(Cow::Owned(self.string()?))
184 }
185 Err(e) => Err(DeserializeError::new(e)),
186 }
187 }
188
189 pub fn string(&mut self) -> Result<String, DeserializeError> {
192 let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
193 let head = iter.next();
194
195 let decoded_string = match head {
196 None => String::new(),
197 Some(head) => {
198 let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
199 for chunk in iter {
200 combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
201 }
202 combined_chunks
203 }
204 };
205
206 Ok(decoded_string)
207 }
208
209 pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
212 let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
213 let parts: Vec<&[u8]> = iter
214 .collect::<Result<_, _>>()
215 .map_err(DeserializeError::new)?;
216
217 Ok(if parts.len() == 1 {
218 Blob::new(parts[0]) } else {
220 Blob::new(parts.concat()) })
222 }
223
224 pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
227 let tag = self.decoder.tag().map_err(DeserializeError::new)?;
228 let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
229
230 if tag != timestamp_tag {
231 Err(DeserializeError::new(Error::message(
232 "expected timestamp tag",
233 )))
234 } else {
235 let epoch_seconds = self.decoder.f64().map_err(DeserializeError::new)?;
244 let mut result = DateTime::from_secs_f64(epoch_seconds);
245 let subsec_nanos = result.subsec_nanos();
246 result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
247 Ok(result)
248 }
249 }
250}
251
252#[allow(dead_code)] #[derive(Debug)]
254pub struct ArrayIter<'a, 'b, T> {
255 inner: minicbor::decode::ArrayIter<'a, 'b, T>,
256}
257
258impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
259 type Item = Result<T, DeserializeError>;
260
261 fn next(&mut self) -> Option<Self::Item> {
262 self.inner
263 .next()
264 .map(|opt| opt.map_err(DeserializeError::new))
265 }
266}
267
268#[allow(dead_code)] #[derive(Debug)]
270pub struct MapIter<'a, 'b, K, V> {
271 inner: minicbor::decode::MapIter<'a, 'b, K, V>,
272}
273
274impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
275where
276 K: minicbor::Decode<'b, ()>,
277 V: minicbor::Decode<'b, ()>,
278{
279 type Item = Result<(K, V), DeserializeError>;
280
281 fn next(&mut self) -> Option<Self::Item> {
282 self.inner
283 .next()
284 .map(|opt| opt.map_err(DeserializeError::new))
285 }
286}
287
288pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
289where
290 F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
291{
292 match decoder.datatype()? {
293 crate::data::Type::Null => {
294 decoder.null()?;
295 Ok(builder)
296 }
297 _ => f(builder, decoder),
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use crate::Decoder;
304 use aws_smithy_types::date_time::Format;
305
306 #[test]
307 fn test_definite_str_is_cow_borrowed() {
308 let definite_bytes = [
310 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
311 ];
312 let mut decoder = Decoder::new(&definite_bytes);
313 let member = decoder.str().expect("could not decode str");
314 assert_eq!(member, "thisIsAKey");
315 assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
316 }
317
318 #[test]
319 fn test_indefinite_str_is_cow_owned() {
320 let indefinite_bytes = [
322 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
323 0x79, 0xff,
324 ];
325 let mut decoder = Decoder::new(&indefinite_bytes);
326 let member = decoder.str().expect("could not decode str");
327 assert_eq!(member, "thisIsAKey");
328 assert!(matches!(member, std::borrow::Cow::Owned(_)));
329 }
330
331 #[test]
332 fn test_empty_str_works() {
333 let bytes = [0x60];
334 let mut decoder = Decoder::new(&bytes);
335 let member = decoder.str().expect("could not decode empty str");
336 assert_eq!(member, "");
337 }
338
339 #[test]
340 fn test_empty_blob_works() {
341 let bytes = [0x40];
342 let mut decoder = Decoder::new(&bytes);
343 let member = decoder.blob().expect("could not decode an empty blob");
344 assert_eq!(member, aws_smithy_types::Blob::new([]));
345 }
346
347 #[test]
348 fn test_indefinite_length_blob() {
349 let indefinite_bytes = [
352 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
353 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
354 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
355 0x61, 0xff,
356 ];
357 let mut decoder = Decoder::new(&indefinite_bytes);
358 let member = decoder.blob().expect("could not decode blob");
359 assert_eq!(
360 member,
361 aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
362 );
363 }
364
365 #[test]
366 fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
367 let bytes = [
370 0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
371 ];
372 let mut decoder = Decoder::new(&bytes);
373 let timestamp = decoder.timestamp().expect("should decode timestamp");
374 assert_eq!(
375 timestamp,
376 aws_smithy_types::date_time::DateTime::from_str(
377 "2000-01-02T20:34:56.123Z",
378 Format::DateTime
379 )
380 .unwrap()
381 );
382 }
383}