1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | #![cfg(all(
|
7 + | feature = "client",
|
8 + | feature = "hickory-dns",
|
9 + | not(target_family = "wasm")
|
10 + | ))]
|
11 + |
|
12 + | use aws_smithy_runtime::client::dns::CachingDnsResolver;
|
13 + | use aws_smithy_runtime_api::client::dns::ResolveDns;
|
14 + | use std::{
|
15 + | net::{IpAddr, Ipv4Addr},
|
16 + | time::Duration,
|
17 + | };
|
18 + | use tokio::test;
|
19 + |
|
20 + | #[test]
|
21 + | async fn test_dns_caching() {
|
22 + | let dns_server = test_dns_server::setup_dns_server().await;
|
23 + | let (dns_ip, dns_port) = dns_server.addr();
|
24 + |
|
25 + | let resolver = CachingDnsResolver::builder()
|
26 + | .nameservers(&[dns_ip], dns_port)
|
27 + | .cache_size(1)
|
28 + | .build();
|
29 + |
|
30 + | let hostname = "example.com";
|
31 + |
|
32 + | // First resolution should hit the server
|
33 + | let first_result = resolver.resolve_dns(hostname).await;
|
34 + | let first_ips = first_result.unwrap();
|
35 + |
|
36 + | // Verify correct IP returned and server hit to get it
|
37 + | assert_eq!(vec![IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))], first_ips);
|
38 + | assert!(dns_server.query_count() == 1);
|
39 + |
|
40 + | // Second resolution should hit the cache
|
41 + | let second_result = resolver.resolve_dns(hostname).await;
|
42 + | let second_ips = second_result.unwrap();
|
43 + |
|
44 + | // Verify second resolution hit cache, not server
|
45 + | assert!(dns_server.query_count() == 1);
|
46 + |
|
47 + | // Verify same IPs returned
|
48 + | assert_eq!(first_ips, second_ips);
|
49 + | }
|
50 + |
|
51 + | #[test]
|
52 + | async fn test_dns_cache_size_limit() {
|
53 + | let dns_server = test_dns_server::setup_dns_server().await;
|
54 + | let (dns_ip, dns_port) = dns_server.addr();
|
55 + |
|
56 + | let resolver = CachingDnsResolver::builder()
|
57 + | .nameservers(&[dns_ip], dns_port)
|
58 + | .cache_size(1)
|
59 + | .build();
|
60 + |
|
61 + | // First resolution should hit the server
|
62 + | let _first_result = resolver.resolve_dns("example.com").await;
|
63 + |
|
64 + | // Verify server hit
|
65 + | assert!(dns_server.query_count() == 1);
|
66 + |
|
67 + | // Second resolution should hit the server
|
68 + | let _second_result = resolver.resolve_dns("aws.com").await;
|
69 + |
|
70 + | // Verify server hit
|
71 + | assert!(dns_server.query_count() == 2);
|
72 + |
|
73 + | // Third resolution should hit the server
|
74 + | let _third_result = resolver.resolve_dns("foo.com").await;
|
75 + |
|
76 + | // Verify server hit
|
77 + | assert!(dns_server.query_count() == 3);
|
78 + |
|
79 + | // Third result should now be in cache
|
80 + | let _third_result = resolver.resolve_dns("foo.com").await;
|
81 + |
|
82 + | // Verify server not hit in favor of the cache
|
83 + | assert!(dns_server.query_count() == 3);
|
84 + |
|
85 + | // First result should have been removed from the cache, so querying it again should hit server
|
86 + | let _first_result_again = resolver.resolve_dns("example.com").await;
|
87 + |
|
88 + | // Verify server hit
|
89 + | assert!(dns_server.query_count() == 4);
|
90 + | }
|
91 + |
|
92 + | #[test]
|
93 + | async fn test_dns_error_handling() {
|
94 + | let dns_server = test_dns_server::setup_dns_server().await;
|
95 + | let (dns_ip, dns_port) = dns_server.addr();
|
96 + |
|
97 + | let resolver = CachingDnsResolver::builder()
|
98 + | .nameservers(&[dns_ip], dns_port)
|
99 + | .timeout(Duration::from_millis(100))
|
100 + | .attempts(1)
|
101 + | .build();
|
102 + |
|
103 + | // Try to resolve an invalid hostname
|
104 + | let result = resolver
|
105 + | .resolve_dns("invalid.nonexistent.domain.test")
|
106 + | .await;
|
107 + | assert!(result.is_err());
|
108 + | }
|
109 + |
|
110 + | // Kind of janky minimal test utility for creating a local DNS server
|
111 + | #[cfg(test)]
|
112 + | mod test_dns_server {
|
113 + | use std::{
|
114 + | collections::HashMap,
|
115 + | net::{IpAddr, Ipv4Addr, SocketAddr},
|
116 + | sync::{atomic::AtomicUsize, Arc},
|
117 + | time::Duration,
|
118 + | };
|
119 + | use tokio::{net::UdpSocket, sync::Notify, task::JoinHandle};
|
120 + |
|
121 + | pub async fn setup_dns_server() -> TestDnsServer {
|
122 + | let mut records = HashMap::new();
|
123 + | records.insert(
|
124 + | "example.com".to_string(),
|
125 + | IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
126 + | );
|
127 + | records.insert("aws.com".to_string(), IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)));
|
128 + | records.insert(
|
129 + | "foo.com".to_string(),
|
130 + | IpAddr::V4(Ipv4Addr::new(9, 10, 11, 12)),
|
131 + | );
|
132 + |
|
133 + | TestDnsServer::start(records).await.unwrap()
|
134 + | }
|
135 + |
|
136 + | pub struct TestDnsServer {
|
137 + | handle: JoinHandle<()>,
|
138 + | addr: SocketAddr,
|
139 + | shutdown: Arc<Notify>,
|
140 + | query_count: Arc<AtomicUsize>,
|
141 + | }
|
142 + |
|
143 + | impl TestDnsServer {
|
144 + | pub async fn start(
|
145 + | records: HashMap<String, IpAddr>,
|
146 + | ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
147 + | // localhost, random port
|
148 + | let socket = UdpSocket::bind("127.0.0.1:0").await?;
|
149 + | let addr = socket.local_addr()?;
|
150 + |
|
151 + | let shutdown = Arc::new(Notify::new());
|
152 + | let shutdown_clone = shutdown.clone();
|
153 + | let query_count = Arc::new(AtomicUsize::new(0));
|
154 + | let query_count_clone = query_count.clone();
|
155 + |
|
156 + | let handle = tokio::spawn(async move {
|
157 + | let mut buf = [0; 512];
|
158 + | loop {
|
159 + | tokio::select! {
|
160 + | _ = shutdown_clone.notified() => break,
|
161 + | result = socket.recv_from(&mut buf) => {
|
162 + | if let Ok((len, src)) = result {
|
163 + | query_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
164 + | // Short sleep before returning DNS response to simulate network latency
|
165 + | tokio::time::sleep(Duration::from_millis(1000)).await;
|
166 + | let response = create_dns_response(&buf[..len], &records);
|
167 + | let _ = socket.send_to(&response, src).await;
|
168 + | }
|
169 + | }
|
170 + | }
|
171 + | }
|
172 + | });
|
173 + |
|
174 + | Ok(TestDnsServer {
|
175 + | handle,
|
176 + | addr,
|
177 + | shutdown,
|
178 + | query_count,
|
179 + | })
|
180 + | }
|
181 + |
|
182 + | pub fn addr(&self) -> (IpAddr, u16) {
|
183 + | (self.addr.ip(), self.addr.port())
|
184 + | }
|
185 + |
|
186 + | pub fn query_count(&self) -> usize {
|
187 + | self.query_count.load(std::sync::atomic::Ordering::Relaxed)
|
188 + | }
|
189 + | }
|
190 + |
|
191 + | impl Drop for TestDnsServer {
|
192 + | fn drop(&mut self) {
|
193 + | self.shutdown.notify_one();
|
194 + | self.handle.abort();
|
195 + | }
|
196 + | }
|
197 + |
|
198 + | fn create_dns_response(query: &[u8], records: &HashMap<String, IpAddr>) -> Vec<u8> {
|
199 + | let parsed = DnsQuery::parse(query).unwrap_or_default();
|
200 + | let ip = records.get(&parsed.domain).copied().unwrap();
|
201 + |
|
202 + | let response = DnsResponse {
|
203 + | id: parsed.id,
|
204 + | flags: 0x8180, // Standard response flags
|
205 + | question: parsed.domain,
|
206 + | answer_ip: ip,
|
207 + | ttl: 300,
|
208 + | };
|
209 + |
|
210 + | response.to_bytes()
|
211 + | }
|
212 + |
|
213 + | #[derive(Debug, Default)]
|
214 + | #[allow(dead_code)]
|
215 + | struct DnsQuery {
|
216 + | id: u16,
|
217 + | flags: u16,
|
218 + | question_count: u16,
|
219 + | domain: String,
|
220 + | query_type: u16,
|
221 + | query_class: u16,
|
222 + | }
|
223 + |
|
224 + | impl DnsQuery {
|
225 + | fn parse(data: &[u8]) -> Option<Self> {
|
226 + | if data.len() < 12 {
|
227 + | return None;
|
228 + | }
|
229 + |
|
230 + | let id = u16::from_be_bytes([data[0], data[1]]);
|
231 + | let flags = u16::from_be_bytes([data[2], data[3]]);
|
232 + | let question_count = u16::from_be_bytes([data[4], data[5]]);
|
233 + |
|
234 + | if question_count == 0 {
|
235 + | return None;
|
236 + | }
|
237 + |
|
238 + | // Parse domain name starting at byte 12
|
239 + | let mut pos = 12;
|
240 + | let mut domain = String::new();
|
241 + |
|
242 + | while pos < data.len() {
|
243 + | let len = data[pos] as usize;
|
244 + | if len == 0 {
|
245 + | pos += 1;
|
246 + | break;
|
247 + | }
|
248 + |
|
249 + | if !domain.is_empty() {
|
250 + | domain.push('.');
|
251 + | }
|
252 + |
|
253 + | pos += 1;
|
254 + | if pos + len > data.len() {
|
255 + | return None;
|
256 + | }
|
257 + |
|
258 + | if let Ok(label) = std::str::from_utf8(&data[pos..pos + len]) {
|
259 + | domain.push_str(label);
|
260 + | }
|
261 + | pos += len;
|
262 + | }
|
263 + |
|
264 + | if pos + 4 > data.len() {
|
265 + | return None;
|
266 + | }
|
267 + |
|
268 + | let query_type = u16::from_be_bytes([data[pos], data[pos + 1]]);
|
269 + | let query_class = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
|
270 + |
|
271 + | Some(DnsQuery {
|
272 + | id,
|
273 + | flags,
|
274 + | question_count,
|
275 + | domain,
|
276 + | query_type,
|
277 + | query_class,
|
278 + | })
|
279 + | }
|
280 + | }
|
281 + |
|
282 + | #[derive(Debug)]
|
283 + | #[allow(dead_code)]
|
284 + | struct DnsResponse {
|
285 + | id: u16,
|
286 + | flags: u16,
|
287 + | question: String,
|
288 + | answer_ip: IpAddr,
|
289 + | ttl: u32,
|
290 + | }
|
291 + |
|
292 + | impl DnsResponse {
|
293 + | fn to_bytes(&self) -> Vec<u8> {
|
294 + | // 30ish required bytes, 11 more added for the question section
|
295 + | // since the longest domain we currently use is 11 bytes long
|
296 + | let mut response = Vec::with_capacity(41);
|
297 + |
|
298 + | // Header (12 bytes) all values besides id/flags hardcoded
|
299 + | response.extend_from_slice(&self.id.to_be_bytes());
|
300 + | response.extend_from_slice(&self.flags.to_be_bytes());
|
301 + | response.extend_from_slice(&1u16.to_be_bytes()); // Questions: 1
|
302 + | response.extend_from_slice(&1u16.to_be_bytes()); // Answers: 1
|
303 + | response.extend_from_slice(&0u16.to_be_bytes()); // Authority: 0
|
304 + | response.extend_from_slice(&0u16.to_be_bytes()); // Additional: 0
|
305 + |
|
306 + | // Question section
|
307 + | // In a more ideal world the DnsResponse would contain a ref to the
|
308 + | // DnsQuery that triggered this response and recreate the question section
|
309 + | // from that
|
310 + | for label in self.question.split('.') {
|
311 + | response.push(label.len() as u8);
|
312 + | response.extend_from_slice(label.as_bytes());
|
313 + | }
|
314 + | response.push(0); // End of name
|
315 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
316 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
317 + |
|
318 + | // Answer section
|
319 + | response.extend_from_slice(&[0xc0, 0x0c]); // Name pointer
|
320 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
321 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
322 + | response.extend_from_slice(&self.ttl.to_be_bytes()); // TTL
|
323 + |
|
324 + | match self.answer_ip {
|
325 + | IpAddr::V4(ipv4) => {
|
326 + | response.extend_from_slice(&4u16.to_be_bytes()); // Data length
|
327 + | response.extend_from_slice(&ipv4.octets());
|
328 + | }
|
329 + | IpAddr::V6(_) => {
|
330 + | // Unsupported, fallback to IPv4
|
331 + | response.extend_from_slice(&4u16.to_be_bytes());
|
332 + | response.extend_from_slice(&[127, 0, 0, 1]);
|
333 + | }
|
334 + | }
|
335 + |
|
336 + | response
|
337 + | }
|
338 + | }
|
339 + | }
|