1use std::borrow::Cow;
7use std::fmt;
8
9#[derive(Debug, PartialEq, Eq)]
10enum EscapeErrorKind {
11 ExpectedSurrogatePair(String),
12 InvalidEscapeCharacter(char),
13 InvalidSurrogatePair(u16, u16),
14 InvalidUnicodeEscape(String),
15 InvalidUtf8,
16 UnexpectedEndOfString,
17}
18
19#[derive(Debug)]
20#[cfg_attr(test, derive(PartialEq, Eq))]
21pub struct EscapeError {
22 kind: EscapeErrorKind,
23}
24
25impl std::error::Error for EscapeError {}
26
27impl fmt::Display for EscapeError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 use EscapeErrorKind::*;
30 match &self.kind {
31 ExpectedSurrogatePair(low) => {
32 write!(
33 f,
34 "expected a UTF-16 surrogate pair, but got {low} as the low word"
35 )
36 }
37 InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{chr}"),
38 InvalidSurrogatePair(high, low) => {
39 write!(f, "invalid surrogate pair: \\u{high:04X}\\u{low:04X}")
40 }
41 InvalidUnicodeEscape(escape) => write!(f, "invalid JSON Unicode escape: \\u{escape}"),
42 InvalidUtf8 => write!(f, "invalid UTF-8 codepoint in JSON string"),
43 UnexpectedEndOfString => write!(f, "unexpected end of string"),
44 }
45 }
46}
47
48impl From<EscapeErrorKind> for EscapeError {
49 fn from(kind: EscapeErrorKind) -> Self {
50 Self { kind }
51 }
52}
53
54pub(crate) fn escape_string(value: &str) -> Cow<'_, str> {
56 let bytes = value.as_bytes();
57 for (index, byte) in bytes.iter().enumerate() {
58 match byte {
59 0..=0x1F | b'"' | b'\\' => {
60 return Cow::Owned(escape_string_inner(&bytes[0..index], &bytes[index..]))
61 }
62 _ => {}
63 }
64 }
65 Cow::Borrowed(value)
66}
67
68fn escape_string_inner(start: &[u8], rest: &[u8]) -> String {
69 let mut escaped = Vec::with_capacity(start.len() + rest.len() + 1);
70 escaped.extend(start);
71
72 for byte in rest {
73 match byte {
74 b'"' => escaped.extend(b"\\\""),
75 b'\\' => escaped.extend(b"\\\\"),
76 0x08 => escaped.extend(b"\\b"),
77 0x0C => escaped.extend(b"\\f"),
78 b'\n' => escaped.extend(b"\\n"),
79 b'\r' => escaped.extend(b"\\r"),
80 b'\t' => escaped.extend(b"\\t"),
81 0..=0x1F => escaped.extend(format!("\\u{byte:04x}").bytes()),
82 _ => escaped.push(*byte),
83 }
84 }
85
86 debug_assert!(std::str::from_utf8(&escaped).is_ok());
91 unsafe { String::from_utf8_unchecked(escaped) }
92}
93
94pub(crate) fn unescape_string(value: &str) -> Result<Cow<'_, str>, EscapeError> {
97 let bytes = value.as_bytes();
98 for (index, byte) in bytes.iter().enumerate() {
99 if *byte == b'\\' {
100 return unescape_string_inner(&bytes[0..index], &bytes[index..]).map(Cow::Owned);
101 }
102 }
103 Ok(Cow::Borrowed(value))
104}
105
106fn unescape_string_inner(start: &[u8], rest: &[u8]) -> Result<String, EscapeError> {
107 let mut unescaped = Vec::with_capacity(start.len() + rest.len());
108 unescaped.extend(start);
109
110 let mut index = 0;
111 while index < rest.len() {
112 match rest[index] {
113 b'\\' => {
114 index += 1;
115 if index == rest.len() {
116 return Err(EscapeErrorKind::UnexpectedEndOfString.into());
117 }
118 match rest[index] {
119 b'u' => {
120 index -= 1;
121 index += read_unicode_escapes(&rest[index..], &mut unescaped)?;
122 }
123 byte => {
124 match byte {
125 b'\\' => unescaped.push(b'\\'),
126 b'/' => unescaped.push(b'/'),
127 b'"' => unescaped.push(b'"'),
128 b'b' => unescaped.push(0x08),
129 b'f' => unescaped.push(0x0C),
130 b'n' => unescaped.push(b'\n'),
131 b'r' => unescaped.push(b'\r'),
132 b't' => unescaped.push(b'\t'),
133 _ => {
134 return Err(
135 EscapeErrorKind::InvalidEscapeCharacter(byte.into()).into()
136 )
137 }
138 }
139 index += 1;
140 }
141 }
142 }
143 byte => {
144 unescaped.push(byte);
145 index += 1
146 }
147 }
148 }
149
150 String::from_utf8(unescaped).map_err(|_| EscapeErrorKind::InvalidUtf8.into())
151}
152
153fn is_utf16_low_surrogate(codepoint: u16) -> bool {
154 codepoint & 0xFC00 == 0xDC00
155}
156
157fn is_utf16_high_surrogate(codepoint: u16) -> bool {
158 codepoint & 0xFC00 == 0xD800
159}
160
161fn read_codepoint(rest: &[u8]) -> Result<u16, EscapeError> {
162 if rest.len() < 6 {
163 return Err(EscapeErrorKind::UnexpectedEndOfString.into());
164 }
165 if &rest[0..2] != b"\\u" {
166 return Err(EscapeErrorKind::ExpectedSurrogatePair(
169 String::from_utf8_lossy(&rest[0..6]).into(),
170 )
171 .into());
172 }
173
174 let codepoint_str =
175 std::str::from_utf8(&rest[2..6]).map_err(|_| EscapeErrorKind::InvalidUtf8)?;
176
177 if codepoint_str.bytes().any(|byte| !byte.is_ascii_hexdigit()) {
179 return Err(EscapeErrorKind::InvalidUnicodeEscape(codepoint_str.into()).into());
180 }
181 Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value"))
182}
183
184fn read_unicode_escapes(bytes: &[u8], into: &mut Vec<u8>) -> Result<usize, EscapeError> {
187 let high = read_codepoint(bytes)?;
188 let (bytes_read, chr) = if is_utf16_high_surrogate(high) {
189 let low = read_codepoint(&bytes[6..])?;
190 if !is_utf16_low_surrogate(low) {
191 return Err(EscapeErrorKind::InvalidSurrogatePair(high, low).into());
192 }
193
194 let codepoint =
195 std::char::from_u32(0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32)
196 .ok_or(EscapeErrorKind::InvalidSurrogatePair(high, low))?;
197 (12, codepoint)
198 } else {
199 let codepoint = std::char::from_u32(high as u32).ok_or_else(|| {
200 EscapeErrorKind::InvalidUnicodeEscape(String::from_utf8_lossy(&bytes[0..6]).into())
201 })?;
202 (6, codepoint)
203 };
204
205 match chr.len_utf8() {
206 1 => into.push(chr as u8),
207 _ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()),
208 }
209 Ok(bytes_read)
210}
211
212#[cfg(test)]
213mod test {
214 use super::escape_string;
215 use crate::escape::{unescape_string, EscapeErrorKind};
216 use std::borrow::Cow;
217
218 #[test]
219 fn escape() {
220 assert_eq!("", escape_string("").as_ref());
221 assert_eq!("foo", escape_string("foo").as_ref());
222 assert_eq!("foo\\r\\n", escape_string("foo\r\n").as_ref());
223 assert_eq!("foo\\r\\nbar", escape_string("foo\r\nbar").as_ref());
224 assert_eq!(r"foo\\bar", escape_string(r"foo\bar").as_ref());
225 assert_eq!(r"\\foobar", escape_string(r"\foobar").as_ref());
226 assert_eq!(
227 r"\bf\fo\to\r\n",
228 escape_string("\u{08}f\u{0C}o\to\r\n").as_ref()
229 );
230 assert_eq!("\\\"test\\\"", escape_string("\"test\"").as_ref());
231 assert_eq!("\\u0000", escape_string("\u{0}").as_ref());
232 assert_eq!("\\u001f", escape_string("\u{1f}").as_ref());
233 }
234
235 #[test]
236 fn unescape_no_escapes() {
237 let unescaped = unescape_string("test test").unwrap();
238 assert_eq!("test test", unescaped);
239 assert!(matches!(unescaped, Cow::Borrowed(_)));
240 }
241
242 #[test]
243 fn unescape() {
244 assert_eq!(
245 "\x08f\x0Co\to\r\n",
246 unescape_string(r"\bf\fo\to\r\n").unwrap()
247 );
248 assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap());
249 assert_eq!("\x00", unescape_string("\\u0000").unwrap());
250 assert_eq!("\x1f", unescape_string("\\u001f").unwrap());
251 assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap());
252 assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap());
253 assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap());
254 assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap());
255
256 assert_eq!(
257 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
258 unescape_string("\\")
259 );
260 assert_eq!(
261 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
262 unescape_string("\\u")
263 );
264 assert_eq!(
265 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
266 unescape_string("\\u00")
267 );
268 assert_eq!(
269 Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()),
270 unescape_string("\\z")
271 );
272
273 assert_eq!(
274 Err(EscapeErrorKind::ExpectedSurrogatePair("\\nasdf".into()).into()),
275 unescape_string("\\uD801\\nasdf")
276 );
277 assert_eq!(
278 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
279 unescape_string("\\uD801\\u00")
280 );
281 assert_eq!(
282 Err(EscapeErrorKind::InvalidSurrogatePair(0xD801, 0xC501).into()),
283 unescape_string("\\uD801\\uC501")
284 );
285
286 assert_eq!(
287 Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()),
288 unescape_string("\\u+04D")
289 );
290 }
291
292 use proptest::proptest;
293 proptest! {
294 #[test]
295 fn matches_serde_json(s in ".*") {
296 let serde_escaped = serde_json::to_string(&s).unwrap();
297 let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
298 assert_eq!(serde_escaped,escape_string(&s))
299 }
300
301 #[test]
302 fn round_trip(chr in proptest::char::any()) {
303 let mut original = String::new();
304 original.push(chr);
305
306 let escaped = escape_string(&original);
307 let unescaped = unescape_string(&escaped).unwrap();
308 assert_eq!(original, unescaped);
309 }
310
311 #[test]
312 fn unicode_surrogates(chr in proptest::char::range(
313 std::char::from_u32(0x10000).unwrap(),
314 std::char::from_u32(0x10FFFF).unwrap(),
315 )) {
316 let mut codepoints = [0; 2];
317 chr.encode_utf16(&mut codepoints);
318
319 let escaped = format!("\\u{:04X}\\u{:04X}", codepoints[0], codepoints[1]);
320 let unescaped = unescape_string(&escaped).unwrap();
321
322 let expected = format!("{chr}");
323 assert_eq!(expected, unescaped);
324 }
325 }
326
327 #[test]
328 #[ignore] fn all_codepoints() {
330 for value in 0..u32::MAX {
331 if let Some(chr) = char::from_u32(value) {
332 let string = String::from(chr);
333 let escaped = escape_string(&string);
334 let serde_escaped = serde_json::to_string(&string).unwrap();
335 let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
336 assert_eq!(&escaped, serde_escaped);
337 }
338 }
339 }
340}