aws_smithy_json/
escape.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::borrow::Cow;
7use std::fmt;
8
9#[derive(Debug, PartialEq, Eq)]
10enum EscapeErrorKind {
11    ExpectedSurrogatePair(String),
12    InvalidEscapeCharacter(char),
13    InvalidSurrogatePair(u16, u16),
14    InvalidUnicodeEscape(String),
15    InvalidUtf8,
16    UnexpectedEndOfString,
17}
18
19#[derive(Debug)]
20#[cfg_attr(test, derive(PartialEq, Eq))]
21pub struct EscapeError {
22    kind: EscapeErrorKind,
23}
24
25impl std::error::Error for EscapeError {}
26
27impl fmt::Display for EscapeError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        use EscapeErrorKind::*;
30        match &self.kind {
31            ExpectedSurrogatePair(low) => {
32                write!(
33                    f,
34                    "expected a UTF-16 surrogate pair, but got {low} as the low word"
35                )
36            }
37            InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{chr}"),
38            InvalidSurrogatePair(high, low) => {
39                write!(f, "invalid surrogate pair: \\u{high:04X}\\u{low:04X}")
40            }
41            InvalidUnicodeEscape(escape) => write!(f, "invalid JSON Unicode escape: \\u{escape}"),
42            InvalidUtf8 => write!(f, "invalid UTF-8 codepoint in JSON string"),
43            UnexpectedEndOfString => write!(f, "unexpected end of string"),
44        }
45    }
46}
47
48impl From<EscapeErrorKind> for EscapeError {
49    fn from(kind: EscapeErrorKind) -> Self {
50        Self { kind }
51    }
52}
53
54/// Escapes a string for embedding in a JSON string value.
55pub(crate) fn escape_string(value: &str) -> Cow<'_, str> {
56    let bytes = value.as_bytes();
57    for (index, byte) in bytes.iter().enumerate() {
58        match byte {
59            0..=0x1F | b'"' | b'\\' => {
60                return Cow::Owned(escape_string_inner(&bytes[0..index], &bytes[index..]))
61            }
62            _ => {}
63        }
64    }
65    Cow::Borrowed(value)
66}
67
68fn escape_string_inner(start: &[u8], rest: &[u8]) -> String {
69    let mut escaped = Vec::with_capacity(start.len() + rest.len() + 1);
70    escaped.extend(start);
71
72    for byte in rest {
73        match byte {
74            b'"' => escaped.extend(b"\\\""),
75            b'\\' => escaped.extend(b"\\\\"),
76            0x08 => escaped.extend(b"\\b"),
77            0x0C => escaped.extend(b"\\f"),
78            b'\n' => escaped.extend(b"\\n"),
79            b'\r' => escaped.extend(b"\\r"),
80            b'\t' => escaped.extend(b"\\t"),
81            0..=0x1F => escaped.extend(format!("\\u{byte:04x}").bytes()),
82            _ => escaped.push(*byte),
83        }
84    }
85
86    // This is safe because:
87    // - The original input was valid UTF-8 since it came in as a `&str`
88    // - Only single-byte code points were escaped
89    // - The escape sequences are valid UTF-8
90    debug_assert!(std::str::from_utf8(&escaped).is_ok());
91    unsafe { String::from_utf8_unchecked(escaped) }
92}
93
94/// Unescapes a JSON-escaped string.
95/// If there are no escape sequences, it directly returns the reference.
96pub(crate) fn unescape_string(value: &str) -> Result<Cow<'_, str>, EscapeError> {
97    let bytes = value.as_bytes();
98    for (index, byte) in bytes.iter().enumerate() {
99        if *byte == b'\\' {
100            return unescape_string_inner(&bytes[0..index], &bytes[index..]).map(Cow::Owned);
101        }
102    }
103    Ok(Cow::Borrowed(value))
104}
105
106fn unescape_string_inner(start: &[u8], rest: &[u8]) -> Result<String, EscapeError> {
107    let mut unescaped = Vec::with_capacity(start.len() + rest.len());
108    unescaped.extend(start);
109
110    let mut index = 0;
111    while index < rest.len() {
112        match rest[index] {
113            b'\\' => {
114                index += 1;
115                if index == rest.len() {
116                    return Err(EscapeErrorKind::UnexpectedEndOfString.into());
117                }
118                match rest[index] {
119                    b'u' => {
120                        index -= 1;
121                        index += read_unicode_escapes(&rest[index..], &mut unescaped)?;
122                    }
123                    byte => {
124                        match byte {
125                            b'\\' => unescaped.push(b'\\'),
126                            b'/' => unescaped.push(b'/'),
127                            b'"' => unescaped.push(b'"'),
128                            b'b' => unescaped.push(0x08),
129                            b'f' => unescaped.push(0x0C),
130                            b'n' => unescaped.push(b'\n'),
131                            b'r' => unescaped.push(b'\r'),
132                            b't' => unescaped.push(b'\t'),
133                            _ => {
134                                return Err(
135                                    EscapeErrorKind::InvalidEscapeCharacter(byte.into()).into()
136                                )
137                            }
138                        }
139                        index += 1;
140                    }
141                }
142            }
143            byte => {
144                unescaped.push(byte);
145                index += 1
146            }
147        }
148    }
149
150    String::from_utf8(unescaped).map_err(|_| EscapeErrorKind::InvalidUtf8.into())
151}
152
153fn is_utf16_low_surrogate(codepoint: u16) -> bool {
154    codepoint & 0xFC00 == 0xDC00
155}
156
157fn is_utf16_high_surrogate(codepoint: u16) -> bool {
158    codepoint & 0xFC00 == 0xD800
159}
160
161fn read_codepoint(rest: &[u8]) -> Result<u16, EscapeError> {
162    if rest.len() < 6 {
163        return Err(EscapeErrorKind::UnexpectedEndOfString.into());
164    }
165    if &rest[0..2] != b"\\u" {
166        // The first codepoint is always prefixed with "\u" since unescape_string_inner does
167        // that check, so this error will always be for the low word of a surrogate pair.
168        return Err(EscapeErrorKind::ExpectedSurrogatePair(
169            String::from_utf8_lossy(&rest[0..6]).into(),
170        )
171        .into());
172    }
173
174    let codepoint_str =
175        std::str::from_utf8(&rest[2..6]).map_err(|_| EscapeErrorKind::InvalidUtf8)?;
176
177    // Error on characters `u16::from_str_radix` would otherwise accept, such as `+`
178    if codepoint_str.bytes().any(|byte| !byte.is_ascii_hexdigit()) {
179        return Err(EscapeErrorKind::InvalidUnicodeEscape(codepoint_str.into()).into());
180    }
181    Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value"))
182}
183
184/// Reads JSON Unicode escape sequences (i.e., "\u1234"). Will also read
185/// an additional codepoint if the first codepoint is the start of a surrogate pair.
186fn read_unicode_escapes(bytes: &[u8], into: &mut Vec<u8>) -> Result<usize, EscapeError> {
187    let high = read_codepoint(bytes)?;
188    let (bytes_read, chr) = if is_utf16_high_surrogate(high) {
189        let low = read_codepoint(&bytes[6..])?;
190        if !is_utf16_low_surrogate(low) {
191            return Err(EscapeErrorKind::InvalidSurrogatePair(high, low).into());
192        }
193
194        let codepoint =
195            std::char::from_u32(0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32)
196                .ok_or(EscapeErrorKind::InvalidSurrogatePair(high, low))?;
197        (12, codepoint)
198    } else {
199        let codepoint = std::char::from_u32(high as u32).ok_or_else(|| {
200            EscapeErrorKind::InvalidUnicodeEscape(String::from_utf8_lossy(&bytes[0..6]).into())
201        })?;
202        (6, codepoint)
203    };
204
205    match chr.len_utf8() {
206        1 => into.push(chr as u8),
207        _ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()),
208    }
209    Ok(bytes_read)
210}
211
212#[cfg(test)]
213mod test {
214    use super::escape_string;
215    use crate::escape::{unescape_string, EscapeErrorKind};
216    use std::borrow::Cow;
217
218    #[test]
219    fn escape() {
220        assert_eq!("", escape_string("").as_ref());
221        assert_eq!("foo", escape_string("foo").as_ref());
222        assert_eq!("foo\\r\\n", escape_string("foo\r\n").as_ref());
223        assert_eq!("foo\\r\\nbar", escape_string("foo\r\nbar").as_ref());
224        assert_eq!(r"foo\\bar", escape_string(r"foo\bar").as_ref());
225        assert_eq!(r"\\foobar", escape_string(r"\foobar").as_ref());
226        assert_eq!(
227            r"\bf\fo\to\r\n",
228            escape_string("\u{08}f\u{0C}o\to\r\n").as_ref()
229        );
230        assert_eq!("\\\"test\\\"", escape_string("\"test\"").as_ref());
231        assert_eq!("\\u0000", escape_string("\u{0}").as_ref());
232        assert_eq!("\\u001f", escape_string("\u{1f}").as_ref());
233    }
234
235    #[test]
236    fn unescape_no_escapes() {
237        let unescaped = unescape_string("test test").unwrap();
238        assert_eq!("test test", unescaped);
239        assert!(matches!(unescaped, Cow::Borrowed(_)));
240    }
241
242    #[test]
243    fn unescape() {
244        assert_eq!(
245            "\x08f\x0Co\to\r\n",
246            unescape_string(r"\bf\fo\to\r\n").unwrap()
247        );
248        assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap());
249        assert_eq!("\x00", unescape_string("\\u0000").unwrap());
250        assert_eq!("\x1f", unescape_string("\\u001f").unwrap());
251        assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap());
252        assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap());
253        assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap());
254        assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap());
255
256        assert_eq!(
257            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
258            unescape_string("\\")
259        );
260        assert_eq!(
261            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
262            unescape_string("\\u")
263        );
264        assert_eq!(
265            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
266            unescape_string("\\u00")
267        );
268        assert_eq!(
269            Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()),
270            unescape_string("\\z")
271        );
272
273        assert_eq!(
274            Err(EscapeErrorKind::ExpectedSurrogatePair("\\nasdf".into()).into()),
275            unescape_string("\\uD801\\nasdf")
276        );
277        assert_eq!(
278            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
279            unescape_string("\\uD801\\u00")
280        );
281        assert_eq!(
282            Err(EscapeErrorKind::InvalidSurrogatePair(0xD801, 0xC501).into()),
283            unescape_string("\\uD801\\uC501")
284        );
285
286        assert_eq!(
287            Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()),
288            unescape_string("\\u+04D")
289        );
290    }
291
292    use proptest::proptest;
293    proptest! {
294        #[test]
295        fn matches_serde_json(s in ".*") {
296            let serde_escaped = serde_json::to_string(&s).unwrap();
297            let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
298            assert_eq!(serde_escaped,escape_string(&s))
299        }
300
301        #[test]
302        fn round_trip(chr in proptest::char::any()) {
303            let mut original = String::new();
304            original.push(chr);
305
306            let escaped = escape_string(&original);
307            let unescaped = unescape_string(&escaped).unwrap();
308            assert_eq!(original, unescaped);
309        }
310
311        #[test]
312        fn unicode_surrogates(chr in proptest::char::range(
313            std::char::from_u32(0x10000).unwrap(),
314            std::char::from_u32(0x10FFFF).unwrap(),
315        )) {
316            let mut codepoints = [0; 2];
317            chr.encode_utf16(&mut codepoints);
318
319            let escaped = format!("\\u{:04X}\\u{:04X}", codepoints[0], codepoints[1]);
320            let unescaped = unescape_string(&escaped).unwrap();
321
322            let expected = format!("{chr}");
323            assert_eq!(expected, unescaped);
324        }
325    }
326
327    #[test]
328    #[ignore] // This tests escaping of all codepoints, but can take a long time in debug builds
329    fn all_codepoints() {
330        for value in 0..u32::MAX {
331            if let Some(chr) = char::from_u32(value) {
332                let string = String::from(chr);
333                let escaped = escape_string(&string);
334                let serde_escaped = serde_json::to_string(&string).unwrap();
335                let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
336                assert_eq!(&escaped, serde_escaped);
337            }
338        }
339    }
340}