1 + | /*
|
2 + | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
3 + | * SPDX-License-Identifier: Apache-2.0
|
4 + | */
|
5 + |
|
6 + | #![cfg(all(feature = "client", feature = "rt-tokio",))]
|
7 + |
|
8 + | use aws_smithy_runtime::client::dns::CachingDnsResolver;
|
9 + | use aws_smithy_runtime_api::client::dns::ResolveDns;
|
10 + | use std::{
|
11 + | collections::HashMap,
|
12 + | net::{IpAddr, Ipv4Addr},
|
13 + | time::Instant,
|
14 + | };
|
15 + | use tokio::test;
|
16 + |
|
17 + | // Note: I couldn't come up with a way to mock the DNS returns to hickory so these
|
18 + | // tests actually hit the network. We should ideally find a better way to do this.
|
19 + |
|
20 + | #[test]
|
21 + | async fn test_dns_caching_performance() {
|
22 + | let resolver = CachingDnsResolver::default();
|
23 + | let hostname = "example.com";
|
24 + |
|
25 + | // First resolution should hit the network
|
26 + | let start = Instant::now();
|
27 + | let first_result = resolver.resolve_dns(hostname).await;
|
28 + | let first_duration = start.elapsed();
|
29 + |
|
30 + | let first_ips = first_result.unwrap();
|
31 + | assert!(!first_ips.is_empty());
|
32 + |
|
33 + | // Second resolution should hit the cache
|
34 + | let start = Instant::now();
|
35 + | let second_result = resolver.resolve_dns(hostname).await;
|
36 + | let second_duration = start.elapsed();
|
37 + |
|
38 + | let second_ips = second_result.unwrap();
|
39 + |
|
40 + | // Verify same IPs returned
|
41 + | assert_eq!(first_ips, second_ips);
|
42 + |
|
43 + | // Cache hit should be faster
|
44 + | assert!(second_duration < first_duration);
|
45 + | }
|
46 + |
|
47 + | #[test]
|
48 + | #[tracing_test::traced_test]
|
49 + | async fn test_dns_cache_size_limit() {
|
50 + | let mut records = HashMap::new();
|
51 + | records.insert(
|
52 + | "example.com".to_string(),
|
53 + | IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
|
54 + | );
|
55 + | records.insert("aws.com".to_string(), IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)));
|
56 + |
|
57 + | let dns_server = test_dns_server::TestDnsServer::start(records)
|
58 + | .await
|
59 + | .unwrap();
|
60 + |
|
61 + | let (dns_ip, dns_port) = dns_server.addr();
|
62 + | println!("DNS_IP: {dns_ip:#?}, DNS_PORT: {dns_port:#?}");
|
63 + | let resolver = CachingDnsResolver::builder()
|
64 + | .nameservers(&[dns_ip], dns_port)
|
65 + | .cache_size(1)
|
66 + | .build();
|
67 + |
|
68 + | // Resolve first hostname
|
69 + | let result1 = resolver.resolve_dns("example.com").await;
|
70 + | // assert!(result1.is_ok());
|
71 + |
|
72 + | // Resolve second hostname (should not be placed into cache because result1 is already occupying
|
73 + | // the single allocated space and entries are only evicted from the cache when their TTL expires)
|
74 + | let result2 = resolver.resolve_dns("aws.com").await;
|
75 + | // assert!(result2.is_ok());
|
76 + |
|
77 + | println!("RESULT1: {result1:#?}");
|
78 + | println!("RESULT2: {result2:#?}");
|
79 + |
|
80 + | let start = Instant::now();
|
81 + | let result2_again = resolver.resolve_dns("aws.com").await;
|
82 + | let result2_again_duration = start.elapsed();
|
83 + |
|
84 + | let start = Instant::now();
|
85 + | let result1_again = resolver.resolve_dns("example.com").await;
|
86 + | let result1_again_duration = start.elapsed();
|
87 + |
|
88 + | assert!(result1_again.is_ok());
|
89 + | assert!(result2_again.is_ok());
|
90 + |
|
91 + | // result1_again should be resolved more quickly than result2_again
|
92 + | println!("result1_again_duration: {:?}", result1_again_duration);
|
93 + | println!("result2_again_duration: {:?}", result2_again_duration);
|
94 + | assert!(result1_again_duration < result2_again_duration);
|
95 + | }
|
96 + |
|
97 + | #[test]
|
98 + | async fn test_dns_error_handling() {
|
99 + | let resolver = CachingDnsResolver::default();
|
100 + |
|
101 + | // Try to resolve an invalid hostname
|
102 + | let result = resolver
|
103 + | .resolve_dns("invalid.nonexistent.domain.test")
|
104 + | .await;
|
105 + | assert!(result.is_err());
|
106 + | }
|
107 + | // Test utility for creating a local DNS server
|
108 + | #[cfg(test)]
|
109 + | mod test_dns_server {
|
110 + | use std::{
|
111 + | collections::HashMap,
|
112 + | net::{IpAddr, Ipv4Addr, SocketAddr},
|
113 + | sync::Arc,
|
114 + | };
|
115 + | use tokio::{net::UdpSocket, sync::Notify, task::JoinHandle};
|
116 + |
|
117 + | pub struct TestDnsServer {
|
118 + | handle: JoinHandle<()>,
|
119 + | addr: SocketAddr,
|
120 + | shutdown: Arc<Notify>,
|
121 + | }
|
122 + |
|
123 + | impl TestDnsServer {
|
124 + | pub async fn start(
|
125 + | records: HashMap<String, IpAddr>,
|
126 + | ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
127 + | // localhost, random port
|
128 + | let socket = UdpSocket::bind("127.0.0.1:0").await?;
|
129 + | let addr = socket.local_addr()?;
|
130 + |
|
131 + | let shutdown = Arc::new(Notify::new());
|
132 + | let shutdown_clone = shutdown.clone();
|
133 + |
|
134 + | let handle = tokio::spawn(async move {
|
135 + | let mut buf = [0; 512];
|
136 + | loop {
|
137 + | tokio::select! {
|
138 + | _ = shutdown_clone.notified() => break,
|
139 + | result = socket.recv_from(&mut buf) => {
|
140 + | // println!("IN SOCKET RECV_FROM: {buf:#?}");
|
141 + | if let Ok((len, src)) = result {
|
142 + | // Simple DNS response - just echo back with a mock A record
|
143 + | // This is a minimal implementation for testing purposes
|
144 + | let response = create_dns_response(&buf[..len], &records);
|
145 + | // println!("SOCKET SEND RES: {response:#?}");
|
146 + | let res = socket.send_to(&response, src).await;
|
147 + |
|
148 + | }
|
149 + | }
|
150 + | }
|
151 + | }
|
152 + | });
|
153 + |
|
154 + | Ok(TestDnsServer {
|
155 + | handle,
|
156 + | addr,
|
157 + | shutdown,
|
158 + | })
|
159 + | }
|
160 + |
|
161 + | pub fn addr(&self) -> (IpAddr, u16) {
|
162 + | (self.addr.ip(), self.addr.port())
|
163 + | }
|
164 + | }
|
165 + |
|
166 + | impl Drop for TestDnsServer {
|
167 + | fn drop(&mut self) {
|
168 + | self.shutdown.notify_one();
|
169 + | self.handle.abort();
|
170 + | }
|
171 + | }
|
172 + |
|
173 + | fn create_dns_response(query: &[u8], records: &HashMap<String, IpAddr>) -> Vec<u8> {
|
174 + | let parsed = DnsQuery::parse(query).unwrap_or_default();
|
175 + | let ip = records
|
176 + | .get(&parsed.domain)
|
177 + | .copied()
|
178 + | .unwrap_or(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
|
179 + |
|
180 + | let response = DnsResponse {
|
181 + | id: parsed.id,
|
182 + | flags: 0x8180, // Standard response flags
|
183 + | question: parsed.domain,
|
184 + | answer_ip: ip,
|
185 + | ttl: 300,
|
186 + | };
|
187 + |
|
188 + | response.serialize()
|
189 + | }
|
190 + |
|
191 + | #[derive(Debug, Default)]
|
192 + | struct DnsQuery {
|
193 + | id: u16,
|
194 + | flags: u16,
|
195 + | question_count: u16,
|
196 + | domain: String,
|
197 + | query_type: u16,
|
198 + | query_class: u16,
|
199 + | }
|
200 + |
|
201 + | impl DnsQuery {
|
202 + | fn parse(data: &[u8]) -> Option<Self> {
|
203 + | if data.len() < 12 {
|
204 + | return None;
|
205 + | }
|
206 + |
|
207 + | let id = u16::from_be_bytes([data[0], data[1]]);
|
208 + | let flags = u16::from_be_bytes([data[2], data[3]]);
|
209 + | let question_count = u16::from_be_bytes([data[4], data[5]]);
|
210 + |
|
211 + | if question_count == 0 {
|
212 + | return None;
|
213 + | }
|
214 + |
|
215 + | // Parse domain name starting at byte 12
|
216 + | let mut pos = 12;
|
217 + | let mut domain = String::new();
|
218 + |
|
219 + | while pos < data.len() {
|
220 + | let len = data[pos] as usize;
|
221 + | if len == 0 {
|
222 + | pos += 1;
|
223 + | break;
|
224 + | }
|
225 + |
|
226 + | if !domain.is_empty() {
|
227 + | domain.push('.');
|
228 + | }
|
229 + |
|
230 + | pos += 1;
|
231 + | if pos + len > data.len() {
|
232 + | return None;
|
233 + | }
|
234 + |
|
235 + | if let Ok(label) = std::str::from_utf8(&data[pos..pos + len]) {
|
236 + | domain.push_str(label);
|
237 + | }
|
238 + | pos += len;
|
239 + | }
|
240 + |
|
241 + | if pos + 4 > data.len() {
|
242 + | return None;
|
243 + | }
|
244 + |
|
245 + | let query_type = u16::from_be_bytes([data[pos], data[pos + 1]]);
|
246 + | let query_class = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
|
247 + |
|
248 + | Some(DnsQuery {
|
249 + | id,
|
250 + | flags,
|
251 + | question_count,
|
252 + | domain,
|
253 + | query_type,
|
254 + | query_class,
|
255 + | })
|
256 + | }
|
257 + | }
|
258 + |
|
259 + | #[derive(Debug)]
|
260 + | struct DnsResponse {
|
261 + | id: u16,
|
262 + | flags: u16,
|
263 + | question: String,
|
264 + | answer_ip: IpAddr,
|
265 + | ttl: u32,
|
266 + | }
|
267 + |
|
268 + | impl DnsResponse {
|
269 + | fn serialize(&self) -> Vec<u8> {
|
270 + | let mut response = Vec::new();
|
271 + |
|
272 + | // Header (12 bytes)
|
273 + | response.extend_from_slice(&self.id.to_be_bytes());
|
274 + | response.extend_from_slice(&self.flags.to_be_bytes());
|
275 + | response.extend_from_slice(&1u16.to_be_bytes()); // Questions: 1
|
276 + | response.extend_from_slice(&1u16.to_be_bytes()); // Answers: 1
|
277 + | response.extend_from_slice(&0u16.to_be_bytes()); // Authority: 0
|
278 + | response.extend_from_slice(&0u16.to_be_bytes()); // Additional: 0
|
279 + |
|
280 + | // Question section
|
281 + | for label in self.question.split('.') {
|
282 + | response.push(label.len() as u8);
|
283 + | response.extend_from_slice(label.as_bytes());
|
284 + | }
|
285 + | response.push(0); // End of name
|
286 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
287 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
288 + |
|
289 + | // Answer section
|
290 + | response.extend_from_slice(&[0xc0, 0x0c]); // Name pointer
|
291 + | response.extend_from_slice(&1u16.to_be_bytes()); // Type A
|
292 + | response.extend_from_slice(&1u16.to_be_bytes()); // Class IN
|
293 + | response.extend_from_slice(&self.ttl.to_be_bytes()); // TTL
|
294 + |
|
295 + | match self.answer_ip {
|
296 + | IpAddr::V4(ipv4) => {
|
297 + | response.extend_from_slice(&4u16.to_be_bytes()); // Data length
|
298 + | response.extend_from_slice(&ipv4.octets());
|
299 + | }
|
300 + | IpAddr::V6(_) => {
|
301 + | response.extend_from_slice(&4u16.to_be_bytes()); // Fallback to IPv4
|
302 + | response.extend_from_slice(&[127, 0, 0, 1]);
|
303 + | }
|
304 + | }
|
305 + |
|
306 + | response
|
307 + | }
|
308 + | }
|
309 + | }
|