1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | use crate::query::fmt_string as percent_encode_query;
|
7 + | use http_02x::uri::InvalidUri;
|
8 + | use http_02x::Uri;
|
9 + |
|
10 + | /// Utility for updating the query string in a [`Uri`].
|
11 + | #[allow(missing_debug_implementations)]
|
12 + | pub struct QueryWriter {
|
13 + | base_uri: Uri,
|
14 + | new_path_and_query: String,
|
15 + | prefix: Option<char>,
|
16 + | }
|
17 + |
|
18 + | impl QueryWriter {
|
19 + | /// Creates a new `QueryWriter` from a string
|
20 + | pub fn new_from_string(uri: &str) -> Result<Self, InvalidUri> {
|
21 + | Ok(Self::new(&Uri::try_from(uri)?))
|
22 + | }
|
23 + |
|
24 + | /// Creates a new `QueryWriter` based off the given `uri`.
|
25 + | pub fn new(uri: &Uri) -> Self {
|
26 + | let new_path_and_query = uri
|
27 + | .path_and_query()
|
28 + | .map(|pq| pq.to_string())
|
29 + | .unwrap_or_default();
|
30 + | let prefix = if uri.query().is_none() {
|
31 + | Some('?')
|
32 + | } else if !uri.query().unwrap_or_default().is_empty() {
|
33 + | Some('&')
|
34 + | } else {
|
35 + | None
|
36 + | };
|
37 + | QueryWriter {
|
38 + | base_uri: uri.clone(),
|
39 + | new_path_and_query,
|
40 + | prefix,
|
41 + | }
|
42 + | }
|
43 + |
|
44 + | /// Clears all query parameters.
|
45 + | pub fn clear_params(&mut self) {
|
46 + | if let Some(index) = self.new_path_and_query.find('?') {
|
47 + | self.new_path_and_query.truncate(index);
|
48 + | self.prefix = Some('?');
|
49 + | }
|
50 + | }
|
51 + |
|
52 + | /// Inserts a new query parameter. The key and value are percent encoded
|
53 + | /// by `QueryWriter`. Passing in percent encoded values will result in double encoding.
|
54 + | pub fn insert(&mut self, k: &str, v: &str) {
|
55 + | self.insert_encoded(&percent_encode_query(k), &percent_encode_query(v));
|
56 + | }
|
57 + |
|
58 + | /// Inserts a new already encoded query parameter. The key and value will be inserted
|
59 + | /// as is.
|
60 + | pub fn insert_encoded(&mut self, encoded_k: &str, encoded_v: &str) {
|
61 + | if let Some(prefix) = self.prefix {
|
62 + | self.new_path_and_query.push(prefix);
|
63 + | }
|
64 + | self.prefix = Some('&');
|
65 + | self.new_path_and_query.push_str(encoded_k);
|
66 + | self.new_path_and_query.push('=');
|
67 + | self.new_path_and_query.push_str(encoded_v)
|
68 + | }
|
69 + |
|
70 + | /// Returns just the built query string.
|
71 + | pub fn build_query(self) -> String {
|
72 + | self.build_uri().query().unwrap_or_default().to_string()
|
73 + | }
|
74 + |
|
75 + | /// Returns a full [`Uri`] with the query string updated.
|
76 + | pub fn build_uri(self) -> Uri {
|
77 + | let mut parts = self.base_uri.into_parts();
|
78 + | parts.path_and_query = Some(
|
79 + | self.new_path_and_query
|
80 + | .parse()
|
81 + | .expect("adding query should not invalidate URI"),
|
82 + | );
|
83 + | Uri::from_parts(parts).expect("a valid URL in should always produce a valid URL out")
|
84 + | }
|
85 + | }
|
86 + |
|
87 + | #[cfg(test)]
|
88 + | mod test {
|
89 + | use super::QueryWriter;
|
90 + | use http_02x::Uri;
|
91 + |
|
92 + | #[test]
|
93 + | fn empty_uri() {
|
94 + | let uri = Uri::from_static("http://www.example.com");
|
95 + | let mut query_writer = QueryWriter::new(&uri);
|
96 + | query_writer.insert("key", "val%ue");
|
97 + | query_writer.insert("another", "value");
|
98 + | assert_eq!(
|
99 + | query_writer.build_uri(),
|
100 + | Uri::from_static("http://www.example.com?key=val%25ue&another=value")
|
101 + | );
|
102 + | }
|
103 + |
|
104 + | #[test]
|
105 + | fn uri_with_path() {
|
106 + | let uri = Uri::from_static("http://www.example.com/path");
|
107 + | let mut query_writer = QueryWriter::new(&uri);
|
108 + | query_writer.insert("key", "val%ue");
|
109 + | query_writer.insert("another", "value");
|
110 + | assert_eq!(
|
111 + | query_writer.build_uri(),
|
112 + | Uri::from_static("http://www.example.com/path?key=val%25ue&another=value")
|
113 + | );
|
114 + | }
|
115 + |
|
116 + | #[test]
|
117 + | fn uri_with_path_and_query() {
|
118 + | let uri = Uri::from_static("http://www.example.com/path?original=here");
|
119 + | let mut query_writer = QueryWriter::new(&uri);
|
120 + | query_writer.insert("key", "val%ue");
|
121 + | query_writer.insert("another", "value");
|
122 + | assert_eq!(
|
123 + | query_writer.build_uri(),
|
124 + | Uri::from_static(
|
125 + | "http://www.example.com/path?original=here&key=val%25ue&another=value"
|
126 + | )
|
127 + | );
|
128 + | }
|
129 + |
|
130 + | #[test]
|
131 + | fn build_query() {
|
132 + | let uri = Uri::from_static("http://www.example.com");
|
133 + | let mut query_writer = QueryWriter::new(&uri);
|
134 + | query_writer.insert("key", "val%ue");
|
135 + | query_writer.insert("ano%ther", "value");
|
136 + | assert_eq!("key=val%25ue&ano%25ther=value", query_writer.build_query());
|
137 + | }
|
138 + |
|
139 + | #[test]
|
140 + | // This test ensures that the percent encoding applied to queries always produces a valid URI if
|
141 + | // the starting URI is valid
|
142 + | fn doesnt_panic_when_adding_query_to_valid_uri() {
|
143 + | let uri = Uri::from_static("http://www.example.com");
|
144 + |
|
145 + | let mut problematic_chars = Vec::new();
|
146 + |
|
147 + | for byte in u8::MIN..=u8::MAX {
|
148 + | match std::str::from_utf8(&[byte]) {
|
149 + | // If we can't make a str from the byte then we certainly can't make a URL from it
|
150 + | Err(_) => {
|
151 + | continue;
|
152 + | }
|
153 + | Ok(value) => {
|
154 + | let mut query_writer = QueryWriter::new(&uri);
|
155 + | query_writer.insert("key", value);
|
156 + |
|
157 + | if std::panic::catch_unwind(|| query_writer.build_uri()).is_err() {
|
158 + | problematic_chars.push(char::from(byte));
|
159 + | };
|
160 + | }
|
161 + | }
|
162 + | }
|
163 + |
|
164 + | if !problematic_chars.is_empty() {
|
165 + | panic!("we got some bad bytes here: {problematic_chars:#?}")
|
166 + | }
|
167 + | }
|
168 + |
|
169 + | #[test]
|
170 + | fn clear_params() {
|
171 + | let uri = Uri::from_static("http://www.example.com/path?original=here&foo=1");
|
172 + | let mut query_writer = QueryWriter::new(&uri);
|
173 + | query_writer.clear_params();
|
174 + | query_writer.insert("new", "value");
|
175 + | assert_eq!("new=value", query_writer.build_query());
|
176 + | }
|
177 + | }
|