1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | #![cfg(all(feature = "client", feature = "rt-tokio",))]
|
7 + |
|
8 + | use aws_smithy_runtime::client::dns::CachingDnsResolver;
|
9 + | use aws_smithy_runtime_api::client::dns::ResolveDns;
|
10 + | use std::{
|
11 + | net::{IpAddr, Ipv4Addr},
|
12 + | time::{Duration, Instant},
|
13 + | };
|
14 + | use tokio::test;
|
15 + |
|
16 + | #[test]
|
17 + | async fn test_dns_caching_performance() {
|
18 + | let dns_server = test_dns_server::setup_dns_server().await;
|
19 + | let (dns_ip, dns_port) = dns_server.addr();
|
20 + |
|
21 + | let resolver = CachingDnsResolver::builder()
|
22 + | .nameservers(&[dns_ip], dns_port)
|
23 + | .cache_size(1)
|
24 + | .build();
|
25 + |
|
26 + | let hostname = "example.com";
|
27 + |
|
28 + | // First resolution should hit the network
|
29 + | let start = Instant::now();
|
30 + | let first_result = resolver.resolve_dns(hostname).await;
|
31 + | let first_duration = start.elapsed();
|
32 + |
|
33 + | let first_ips = first_result.unwrap();
|
34 + | assert!(!first_ips.is_empty());
|
35 + |
|
36 + | // Second resolution should hit the cache
|
37 + | let start = Instant::now();
|
38 + | let second_result = resolver.resolve_dns(hostname).await;
|
39 + | let second_duration = start.elapsed();
|
40 + |
|
41 + | let second_ips = second_result.unwrap();
|
42 + |
|
43 + | // Verify same IPs returned
|
44 + | assert_eq!(first_ips, second_ips);
|
45 + |
|
46 + | // Verify correct IP returned
|
47 + | assert_eq!(vec![IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))], first_ips);
|
48 + |
|
49 + | // Cache hit should be faster
|
50 + | assert!(second_duration < first_duration);
|
51 + | }
|
52 + |
|
53 + | // Ignored since the cache is only eventually consistent w.r.t. size. So hard to get
|
54 + | // an exact measure. But the logs here are useful to manually check the performance,
|
55 + | // so not deleting this.
|
56 + | #[test]
|
57 + | #[ignore = "Cache is eventually consistent w.r.t. size."]
|
58 + | async fn test_dns_cache_size_limit() {
|
59 + | let dns_server = test_dns_server::setup_dns_server().await;
|
60 + | let (dns_ip, dns_port) = dns_server.addr();
|
61 + |
|
62 + | let resolver = CachingDnsResolver::builder()
|
63 + | .nameservers(&[dns_ip], dns_port)
|
64 + | .cache_size(1)
|
65 + | .build();
|
66 + |
|
67 + | // Resolve first hostname
|
68 + | let start = Instant::now();
|
69 + | let result1 = resolver.resolve_dns("example.com").await.unwrap();
|
70 + | let result1_duration = start.elapsed();
|
71 + | assert_eq!(vec![IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))], result1);
|
72 + |
|
73 + | // Resolve second hostname (should replace first response in cache)
|
74 + | let start = Instant::now();
|
75 + | let result2 = resolver.resolve_dns("aws.com").await.unwrap();
|
76 + | let result2_duration = start.elapsed();
|
77 + | assert_eq!(vec![IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))], result2);
|
78 + |
|
79 + | println!("RESULT1 DURATION: {result1_duration:#?}");
|
80 + | println!("RESULT2 DURATION: {result2_duration:#?}");
|
81 + |
|
82 + | let start = Instant::now();
|
83 + | let _result1_again = resolver.resolve_dns("example.com").await.unwrap();
|
84 + | let result1_again_duration = start.elapsed();
|
85 + |
|
86 + | let start = Instant::now();
|
87 + | let _result2_again = resolver.resolve_dns("aws.com").await.unwrap();
|
88 + | let result2_again_duration = start.elapsed();
|
89 + |
|
90 + | println!("RESULT1 AGAIN DURATION: {result1_again_duration:#?}");
|
91 + | println!("RESULT2 AGAIN DURATION: {result2_again_duration:#?}");
|
92 + | assert!(result1_again_duration < result2_again_duration);
|
93 + | }
|
94 + |
|
95 + | #[test]
|
96 + | async fn test_dns_error_handling() {
|
97 + | let dns_server = test_dns_server::setup_dns_server().await;
|
98 + | let (dns_ip, dns_port) = dns_server.addr();
|
99 + |
|
100 + | let resolver = CachingDnsResolver::builder()
|
101 + | .nameservers(&[dns_ip], dns_port)
|
102 + | .timeout(Duration::from_millis(100))
|
103 + | .attempts(1)
|
104 + | .build();
|
105 + |
|
106 + | // Try to resolve an invalid hostname
|
107 + | let result = resolver
|
108 + | .resolve_dns("invalid.nonexistent.domain.test")
|
109 + | .await;
|
110 + | assert!(result.is_err());
|
111 + | }
|
112 + |
|
113 + | // Kind of janky minimal test utility for creating a local DNS server
|
114 + | #[cfg(test)]
|
115 + | mod test_dns_server {
|
116 + | use std::{
|
117 + | collections::HashMap,
|
118 + | net::{IpAddr, Ipv4Addr, SocketAddr},
|
119 + | sync::Arc,
|
120 + | time::Duration,
|
121 + | };
|
122 + | use tokio::{net::UdpSocket, sync::Notify, task::JoinHandle};
|
123 + |
|
124 + | pub async fn setup_dns_server() -> TestDnsServer {
|
125 + | let mut records = HashMap::new();
|
126 + | records.insert(
|
127 + | "example.com".to_string(),
|
128 + | IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
129 + | );
|
130 + | records.insert("aws.com".to_string(), IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)));
|
131 + | records.insert(
|
132 + | "foo.com".to_string(),
|
133 + | IpAddr::V4(Ipv4Addr::new(9, 10, 11, 12)),
|
134 + | );
|
135 + |
|
136 + | TestDnsServer::start(records).await.unwrap()
|
137 + | }
|
138 + |
|
139 + | pub struct TestDnsServer {
|
140 + | handle: JoinHandle<()>,
|
141 + | addr: SocketAddr,
|
142 + | shutdown: Arc<Notify>,
|
143 + | }
|
144 + |
|
145 + | impl TestDnsServer {
|
146 + | pub async fn start(
|
147 + | records: HashMap<String, IpAddr>,
|
148 + | ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
149 + | // localhost, random port
|
150 + | let socket = UdpSocket::bind("127.0.0.1:0").await?;
|
151 + | let addr = socket.local_addr()?;
|
152 + |
|
153 + | let shutdown = Arc::new(Notify::new());
|
154 + | let shutdown_clone = shutdown.clone();
|
155 + |
|
156 + | let handle = tokio::spawn(async move {
|
157 + | let mut buf = [0; 512];
|
158 + | loop {
|
159 + | tokio::select! {
|
160 + | _ = shutdown_clone.notified() => break,
|
161 + | result = socket.recv_from(&mut buf) => {
|
162 + | if let Ok((len, src)) = result {
|
163 + | // Short sleep before returning DNS response to simulate network latency
|
164 + | tokio::time::sleep(Duration::from_millis(1000)).await;
|
165 + | let response = create_dns_response(&buf[..len], &records);
|
166 + | let _ = socket.send_to(&response, src).await;
|
167 + | }
|
168 + | }
|
169 + | }
|
170 + | }
|
171 + | });
|
172 + |
|
173 + | Ok(TestDnsServer {
|
174 + | handle,
|
175 + | addr,
|
176 + | shutdown,
|
177 + | })
|
178 + | }
|
179 + |
|
180 + | pub fn addr(&self) -> (IpAddr, u16) {
|
181 + | (self.addr.ip(), self.addr.port())
|
182 + | }
|
183 + | }
|
184 + |
|
185 + | impl Drop for TestDnsServer {
|
186 + | fn drop(&mut self) {
|
187 + | self.shutdown.notify_one();
|
188 + | self.handle.abort();
|
189 + | }
|
190 + | }
|
191 + |
|
192 + | fn create_dns_response(query: &[u8], records: &HashMap<String, IpAddr>) -> Vec<u8> {
|
193 + | let parsed = DnsQuery::parse(query).unwrap_or_default();
|
194 + | let ip = records.get(&parsed.domain).copied().unwrap();
|
195 + |
|
196 + | let response = DnsResponse {
|
197 + | id: parsed.id,
|
198 + | flags: 0x8180, // Standard response flags
|
199 + | question: parsed.domain,
|
200 + | answer_ip: ip,
|
201 + | ttl: 300,
|
202 + | };
|
203 + |
|
204 + | response.to_bytes()
|
205 + | }
|
206 + |
|
207 + | #[derive(Debug, Default)]
|
208 + | #[allow(dead_code)]
|
209 + | struct DnsQuery {
|
210 + | id: u16,
|
211 + | flags: u16,
|
212 + | question_count: u16,
|
213 + | domain: String,
|
214 + | query_type: u16,
|
215 + | query_class: u16,
|
216 + | }
|
217 + |
|
218 + | impl DnsQuery {
|
219 + | fn parse(data: &[u8]) -> Option<Self> {
|
220 + | if data.len() < 12 {
|
221 + | return None;
|
222 + | }
|
223 + |
|
224 + | let id = u16::from_be_bytes([data[0], data[1]]);
|
225 + | let flags = u16::from_be_bytes([data[2], data[3]]);
|
226 + | let question_count = u16::from_be_bytes([data[4], data[5]]);
|
227 + |
|
228 + | if question_count == 0 {
|
229 + | return None;
|
230 + | }
|
231 + |
|
232 + | // Parse domain name starting at byte 12
|
233 + | let mut pos = 12;
|
234 + | let mut domain = String::new();
|
235 + |
|
236 + | while pos < data.len() {
|
237 + | let len = data[pos] as usize;
|
238 + | if len == 0 {
|
239 + | pos += 1;
|
240 + | break;
|
241 + | }
|
242 + |
|
243 + | if !domain.is_empty() {
|
244 + | domain.push('.');
|
245 + | }
|
246 + |
|
247 + | pos += 1;
|
248 + | if pos + len > data.len() {
|
249 + | return None;
|
250 + | }
|
251 + |
|
252 + | if let Ok(label) = std::str::from_utf8(&data[pos..pos + len]) {
|
253 + | domain.push_str(label);
|
254 + | }
|
255 + | pos += len;
|
256 + | }
|
257 + |
|
258 + | if pos + 4 > data.len() {
|
259 + | return None;
|
260 + | }
|
261 + |
|
262 + | let query_type = u16::from_be_bytes([data[pos], data[pos + 1]]);
|
263 + | let query_class = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
|
264 + |
|
265 + | Some(DnsQuery {
|
266 + | id,
|
267 + | flags,
|
268 + | question_count,
|
269 + | domain,
|
270 + | query_type,
|
271 + | query_class,
|
272 + | })
|
273 + | }
|
274 + | }
|
275 + |
|
276 + | #[derive(Debug)]
|
277 + | #[allow(dead_code)]
|
278 + | struct DnsResponse {
|
279 + | id: u16,
|
280 + | flags: u16,
|
281 + | question: String,
|
282 + | answer_ip: IpAddr,
|
283 + | ttl: u32,
|
284 + | }
|
285 + |
|
286 + | impl DnsResponse {
|
287 + | fn to_bytes(&self) -> Vec<u8> {
|
288 + | // 30ish required bytes, 11 more added for the question section
|
289 + | // since the longest domain we currently use is 11 bytes long
|
290 + | let mut response = Vec::with_capacity(41);
|
291 + |
|
292 + | // Header (12 bytes) all values besides id/flags hardcoded
|
293 + | response.extend_from_slice(&self.id.to_be_bytes());
|
294 + | response.extend_from_slice(&self.flags.to_be_bytes());
|
295 + | response.extend_from_slice(&1u16.to_be_bytes()); // Questions: 1
|
296 + | response.extend_from_slice(&1u16.to_be_bytes()); // Answers: 1
|
297 + | response.extend_from_slice(&0u16.to_be_bytes()); // Authority: 0
|
298 + | response.extend_from_slice(&0u16.to_be_bytes()); // Additional: 0
|
299 + |
|
300 + | // Question section
|
301 + | // In a more ideal world the DnsResponse would contain a ref to the
|
302 + | // DnsQuery that triggered this response and recreate the question section
|
303 + | // from that
|
304 + | for label in self.question.split('.') {
|
305 + | response.push(label.len() as u8);
|
306 + | response.extend_from_slice(label.as_bytes());
|
307 + | }
|
308 + | response.push(0); // End of name
|
309 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
310 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
311 + |
|
312 + | // Answer section
|
313 + | response.extend_from_slice(&[0xc0, 0x0c]); // Name pointer
|
314 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
315 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
316 + | response.extend_from_slice(&self.ttl.to_be_bytes()); // TTL
|
317 + |
|
318 + | match self.answer_ip {
|
319 + | IpAddr::V4(ipv4) => {
|
320 + | response.extend_from_slice(&4u16.to_be_bytes()); // Data length
|
321 + | response.extend_from_slice(&ipv4.octets());
|
322 + | }
|
323 + | IpAddr::V6(_) => {
|
324 + | // Unsupported, fallback to IPv4
|
325 + | response.extend_from_slice(&4u16.to_be_bytes());
|
326 + | response.extend_from_slice(&[127, 0, 0, 1]);
|
327 + | }
|
328 + | }
|
329 + |
|
330 + | response
|
331 + | }
|
332 + | }
|
333 + | }
|