aws_smithy_http_server/request/
request_id.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! # Request IDs
7//!
8//! `aws-smithy-http-server` provides the [`ServerRequestId`].
9//!
10//! ## `ServerRequestId`
11//!
12//! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request.
13//! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and
14//! data related to a single operation.
15//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
16//!
17//! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service.
18//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
19//!
20//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
21//!
22//! ## Examples
23//!
24//! Your handler can now optionally take as input a [`ServerRequestId`].
25//!
26//! ```rust,ignore
27//! pub async fn handler(
28//!     _input: Input,
29//!     server_request_id: ServerRequestId,
30//! ) -> Output {
31//!     /* Use server_request_id */
32//!     todo!()
33//! }
34//!
35//! let config = ServiceConfig::builder()
36//!     // Generate a server request ID and add it to the response header.
37//!     .layer(ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")))
38//!     .build();
39//! let app = Service::builder(config)
40//!     .operation(handler)
41//!     .build().unwrap();
42//!
43//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
44//!     .parse()
45//!     .expect("unable to parse the server bind address and port");
46//! let server = hyper::Server::bind(&bind).serve(app.into_make_service());
47//! ```
48
49use std::future::Future;
50use std::{
51    fmt::Display,
52    task::{Context, Poll},
53};
54
55use futures_util::TryFuture;
56use http::request::Parts;
57use http::{header::HeaderName, HeaderValue, Response};
58use thiserror::Error;
59use tower::{Layer, Service};
60use uuid::Uuid;
61
62use crate::{body::BoxBody, response::IntoResponse};
63
64use super::{internal_server_error, FromParts};
65
66/// Opaque type for Server Request IDs.
67///
68/// If it is missing, the request will be rejected with a `500 Internal Server Error` response.
69#[derive(Clone, Debug)]
70pub struct ServerRequestId {
71    id: Uuid,
72}
73
74/// The server request ID has not been added to the [`Request`](http::Request) or has been previously removed.
75#[non_exhaustive]
76#[derive(Debug, Error)]
77#[error("the `ServerRequestId` is not present in the `http::Request`")]
78pub struct MissingServerRequestId;
79
80impl ServerRequestId {
81    pub fn new() -> Self {
82        Self { id: Uuid::new_v4() }
83    }
84
85    pub(crate) fn to_header(&self) -> HeaderValue {
86        HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
87    }
88}
89
90impl Display for ServerRequestId {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        self.id.fmt(f)
93    }
94}
95
96impl<P> FromParts<P> for ServerRequestId {
97    type Rejection = MissingServerRequestId;
98
99    fn from_parts(parts: &mut Parts) -> Result<Self, Self::Rejection> {
100        parts.extensions.remove().ok_or(MissingServerRequestId)
101    }
102}
103
104impl Default for ServerRequestId {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110#[derive(Clone)]
111pub struct ServerRequestIdProvider<S> {
112    inner: S,
113    header_key: Option<HeaderName>,
114}
115
116/// A layer that provides services with a unique request ID instance
117#[derive(Debug)]
118#[non_exhaustive]
119pub struct ServerRequestIdProviderLayer {
120    header_key: Option<HeaderName>,
121}
122
123impl ServerRequestIdProviderLayer {
124    /// Generate a new unique request ID and do not add it as a response header
125    /// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header
126    pub fn new() -> Self {
127        Self { header_key: None }
128    }
129
130    /// Generate a new unique request ID and add it as a response header
131    pub fn new_with_response_header(header_key: HeaderName) -> Self {
132        Self {
133            header_key: Some(header_key),
134        }
135    }
136}
137
138impl Default for ServerRequestIdProviderLayer {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl<S> Layer<S> for ServerRequestIdProviderLayer {
145    type Service = ServerRequestIdProvider<S>;
146
147    fn layer(&self, inner: S) -> Self::Service {
148        ServerRequestIdProvider {
149            inner,
150            header_key: self.header_key.clone(),
151        }
152    }
153}
154
155impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
156where
157    S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
158    S::Future: std::marker::Send + 'static,
159{
160    type Response = S::Response;
161    type Error = S::Error;
162    type Future = ServerRequestIdResponseFuture<S::Future>;
163
164    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        self.inner.poll_ready(cx)
166    }
167
168    fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
169        let request_id = ServerRequestId::new();
170        match &self.header_key {
171            Some(header_key) => {
172                req.extensions_mut().insert(request_id.clone());
173                ServerRequestIdResponseFuture {
174                    response_package: Some(ResponsePackage {
175                        request_id,
176                        header_key: header_key.clone(),
177                    }),
178                    fut: self.inner.call(req),
179                }
180            }
181            None => {
182                req.extensions_mut().insert(request_id);
183                ServerRequestIdResponseFuture {
184                    response_package: None,
185                    fut: self.inner.call(req),
186                }
187            }
188        }
189    }
190}
191
192impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
193    fn into_response(self) -> http::Response<BoxBody> {
194        internal_server_error()
195    }
196}
197
198struct ResponsePackage {
199    request_id: ServerRequestId,
200    header_key: HeaderName,
201}
202
203pin_project_lite::pin_project! {
204    pub struct ServerRequestIdResponseFuture<Fut> {
205        response_package: Option<ResponsePackage>,
206        #[pin]
207        fut: Fut,
208    }
209}
210
211impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
212where
213    Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
214{
215    type Output = Result<Fut::Ok, Fut::Error>;
216
217    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218        let this = self.project();
219        let fut = this.fut;
220        let response_package = this.response_package;
221        fut.try_poll(cx).map_ok(|mut res| {
222            if let Some(response_package) = response_package.take() {
223                res.headers_mut()
224                    .insert(response_package.header_key, response_package.request_id.to_header());
225            }
226            res
227        })
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::body::{Body, BoxBody};
235    use crate::request::Request;
236    use http::HeaderValue;
237    use std::convert::Infallible;
238    use tower::{service_fn, ServiceBuilder, ServiceExt};
239
240    #[test]
241    fn test_request_id_parsed_by_header_value_infallible() {
242        ServerRequestId::new().to_header();
243    }
244
245    #[tokio::test]
246    async fn test_request_id_in_response_header() {
247        let svc = ServiceBuilder::new()
248            .layer(&ServerRequestIdProviderLayer::new_with_response_header(
249                HeaderName::from_static("x-request-id"),
250            ))
251            .service(service_fn(|_req: Request<Body>| async move {
252                Ok::<_, Infallible>(Response::new(BoxBody::default()))
253            }));
254
255        let req = Request::new(Body::empty());
256
257        let res = svc.oneshot(req).await.unwrap();
258        let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();
259
260        assert!(HeaderValue::from_str(request_id).is_ok());
261    }
262
263    #[tokio::test]
264    async fn test_request_id_not_in_response_header() {
265        let svc = ServiceBuilder::new()
266            .layer(&ServerRequestIdProviderLayer::new())
267            .service(service_fn(|_req: Request<Body>| async move {
268                Ok::<_, Infallible>(Response::new(BoxBody::default()))
269            }));
270
271        let req = Request::new(Body::empty());
272
273        let res = svc.oneshot(req).await.unwrap();
274
275        assert!(res.headers().is_empty());
276    }
277}