1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | #![cfg(all(feature = "client", feature = "rt-tokio", not(target_family = "wasm")))]
|
7 + |
|
8 + | use aws_smithy_runtime::client::dns::CachingDnsResolver;
|
9 + | use aws_smithy_runtime_api::client::dns::ResolveDns;
|
10 + | use std::{
|
11 + | net::{IpAddr, Ipv4Addr},
|
12 + | time::Duration,
|
13 + | };
|
14 + | use tokio::test;
|
15 + |
|
16 + | #[test]
|
17 + | async fn test_dns_caching() {
|
18 + | let dns_server = test_dns_server::setup_dns_server().await;
|
19 + | let (dns_ip, dns_port) = dns_server.addr();
|
20 + |
|
21 + | let resolver = CachingDnsResolver::builder()
|
22 + | .nameservers(&[dns_ip], dns_port)
|
23 + | .cache_size(1)
|
24 + | .build();
|
25 + |
|
26 + | let hostname = "example.com";
|
27 + |
|
28 + | // First resolution should hit the server
|
29 + | let first_result = resolver.resolve_dns(hostname).await;
|
30 + | let first_ips = first_result.unwrap();
|
31 + |
|
32 + | // Verify correct IP returned and server hit to get it
|
33 + | assert_eq!(vec![IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))], first_ips);
|
34 + | assert!(dns_server.query_count() == 1);
|
35 + |
|
36 + | // Second resolution should hit the cache
|
37 + | let second_result = resolver.resolve_dns(hostname).await;
|
38 + | let second_ips = second_result.unwrap();
|
39 + |
|
40 + | // Verify second resolution hit cache, not server
|
41 + | assert!(dns_server.query_count() == 1);
|
42 + |
|
43 + | // Verify same IPs returned
|
44 + | assert_eq!(first_ips, second_ips);
|
45 + | }
|
46 + |
|
47 + | #[test]
|
48 + | async fn test_dns_cache_size_limit() {
|
49 + | let dns_server = test_dns_server::setup_dns_server().await;
|
50 + | let (dns_ip, dns_port) = dns_server.addr();
|
51 + |
|
52 + | let resolver = CachingDnsResolver::builder()
|
53 + | .nameservers(&[dns_ip], dns_port)
|
54 + | .cache_size(1)
|
55 + | .build();
|
56 + |
|
57 + | // First resolution should hit the server
|
58 + | let _first_result = resolver.resolve_dns("example.com").await;
|
59 + |
|
60 + | // Verify server hit
|
61 + | assert!(dns_server.query_count() == 1);
|
62 + |
|
63 + | // Second resolution should hit the server
|
64 + | let _second_result = resolver.resolve_dns("aws.com").await;
|
65 + |
|
66 + | // Verify server hit
|
67 + | assert!(dns_server.query_count() == 2);
|
68 + |
|
69 + | // Third resolution should hit the server
|
70 + | let _third_result = resolver.resolve_dns("foo.com").await;
|
71 + |
|
72 + | // Verify server hit
|
73 + | assert!(dns_server.query_count() == 3);
|
74 + |
|
75 + | // Third result should now be in cache
|
76 + | let _third_result = resolver.resolve_dns("foo.com").await;
|
77 + |
|
78 + | // Verify server not hit in favor of the cache
|
79 + | assert!(dns_server.query_count() == 3);
|
80 + |
|
81 + | // First result should have been removed from the cache, so querying it again should hit server
|
82 + | let _first_result_again = resolver.resolve_dns("example.com").await;
|
83 + |
|
84 + | // Verify server hit
|
85 + | assert!(dns_server.query_count() == 4);
|
86 + | }
|
87 + |
|
88 + | #[test]
|
89 + | async fn test_dns_error_handling() {
|
90 + | let dns_server = test_dns_server::setup_dns_server().await;
|
91 + | let (dns_ip, dns_port) = dns_server.addr();
|
92 + |
|
93 + | let resolver = CachingDnsResolver::builder()
|
94 + | .nameservers(&[dns_ip], dns_port)
|
95 + | .timeout(Duration::from_millis(100))
|
96 + | .attempts(1)
|
97 + | .build();
|
98 + |
|
99 + | // Try to resolve an invalid hostname
|
100 + | let result = resolver
|
101 + | .resolve_dns("invalid.nonexistent.domain.test")
|
102 + | .await;
|
103 + | assert!(result.is_err());
|
104 + | }
|
105 + |
|
106 + | // Kind of janky minimal test utility for creating a local DNS server
|
107 + | #[cfg(test)]
|
108 + | mod test_dns_server {
|
109 + | use std::{
|
110 + | collections::HashMap,
|
111 + | net::{IpAddr, Ipv4Addr, SocketAddr},
|
112 + | sync::{atomic::AtomicUsize, Arc},
|
113 + | time::Duration,
|
114 + | };
|
115 + | use tokio::{net::UdpSocket, sync::Notify, task::JoinHandle};
|
116 + |
|
117 + | pub async fn setup_dns_server() -> TestDnsServer {
|
118 + | let mut records = HashMap::new();
|
119 + | records.insert(
|
120 + | "example.com".to_string(),
|
121 + | IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
122 + | );
|
123 + | records.insert("aws.com".to_string(), IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)));
|
124 + | records.insert(
|
125 + | "foo.com".to_string(),
|
126 + | IpAddr::V4(Ipv4Addr::new(9, 10, 11, 12)),
|
127 + | );
|
128 + |
|
129 + | TestDnsServer::start(records).await.unwrap()
|
130 + | }
|
131 + |
|
132 + | pub struct TestDnsServer {
|
133 + | handle: JoinHandle<()>,
|
134 + | addr: SocketAddr,
|
135 + | shutdown: Arc<Notify>,
|
136 + | query_count: Arc<AtomicUsize>,
|
137 + | }
|
138 + |
|
139 + | impl TestDnsServer {
|
140 + | pub async fn start(
|
141 + | records: HashMap<String, IpAddr>,
|
142 + | ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
143 + | // localhost, random port
|
144 + | let socket = UdpSocket::bind("127.0.0.1:0").await?;
|
145 + | let addr = socket.local_addr()?;
|
146 + |
|
147 + | let shutdown = Arc::new(Notify::new());
|
148 + | let shutdown_clone = shutdown.clone();
|
149 + | let query_count = Arc::new(AtomicUsize::new(0));
|
150 + | let query_count_clone = query_count.clone();
|
151 + |
|
152 + | let handle = tokio::spawn(async move {
|
153 + | let mut buf = [0; 512];
|
154 + | loop {
|
155 + | tokio::select! {
|
156 + | _ = shutdown_clone.notified() => break,
|
157 + | result = socket.recv_from(&mut buf) => {
|
158 + | if let Ok((len, src)) = result {
|
159 + | query_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
160 + | // Short sleep before returning DNS response to simulate network latency
|
161 + | tokio::time::sleep(Duration::from_millis(1000)).await;
|
162 + | let response = create_dns_response(&buf[..len], &records);
|
163 + | let _ = socket.send_to(&response, src).await;
|
164 + | }
|
165 + | }
|
166 + | }
|
167 + | }
|
168 + | });
|
169 + |
|
170 + | Ok(TestDnsServer {
|
171 + | handle,
|
172 + | addr,
|
173 + | shutdown,
|
174 + | query_count,
|
175 + | })
|
176 + | }
|
177 + |
|
178 + | pub fn addr(&self) -> (IpAddr, u16) {
|
179 + | (self.addr.ip(), self.addr.port())
|
180 + | }
|
181 + |
|
182 + | pub fn query_count(&self) -> usize {
|
183 + | self.query_count.load(std::sync::atomic::Ordering::Relaxed)
|
184 + | }
|
185 + | }
|
186 + |
|
187 + | impl Drop for TestDnsServer {
|
188 + | fn drop(&mut self) {
|
189 + | self.shutdown.notify_one();
|
190 + | self.handle.abort();
|
191 + | }
|
192 + | }
|
193 + |
|
194 + | fn create_dns_response(query: &[u8], records: &HashMap<String, IpAddr>) -> Vec<u8> {
|
195 + | let parsed = DnsQuery::parse(query).unwrap_or_default();
|
196 + | let ip = records.get(&parsed.domain).copied().unwrap();
|
197 + |
|
198 + | let response = DnsResponse {
|
199 + | id: parsed.id,
|
200 + | flags: 0x8180, // Standard response flags
|
201 + | question: parsed.domain,
|
202 + | answer_ip: ip,
|
203 + | ttl: 300,
|
204 + | };
|
205 + |
|
206 + | response.to_bytes()
|
207 + | }
|
208 + |
|
209 + | #[derive(Debug, Default)]
|
210 + | #[allow(dead_code)]
|
211 + | struct DnsQuery {
|
212 + | id: u16,
|
213 + | flags: u16,
|
214 + | question_count: u16,
|
215 + | domain: String,
|
216 + | query_type: u16,
|
217 + | query_class: u16,
|
218 + | }
|
219 + |
|
220 + | impl DnsQuery {
|
221 + | fn parse(data: &[u8]) -> Option<Self> {
|
222 + | if data.len() < 12 {
|
223 + | return None;
|
224 + | }
|
225 + |
|
226 + | let id = u16::from_be_bytes([data[0], data[1]]);
|
227 + | let flags = u16::from_be_bytes([data[2], data[3]]);
|
228 + | let question_count = u16::from_be_bytes([data[4], data[5]]);
|
229 + |
|
230 + | if question_count == 0 {
|
231 + | return None;
|
232 + | }
|
233 + |
|
234 + | // Parse domain name starting at byte 12
|
235 + | let mut pos = 12;
|
236 + | let mut domain = String::new();
|
237 + |
|
238 + | while pos < data.len() {
|
239 + | let len = data[pos] as usize;
|
240 + | if len == 0 {
|
241 + | pos += 1;
|
242 + | break;
|
243 + | }
|
244 + |
|
245 + | if !domain.is_empty() {
|
246 + | domain.push('.');
|
247 + | }
|
248 + |
|
249 + | pos += 1;
|
250 + | if pos + len > data.len() {
|
251 + | return None;
|
252 + | }
|
253 + |
|
254 + | if let Ok(label) = std::str::from_utf8(&data[pos..pos + len]) {
|
255 + | domain.push_str(label);
|
256 + | }
|
257 + | pos += len;
|
258 + | }
|
259 + |
|
260 + | if pos + 4 > data.len() {
|
261 + | return None;
|
262 + | }
|
263 + |
|
264 + | let query_type = u16::from_be_bytes([data[pos], data[pos + 1]]);
|
265 + | let query_class = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
|
266 + |
|
267 + | Some(DnsQuery {
|
268 + | id,
|
269 + | flags,
|
270 + | question_count,
|
271 + | domain,
|
272 + | query_type,
|
273 + | query_class,
|
274 + | })
|
275 + | }
|
276 + | }
|
277 + |
|
278 + | #[derive(Debug)]
|
279 + | #[allow(dead_code)]
|
280 + | struct DnsResponse {
|
281 + | id: u16,
|
282 + | flags: u16,
|
283 + | question: String,
|
284 + | answer_ip: IpAddr,
|
285 + | ttl: u32,
|
286 + | }
|
287 + |
|
288 + | impl DnsResponse {
|
289 + | fn to_bytes(&self) -> Vec<u8> {
|
290 + | // 30ish required bytes, 11 more added for the question section
|
291 + | // since the longest domain we currently use is 11 bytes long
|
292 + | let mut response = Vec::with_capacity(41);
|
293 + |
|
294 + | // Header (12 bytes) all values besides id/flags hardcoded
|
295 + | response.extend_from_slice(&self.id.to_be_bytes());
|
296 + | response.extend_from_slice(&self.flags.to_be_bytes());
|
297 + | response.extend_from_slice(&1u16.to_be_bytes()); // Questions: 1
|
298 + | response.extend_from_slice(&1u16.to_be_bytes()); // Answers: 1
|
299 + | response.extend_from_slice(&0u16.to_be_bytes()); // Authority: 0
|
300 + | response.extend_from_slice(&0u16.to_be_bytes()); // Additional: 0
|
301 + |
|
302 + | // Question section
|
303 + | // In a more ideal world the DnsResponse would contain a ref to the
|
304 + | // DnsQuery that triggered this response and recreate the question section
|
305 + | // from that
|
306 + | for label in self.question.split('.') {
|
307 + | response.push(label.len() as u8);
|
308 + | response.extend_from_slice(label.as_bytes());
|
309 + | }
|
310 + | response.push(0); // End of name
|
311 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
312 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
313 + |
|
314 + | // Answer section
|
315 + | response.extend_from_slice(&[0xc0, 0x0c]); // Name pointer
|
316 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
317 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
318 + | response.extend_from_slice(&self.ttl.to_be_bytes()); // TTL
|
319 + |
|
320 + | match self.answer_ip {
|
321 + | IpAddr::V4(ipv4) => {
|
322 + | response.extend_from_slice(&4u16.to_be_bytes()); // Data length
|
323 + | response.extend_from_slice(&ipv4.octets());
|
324 + | }
|
325 + | IpAddr::V6(_) => {
|
326 + | // Unsupported, fallback to IPv4
|
327 + | response.extend_from_slice(&4u16.to_be_bytes());
|
328 + | response.extend_from_slice(&[127, 0, 0, 1]);
|
329 + | }
|
330 + | }
|
331 + |
|
332 + | response
|
333 + | }
|
334 + | }
|
335 + | }
|