aws_smithy_http_server/layer/
alb_health_check.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Middleware for handling [ALB health
7//! checks](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/target-group-health-checks.html).
8//!
9//! # Example
10//!
11//! ```no_run
12//! use aws_smithy_http_server::layer::alb_health_check::AlbHealthCheckLayer;
13//! use http::StatusCode;
14//! use tower::Layer;
15//!
16//! // Handle all `/ping` health check requests by returning a `200 OK`.
17//! let ping_layer = AlbHealthCheckLayer::from_handler("/ping", |_req: hyper::Request<hyper::body::Incoming>| async {
18//!     StatusCode::OK
19//! });
20//! # async fn handle() { }
21//! let app = tower::service_fn(handle);
22//! let app = ping_layer.layer(app);
23//! ```
24
25use std::borrow::Cow;
26use std::convert::Infallible;
27use std::task::{Context, Poll};
28
29use futures_util::{Future, FutureExt};
30use http::StatusCode;
31use hyper::{Request, Response};
32use pin_project_lite::pin_project;
33use tower::{service_fn, util::Oneshot, Layer, Service, ServiceExt};
34
35use crate::body::BoxBody;
36
37use crate::plugin::either::Either;
38use crate::plugin::either::EitherProj;
39
40/// A [`tower::Layer`] used to apply [`AlbHealthCheckService`].
41#[derive(Clone, Debug)]
42pub struct AlbHealthCheckLayer<HealthCheckHandler> {
43    health_check_uri: Cow<'static, str>,
44    health_check_handler: HealthCheckHandler,
45}
46
47impl AlbHealthCheckLayer<()> {
48    /// Handle health check requests at `health_check_uri` with the specified handler.
49    pub fn from_handler<
50        B: http_body::Body,
51        HandlerFuture: Future<Output = StatusCode>,
52        H: Fn(Request<B>) -> HandlerFuture + Clone,
53    >(
54        health_check_uri: impl Into<Cow<'static, str>>,
55        health_check_handler: H,
56    ) -> AlbHealthCheckLayer<
57        impl Service<
58                Request<B>,
59                Response = StatusCode,
60                Error = Infallible,
61                Future = impl Future<Output = Result<StatusCode, Infallible>>,
62            > + Clone,
63    > {
64        let service = service_fn(move |req| health_check_handler(req).map(Ok));
65
66        AlbHealthCheckLayer::new(health_check_uri, service)
67    }
68
69    /// Handle health check requests at `health_check_uri` with the specified service.
70    pub fn new<B, H: Service<Request<B>, Response = StatusCode>>(
71        health_check_uri: impl Into<Cow<'static, str>>,
72        health_check_handler: H,
73    ) -> AlbHealthCheckLayer<H> {
74        AlbHealthCheckLayer {
75            health_check_uri: health_check_uri.into(),
76            health_check_handler,
77        }
78    }
79}
80
81impl<S, H: Clone> Layer<S> for AlbHealthCheckLayer<H> {
82    type Service = AlbHealthCheckService<H, S>;
83
84    fn layer(&self, inner: S) -> Self::Service {
85        AlbHealthCheckService {
86            inner,
87            layer: self.clone(),
88        }
89    }
90}
91
92/// A middleware [`Service`] responsible for handling health check requests.
93#[derive(Clone, Debug)]
94pub struct AlbHealthCheckService<H, S> {
95    inner: S,
96    layer: AlbHealthCheckLayer<H>,
97}
98
99impl<B, H, S> Service<Request<B>> for AlbHealthCheckService<H, S>
100where
101    S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
102    S::Future: Send + 'static,
103    H: Service<Request<B>, Response = StatusCode, Error = Infallible> + Clone,
104{
105    type Response = S::Response;
106    type Error = S::Error;
107    type Future = AlbHealthCheckFuture<B, H, S>;
108
109    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        // The check that the service is ready is done by `Oneshot` below.
111        Poll::Ready(Ok(()))
112    }
113
114    fn call(&mut self, req: Request<B>) -> Self::Future {
115        if req.uri() == self.layer.health_check_uri.as_ref() {
116            let clone = self.layer.health_check_handler.clone();
117            let service = std::mem::replace(&mut self.layer.health_check_handler, clone);
118            let handler_future = service.oneshot(req);
119
120            AlbHealthCheckFuture::handler_future(handler_future)
121        } else {
122            let clone = self.inner.clone();
123            let service = std::mem::replace(&mut self.inner, clone);
124            let service_future = service.oneshot(req);
125
126            AlbHealthCheckFuture::service_future(service_future)
127        }
128    }
129}
130
131type HealthCheckFutureInner<B, H, S> = Either<Oneshot<H, Request<B>>, Oneshot<S, Request<B>>>;
132
133pin_project! {
134    /// Future for [`AlbHealthCheckService`].
135    pub struct AlbHealthCheckFuture<B, H: Service<Request<B>, Response = StatusCode>, S: Service<Request<B>>> {
136        #[pin]
137        inner: HealthCheckFutureInner<B, H, S>
138    }
139}
140
141impl<B, H, S> AlbHealthCheckFuture<B, H, S>
142where
143    H: Service<Request<B>, Response = StatusCode>,
144    S: Service<Request<B>>,
145{
146    fn handler_future(handler_future: Oneshot<H, Request<B>>) -> Self {
147        Self {
148            inner: Either::Left { value: handler_future },
149        }
150    }
151
152    fn service_future(service_future: Oneshot<S, Request<B>>) -> Self {
153        Self {
154            inner: Either::Right { value: service_future },
155        }
156    }
157}
158
159impl<B, H, S> Future for AlbHealthCheckFuture<B, H, S>
160where
161    H: Service<Request<B>, Response = StatusCode, Error = Infallible>,
162    S: Service<Request<B>, Response = Response<BoxBody>>,
163{
164    type Output = Result<S::Response, S::Error>;
165
166    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167        let either_proj = self.project().inner.project();
168
169        match either_proj {
170            EitherProj::Left { value } => {
171                let polled: Poll<Self::Output> = value.poll(cx).map(|res| {
172                    res.map(|status_code| {
173                        Response::builder()
174                            .status(status_code)
175                            .body(crate::body::empty())
176                            .unwrap()
177                    })
178                    .map_err(|never| match never {})
179                });
180                polled
181            }
182            EitherProj::Right { value } => value.poll(cx),
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use http::Method;
191    use tower::{service_fn, ServiceExt};
192
193    #[tokio::test]
194    async fn test_health_check_handler_responds_to_matching_uri() {
195        let layer = AlbHealthCheckLayer::from_handler("/health", |_req| async { StatusCode::OK });
196        let inner_service = service_fn(|_req| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
197        let service = layer.layer(inner_service);
198
199        let request = Request::builder()
200            .method(Method::GET)
201            .uri("/health")
202            .body(crate::body::empty())
203            .unwrap();
204
205        let response = service.oneshot(request).await.unwrap();
206        assert_eq!(response.status(), StatusCode::OK);
207    }
208
209    #[tokio::test]
210    async fn test_non_health_check_requests_pass_through() {
211        let layer = AlbHealthCheckLayer::from_handler("/health", |_req| async { StatusCode::OK });
212        let inner_service = service_fn(|_req| async {
213            Ok::<_, Infallible>(
214                Response::builder()
215                    .status(StatusCode::ACCEPTED)
216                    .body(crate::body::empty())
217                    .unwrap(),
218            )
219        });
220        let service = layer.layer(inner_service);
221
222        let request = Request::builder()
223            .method(Method::GET)
224            .uri("/api/data")
225            .body(crate::body::empty())
226            .unwrap();
227
228        let response = service.oneshot(request).await.unwrap();
229        assert_eq!(response.status(), StatusCode::ACCEPTED);
230    }
231
232    #[tokio::test]
233    async fn test_handler_can_read_request_headers() {
234        let layer = AlbHealthCheckLayer::from_handler("/ping", |req| async move {
235            if req.headers().get("x-health-check").is_some() {
236                StatusCode::OK
237            } else {
238                StatusCode::SERVICE_UNAVAILABLE
239            }
240        });
241        let inner_service = service_fn(|_req| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
242        let service = layer.layer(inner_service);
243
244        // Test with header present
245        let request = Request::builder()
246            .uri("/ping")
247            .header("x-health-check", "true")
248            .body(crate::body::empty())
249            .unwrap();
250
251        let response = service.clone().oneshot(request).await.unwrap();
252        assert_eq!(response.status(), StatusCode::OK);
253
254        // Test without header
255        let request = Request::builder().uri("/ping").body(crate::body::empty()).unwrap();
256
257        let response = service.oneshot(request).await.unwrap();
258        assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
259    }
260
261    #[tokio::test]
262    async fn test_works_with_any_body_type() {
263        use bytes::Bytes;
264        use http_body_util::Full;
265
266        let layer = AlbHealthCheckLayer::from_handler("/health", |_req: Request<Full<Bytes>>| async { StatusCode::OK });
267        let inner_service =
268            service_fn(|_req: Request<Full<Bytes>>| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
269        let service = layer.layer(inner_service);
270
271        let request = Request::builder()
272            .uri("/health")
273            .body(Full::new(Bytes::from("test body")))
274            .unwrap();
275
276        let response = service.oneshot(request).await.unwrap();
277        assert_eq!(response.status(), StatusCode::OK);
278    }
279
280    #[tokio::test]
281    async fn test_works_with_custom_body_type() {
282        use bytes::Bytes;
283        use http_body::Frame;
284        use std::pin::Pin;
285        use std::task::{Context, Poll};
286
287        // Custom body type that implements http_body::Body
288        struct CustomBody {
289            data: Option<Bytes>,
290        }
291
292        impl CustomBody {
293            fn new(data: Bytes) -> Self {
294                Self { data: Some(data) }
295            }
296        }
297
298        impl http_body::Body for CustomBody {
299            type Data = Bytes;
300            type Error = std::io::Error;
301
302            fn poll_frame(
303                mut self: Pin<&mut Self>,
304                _cx: &mut Context<'_>,
305            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
306                if let Some(data) = self.data.take() {
307                    Poll::Ready(Some(Ok(Frame::data(data))))
308                } else {
309                    Poll::Ready(None)
310                }
311            }
312        }
313
314        let layer = AlbHealthCheckLayer::from_handler("/health", |_req: Request<CustomBody>| async { StatusCode::OK });
315        let inner_service =
316            service_fn(|_req: Request<CustomBody>| async { Ok::<_, Infallible>(Response::new(crate::body::empty())) });
317        let service = layer.layer(inner_service);
318
319        let request = Request::builder()
320            .uri("/health")
321            .body(CustomBody::new(Bytes::from("custom body")))
322            .unwrap();
323
324        let response = service.oneshot(request).await.unwrap();
325        assert_eq!(response.status(), StatusCode::OK);
326    }
327}