aws_smithy_http_server/layer/
alb_health_check.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Middleware for handling [ALB health
7//! checks](https://docs.aws.amazon.com/elasticloadbalancing/latest/application/target-group-health-checks.html).
8//!
9//! # Example
10//!
11//! ```no_run
12//! use aws_smithy_http_server::layer::alb_health_check::AlbHealthCheckLayer;
13//! use hyper::StatusCode;
14//! use tower::Layer;
15//!
16//! // Handle all `/ping` health check requests by returning a `200 OK`.
17//! let ping_layer = AlbHealthCheckLayer::from_handler("/ping", |_req| async {
18//!     StatusCode::OK
19//! });
20//! # async fn handle() { }
21//! let app = tower::service_fn(handle);
22//! let app = ping_layer.layer(app);
23//! ```
24
25use std::borrow::Cow;
26use std::convert::Infallible;
27use std::task::{Context, Poll};
28
29use futures_util::{Future, FutureExt};
30use http::StatusCode;
31use hyper::{Body, Request, Response};
32use pin_project_lite::pin_project;
33use tower::{service_fn, util::Oneshot, Layer, Service, ServiceExt};
34
35use crate::body::BoxBody;
36
37use crate::plugin::either::Either;
38use crate::plugin::either::EitherProj;
39
40/// A [`tower::Layer`] used to apply [`AlbHealthCheckService`].
41#[derive(Clone, Debug)]
42pub struct AlbHealthCheckLayer<HealthCheckHandler> {
43    health_check_uri: Cow<'static, str>,
44    health_check_handler: HealthCheckHandler,
45}
46
47impl AlbHealthCheckLayer<()> {
48    /// Handle health check requests at `health_check_uri` with the specified handler.
49    pub fn from_handler<HandlerFuture: Future<Output = StatusCode>, H: Fn(Request<Body>) -> HandlerFuture + Clone>(
50        health_check_uri: impl Into<Cow<'static, str>>,
51        health_check_handler: H,
52    ) -> AlbHealthCheckLayer<
53        impl Service<
54                Request<Body>,
55                Response = StatusCode,
56                Error = Infallible,
57                Future = impl Future<Output = Result<StatusCode, Infallible>>,
58            > + Clone,
59    > {
60        let service = service_fn(move |req| health_check_handler(req).map(Ok));
61
62        AlbHealthCheckLayer::new(health_check_uri, service)
63    }
64
65    /// Handle health check requests at `health_check_uri` with the specified service.
66    pub fn new<H: Service<Request<Body>, Response = StatusCode>>(
67        health_check_uri: impl Into<Cow<'static, str>>,
68        health_check_handler: H,
69    ) -> AlbHealthCheckLayer<H> {
70        AlbHealthCheckLayer {
71            health_check_uri: health_check_uri.into(),
72            health_check_handler,
73        }
74    }
75}
76
77impl<S, H: Clone> Layer<S> for AlbHealthCheckLayer<H> {
78    type Service = AlbHealthCheckService<H, S>;
79
80    fn layer(&self, inner: S) -> Self::Service {
81        AlbHealthCheckService {
82            inner,
83            layer: self.clone(),
84        }
85    }
86}
87
88/// A middleware [`Service`] responsible for handling health check requests.
89#[derive(Clone, Debug)]
90pub struct AlbHealthCheckService<H, S> {
91    inner: S,
92    layer: AlbHealthCheckLayer<H>,
93}
94
95impl<H, S> Service<Request<Body>> for AlbHealthCheckService<H, S>
96where
97    S: Service<Request<Body>, Response = Response<BoxBody>> + Clone,
98    S::Future: Send + 'static,
99    H: Service<Request<Body>, Response = StatusCode, Error = Infallible> + Clone,
100{
101    type Response = S::Response;
102    type Error = S::Error;
103    type Future = AlbHealthCheckFuture<H, S>;
104
105    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        // The check that the service is ready is done by `Oneshot` below.
107        Poll::Ready(Ok(()))
108    }
109
110    fn call(&mut self, req: Request<Body>) -> Self::Future {
111        if req.uri() == self.layer.health_check_uri.as_ref() {
112            let clone = self.layer.health_check_handler.clone();
113            let service = std::mem::replace(&mut self.layer.health_check_handler, clone);
114            let handler_future = service.oneshot(req);
115
116            AlbHealthCheckFuture::handler_future(handler_future)
117        } else {
118            let clone = self.inner.clone();
119            let service = std::mem::replace(&mut self.inner, clone);
120            let service_future = service.oneshot(req);
121
122            AlbHealthCheckFuture::service_future(service_future)
123        }
124    }
125}
126
127type HealthCheckFutureInner<H, S> = Either<Oneshot<H, Request<Body>>, Oneshot<S, Request<Body>>>;
128
129pin_project! {
130    /// Future for [`AlbHealthCheckService`].
131    pub struct AlbHealthCheckFuture<H: Service<Request<Body>, Response = StatusCode>, S: Service<Request<Body>>> {
132        #[pin]
133        inner: HealthCheckFutureInner<H, S>
134    }
135}
136
137impl<H, S> AlbHealthCheckFuture<H, S>
138where
139    H: Service<Request<Body>, Response = StatusCode>,
140    S: Service<Request<Body>>,
141{
142    fn handler_future(handler_future: Oneshot<H, Request<Body>>) -> Self {
143        Self {
144            inner: Either::Left { value: handler_future },
145        }
146    }
147
148    fn service_future(service_future: Oneshot<S, Request<Body>>) -> Self {
149        Self {
150            inner: Either::Right { value: service_future },
151        }
152    }
153}
154
155impl<H, S> Future for AlbHealthCheckFuture<H, S>
156where
157    H: Service<Request<Body>, Response = StatusCode, Error = Infallible>,
158    S: Service<Request<Body>, Response = Response<BoxBody>>,
159{
160    type Output = Result<S::Response, S::Error>;
161
162    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
163        let either_proj = self.project().inner.project();
164
165        match either_proj {
166            EitherProj::Left { value } => {
167                let polled: Poll<Self::Output> = value.poll(cx).map(|res| {
168                    res.map(|status_code| {
169                        Response::builder()
170                            .status(status_code)
171                            .body(crate::body::empty())
172                            .unwrap()
173                    })
174                    .map_err(|never| match never {})
175                });
176                polled
177            }
178            EitherProj::Right { value } => value.poll(cx),
179        }
180    }
181}