aws_smithy_http_server/layer/
alb_health_check.rs1use std::borrow::Cow;
26use std::convert::Infallible;
27use std::task::{Context, Poll};
28
29use futures_util::{Future, FutureExt};
30use http::StatusCode;
31use hyper::{Request, Response};
32use pin_project_lite::pin_project;
33use tower::{service_fn, util::Oneshot, Layer, Service, ServiceExt};
34
35use crate::body::BoxBody;
36
37use crate::plugin::either::Either;
38use crate::plugin::either::EitherProj;
39
40#[derive(Clone, Debug)]
42pub struct AlbHealthCheckLayer<HealthCheckHandler> {
43 health_check_uri: Cow<'static, str>,
44 health_check_handler: HealthCheckHandler,
45}
46
47impl AlbHealthCheckLayer<()> {
48 pub fn from_handler<
50 B: http_body::Body,
51 HandlerFuture: Future<Output = StatusCode>,
52 H: Fn(Request<B>) -> HandlerFuture + Clone,
53 >(
54 health_check_uri: impl Into<Cow<'static, str>>,
55 health_check_handler: H,
56 ) -> AlbHealthCheckLayer<
57 impl Service<
58 Request<B>,
59 Response = StatusCode,
60 Error = Infallible,
61 Future = impl Future<Output = Result<StatusCode, Infallible>>,
62 > + Clone,
63 > {
64 let service = service_fn(move |req| health_check_handler(req).map(Ok));
65
66 AlbHealthCheckLayer::new(health_check_uri, service)
67 }
68
69 pub fn new<B, H: Service<Request<B>, Response = StatusCode>>(
71 health_check_uri: impl Into<Cow<'static, str>>,
72 health_check_handler: H,
73 ) -> AlbHealthCheckLayer<H> {
74 AlbHealthCheckLayer {
75 health_check_uri: health_check_uri.into(),
76 health_check_handler,
77 }
78 }
79}
80
81impl<S, H: Clone> Layer<S> for AlbHealthCheckLayer<H> {
82 type Service = AlbHealthCheckService<H, S>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 AlbHealthCheckService {
86 inner,
87 layer: self.clone(),
88 }
89 }
90}
91
92#[derive(Clone, Debug)]
94pub struct AlbHealthCheckService<H, S> {
95 inner: S,
96 layer: AlbHealthCheckLayer<H>,
97}
98
99impl<B, H, S> Service<Request<B>> for AlbHealthCheckService<H, S>
100where
101 S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
102 S::Future: Send + 'static,
103 H: Service<Request<B>, Response = StatusCode, Error = Infallible> + Clone,
104{
105 type Response = S::Response;
106 type Error = S::Error;
107 type Future = AlbHealthCheckFuture<B, H, S>;
108
109 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 Poll::Ready(Ok(()))
112 }
113
114 fn call(&mut self, req: Request<B>) -> Self::Future {
115 if req.uri() == self.layer.health_check_uri.as_ref() {
116 let clone = self.layer.health_check_handler.clone();
117 let service = std::mem::replace(&mut self.layer.health_check_handler, clone);
118 let handler_future = service.oneshot(req);
119
120 AlbHealthCheckFuture::handler_future(handler_future)
121 } else {
122 let clone = self.inner.clone();
123 let service = std::mem::replace(&mut self.inner, clone);
124 let service_future = service.oneshot(req);
125
126 AlbHealthCheckFuture::service_future(service_future)
127 }
128 }
129}
130
131type HealthCheckFutureInner<B, H, S> = Either<Oneshot<H, Request<B>>, Oneshot<S, Request<B>>>;
132
133pin_project! {
134 pub struct AlbHealthCheckFuture<B, H: Service<Request<B>, Response = StatusCode>, S: Service<Request<B>>> {
136 #[pin]
137 inner: HealthCheckFutureInner<B, H, S>
138 }
139}
140
141impl<B, H, S> AlbHealthCheckFuture<B, H, S>
142where
143 H: Service<Request<B>, Response = StatusCode>,
144 S: Service<Request<B>>,
145{
146 fn handler_future(handler_future: Oneshot<H, Request<B>>) -> Self {
147 Self {
148 inner: Either::Left { value: handler_future },
149 }
150 }
151
152 fn service_future(service_future: Oneshot<S, Request<B>>) -> Self {
153 Self {
154 inner: Either::Right { value: service_future },
155 }
156 }
157}
158
159impl<B, H, S> Future for AlbHealthCheckFuture<B, H, S>
160where
161 H: Service<Request<B>, Response = StatusCode, Error = Infallible>,
162 S: Service<Request<B>, Response = Response<BoxBody>>,
163{
164 type Output = Result<S::Response, S::Error>;
165
166 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let either_proj = self.project().inner.project();
168
169 match either_proj {
170 EitherProj::Left { value } => {
171 let polled: Poll<Self::Output> = value.poll(cx).map(|res| {
172 res.map(|status_code| {
173 Response::builder()
174 .status(status_code)
175 .body(crate::body::empty())
176 .unwrap()
177 })
178 .map_err(|never| match never {})
179 });
180 polled
181 }
182 EitherProj::Right { value } => value.poll(cx),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use http::Method;
191 use tower::{service_fn, ServiceExt};
192
193 #[tokio::test]
194 async fn test_health_check_handler_responds_to_matching_uri() {
195 let layer = AlbHealthCheckLayer::from_handler("/health", |_req| async { StatusCode::OK });
196 let inner_service = service_fn(|_req| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
197 let service = layer.layer(inner_service);
198
199 let request = Request::builder()
200 .method(Method::GET)
201 .uri("/health")
202 .body(crate::body::empty())
203 .unwrap();
204
205 let response = service.oneshot(request).await.unwrap();
206 assert_eq!(response.status(), StatusCode::OK);
207 }
208
209 #[tokio::test]
210 async fn test_non_health_check_requests_pass_through() {
211 let layer = AlbHealthCheckLayer::from_handler("/health", |_req| async { StatusCode::OK });
212 let inner_service = service_fn(|_req| async {
213 Ok::<_, Infallible>(
214 Response::builder()
215 .status(StatusCode::ACCEPTED)
216 .body(crate::body::empty())
217 .unwrap(),
218 )
219 });
220 let service = layer.layer(inner_service);
221
222 let request = Request::builder()
223 .method(Method::GET)
224 .uri("/api/data")
225 .body(crate::body::empty())
226 .unwrap();
227
228 let response = service.oneshot(request).await.unwrap();
229 assert_eq!(response.status(), StatusCode::ACCEPTED);
230 }
231
232 #[tokio::test]
233 async fn test_handler_can_read_request_headers() {
234 let layer = AlbHealthCheckLayer::from_handler("/ping", |req| async move {
235 if req.headers().get("x-health-check").is_some() {
236 StatusCode::OK
237 } else {
238 StatusCode::SERVICE_UNAVAILABLE
239 }
240 });
241 let inner_service = service_fn(|_req| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
242 let service = layer.layer(inner_service);
243
244 let request = Request::builder()
246 .uri("/ping")
247 .header("x-health-check", "true")
248 .body(crate::body::empty())
249 .unwrap();
250
251 let response = service.clone().oneshot(request).await.unwrap();
252 assert_eq!(response.status(), StatusCode::OK);
253
254 let request = Request::builder().uri("/ping").body(crate::body::empty()).unwrap();
256
257 let response = service.oneshot(request).await.unwrap();
258 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
259 }
260
261 #[tokio::test]
262 async fn test_works_with_any_body_type() {
263 use bytes::Bytes;
264 use http_body_util::Full;
265
266 let layer = AlbHealthCheckLayer::from_handler("/health", |_req: Request<Full<Bytes>>| async { StatusCode::OK });
267 let inner_service =
268 service_fn(|_req: Request<Full<Bytes>>| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
269 let service = layer.layer(inner_service);
270
271 let request = Request::builder()
272 .uri("/health")
273 .body(Full::new(Bytes::from("test body")))
274 .unwrap();
275
276 let response = service.oneshot(request).await.unwrap();
277 assert_eq!(response.status(), StatusCode::OK);
278 }
279
280 #[tokio::test]
281 async fn test_works_with_custom_body_type() {
282 use bytes::Bytes;
283 use http_body::Frame;
284 use std::pin::Pin;
285 use std::task::{Context, Poll};
286
287 struct CustomBody {
289 data: Option<Bytes>,
290 }
291
292 impl CustomBody {
293 fn new(data: Bytes) -> Self {
294 Self { data: Some(data) }
295 }
296 }
297
298 impl http_body::Body for CustomBody {
299 type Data = Bytes;
300 type Error = std::io::Error;
301
302 fn poll_frame(
303 mut self: Pin<&mut Self>,
304 _cx: &mut Context<'_>,
305 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
306 if let Some(data) = self.data.take() {
307 Poll::Ready(Some(Ok(Frame::data(data))))
308 } else {
309 Poll::Ready(None)
310 }
311 }
312 }
313
314 let layer = AlbHealthCheckLayer::from_handler("/health", |_req: Request<CustomBody>| async { StatusCode::OK });
315 let inner_service =
316 service_fn(|_req: Request<CustomBody>| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
317 let service = layer.layer(inner_service);
318
319 let request = Request::builder()
320 .uri("/health")
321 .body(CustomBody::new(Bytes::from("custom body")))
322 .unwrap();
323
324 let response = service.oneshot(request).await.unwrap();
325 assert_eq!(response.status(), StatusCode::OK);
326 }
327}