aws_smithy_http_server/request/
request_id.rs1use std::future::Future;
50use std::{
51    fmt::Display,
52    task::{Context, Poll},
53};
54
55use futures_util::TryFuture;
56use http::request::Parts;
57use http::{header::HeaderName, HeaderValue, Response};
58use thiserror::Error;
59use tower::{Layer, Service};
60use uuid::Uuid;
61
62use crate::{body::BoxBody, response::IntoResponse};
63
64use super::{internal_server_error, FromParts};
65
66#[derive(Clone, Debug)]
70pub struct ServerRequestId {
71    id: Uuid,
72}
73
74#[non_exhaustive]
76#[derive(Debug, Error)]
77#[error("the `ServerRequestId` is not present in the `http::Request`")]
78pub struct MissingServerRequestId;
79
80impl ServerRequestId {
81    pub fn new() -> Self {
82        Self { id: Uuid::new_v4() }
83    }
84
85    pub(crate) fn to_header(&self) -> HeaderValue {
86        HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
87    }
88}
89
90impl Display for ServerRequestId {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        self.id.fmt(f)
93    }
94}
95
96impl<P> FromParts<P> for ServerRequestId {
97    type Rejection = MissingServerRequestId;
98
99    fn from_parts(parts: &mut Parts) -> Result<Self, Self::Rejection> {
100        parts.extensions.remove().ok_or(MissingServerRequestId)
101    }
102}
103
104impl Default for ServerRequestId {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110#[derive(Clone)]
111pub struct ServerRequestIdProvider<S> {
112    inner: S,
113    header_key: Option<HeaderName>,
114}
115
116#[derive(Debug)]
118#[non_exhaustive]
119pub struct ServerRequestIdProviderLayer {
120    header_key: Option<HeaderName>,
121}
122
123impl ServerRequestIdProviderLayer {
124    pub fn new() -> Self {
127        Self { header_key: None }
128    }
129
130    pub fn new_with_response_header(header_key: HeaderName) -> Self {
132        Self {
133            header_key: Some(header_key),
134        }
135    }
136}
137
138impl Default for ServerRequestIdProviderLayer {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl<S> Layer<S> for ServerRequestIdProviderLayer {
145    type Service = ServerRequestIdProvider<S>;
146
147    fn layer(&self, inner: S) -> Self::Service {
148        ServerRequestIdProvider {
149            inner,
150            header_key: self.header_key.clone(),
151        }
152    }
153}
154
155impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
156where
157    S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
158    S::Future: std::marker::Send + 'static,
159{
160    type Response = S::Response;
161    type Error = S::Error;
162    type Future = ServerRequestIdResponseFuture<S::Future>;
163
164    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        self.inner.poll_ready(cx)
166    }
167
168    fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
169        let request_id = ServerRequestId::new();
170        match &self.header_key {
171            Some(header_key) => {
172                req.extensions_mut().insert(request_id.clone());
173                ServerRequestIdResponseFuture {
174                    response_package: Some(ResponsePackage {
175                        request_id,
176                        header_key: header_key.clone(),
177                    }),
178                    fut: self.inner.call(req),
179                }
180            }
181            None => {
182                req.extensions_mut().insert(request_id);
183                ServerRequestIdResponseFuture {
184                    response_package: None,
185                    fut: self.inner.call(req),
186                }
187            }
188        }
189    }
190}
191
192impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
193    fn into_response(self) -> http::Response<BoxBody> {
194        internal_server_error()
195    }
196}
197
198struct ResponsePackage {
199    request_id: ServerRequestId,
200    header_key: HeaderName,
201}
202
203pin_project_lite::pin_project! {
204    pub struct ServerRequestIdResponseFuture<Fut> {
205        response_package: Option<ResponsePackage>,
206        #[pin]
207        fut: Fut,
208    }
209}
210
211impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
212where
213    Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
214{
215    type Output = Result<Fut::Ok, Fut::Error>;
216
217    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218        let this = self.project();
219        let fut = this.fut;
220        let response_package = this.response_package;
221        fut.try_poll(cx).map_ok(|mut res| {
222            if let Some(response_package) = response_package.take() {
223                res.headers_mut()
224                    .insert(response_package.header_key, response_package.request_id.to_header());
225            }
226            res
227        })
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::body::{Body, BoxBody};
235    use crate::request::Request;
236    use http::HeaderValue;
237    use std::convert::Infallible;
238    use tower::{service_fn, ServiceBuilder, ServiceExt};
239
240    #[test]
241    fn test_request_id_parsed_by_header_value_infallible() {
242        ServerRequestId::new().to_header();
243    }
244
245    #[tokio::test]
246    async fn test_request_id_in_response_header() {
247        let svc = ServiceBuilder::new()
248            .layer(&ServerRequestIdProviderLayer::new_with_response_header(
249                HeaderName::from_static("x-request-id"),
250            ))
251            .service(service_fn(|_req: Request<Body>| async move {
252                Ok::<_, Infallible>(Response::new(BoxBody::default()))
253            }));
254
255        let req = Request::new(Body::empty());
256
257        let res = svc.oneshot(req).await.unwrap();
258        let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();
259
260        assert!(HeaderValue::from_str(request_id).is_ok());
261    }
262
263    #[tokio::test]
264    async fn test_request_id_not_in_response_header() {
265        let svc = ServiceBuilder::new()
266            .layer(&ServerRequestIdProviderLayer::new())
267            .service(service_fn(|_req: Request<Body>| async move {
268                Ok::<_, Infallible>(Response::new(BoxBody::default()))
269            }));
270
271        let req = Request::new(Body::empty());
272
273        let res = svc.oneshot(req).await.unwrap();
274
275        assert!(res.headers().is_empty());
276    }
277}