aws_smithy_http_server/request/
request_id.rs1use std::future::Future;
50use std::{
51 fmt::Display,
52 task::{Context, Poll},
53};
54
55use futures_util::TryFuture;
56use http::request::Parts;
57use http::{header::HeaderName, HeaderValue, Response};
58use thiserror::Error;
59use tower::{Layer, Service};
60use uuid::Uuid;
61
62use crate::{body::BoxBody, response::IntoResponse};
63
64use super::{internal_server_error, FromParts};
65
66#[derive(Clone, Debug)]
70pub struct ServerRequestId {
71 id: Uuid,
72}
73
74#[non_exhaustive]
76#[derive(Debug, Error)]
77#[error("the `ServerRequestId` is not present in the `http::Request`")]
78pub struct MissingServerRequestId;
79
80impl ServerRequestId {
81 pub fn new() -> Self {
82 Self { id: Uuid::new_v4() }
83 }
84
85 pub(crate) fn to_header(&self) -> HeaderValue {
86 HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
87 }
88}
89
90impl Display for ServerRequestId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 self.id.fmt(f)
93 }
94}
95
96impl<P> FromParts<P> for ServerRequestId {
97 type Rejection = MissingServerRequestId;
98
99 fn from_parts(parts: &mut Parts) -> Result<Self, Self::Rejection> {
100 parts.extensions.remove().ok_or(MissingServerRequestId)
101 }
102}
103
104impl Default for ServerRequestId {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[derive(Clone)]
111pub struct ServerRequestIdProvider<S> {
112 inner: S,
113 header_key: Option<HeaderName>,
114}
115
116#[derive(Debug)]
118#[non_exhaustive]
119pub struct ServerRequestIdProviderLayer {
120 header_key: Option<HeaderName>,
121}
122
123impl ServerRequestIdProviderLayer {
124 pub fn new() -> Self {
127 Self { header_key: None }
128 }
129
130 pub fn new_with_response_header(header_key: HeaderName) -> Self {
132 Self {
133 header_key: Some(header_key),
134 }
135 }
136}
137
138impl Default for ServerRequestIdProviderLayer {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl<S> Layer<S> for ServerRequestIdProviderLayer {
145 type Service = ServerRequestIdProvider<S>;
146
147 fn layer(&self, inner: S) -> Self::Service {
148 ServerRequestIdProvider {
149 inner,
150 header_key: self.header_key.clone(),
151 }
152 }
153}
154
155impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
156where
157 S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
158 S::Future: std::marker::Send + 'static,
159{
160 type Response = S::Response;
161 type Error = S::Error;
162 type Future = ServerRequestIdResponseFuture<S::Future>;
163
164 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165 self.inner.poll_ready(cx)
166 }
167
168 fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
169 let request_id = ServerRequestId::new();
170 match &self.header_key {
171 Some(header_key) => {
172 req.extensions_mut().insert(request_id.clone());
173 ServerRequestIdResponseFuture {
174 response_package: Some(ResponsePackage {
175 request_id,
176 header_key: header_key.clone(),
177 }),
178 fut: self.inner.call(req),
179 }
180 }
181 None => {
182 req.extensions_mut().insert(request_id);
183 ServerRequestIdResponseFuture {
184 response_package: None,
185 fut: self.inner.call(req),
186 }
187 }
188 }
189 }
190}
191
192impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
193 fn into_response(self) -> http::Response<BoxBody> {
194 internal_server_error()
195 }
196}
197
198struct ResponsePackage {
199 request_id: ServerRequestId,
200 header_key: HeaderName,
201}
202
203pin_project_lite::pin_project! {
204 pub struct ServerRequestIdResponseFuture<Fut> {
205 response_package: Option<ResponsePackage>,
206 #[pin]
207 fut: Fut,
208 }
209}
210
211impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
212where
213 Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
214{
215 type Output = Result<Fut::Ok, Fut::Error>;
216
217 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218 let this = self.project();
219 let fut = this.fut;
220 let response_package = this.response_package;
221 fut.try_poll(cx).map_ok(|mut res| {
222 if let Some(response_package) = response_package.take() {
223 res.headers_mut()
224 .insert(response_package.header_key, response_package.request_id.to_header());
225 }
226 res
227 })
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::body::{Body, BoxBody};
235 use crate::request::Request;
236 use http::HeaderValue;
237 use std::convert::Infallible;
238 use tower::{service_fn, ServiceBuilder, ServiceExt};
239
240 #[test]
241 fn test_request_id_parsed_by_header_value_infallible() {
242 ServerRequestId::new().to_header();
243 }
244
245 #[tokio::test]
246 async fn test_request_id_in_response_header() {
247 let svc = ServiceBuilder::new()
248 .layer(&ServerRequestIdProviderLayer::new_with_response_header(
249 HeaderName::from_static("x-request-id"),
250 ))
251 .service(service_fn(|_req: Request<Body>| async move {
252 Ok::<_, Infallible>(Response::new(BoxBody::default()))
253 }));
254
255 let req = Request::new(Body::empty());
256
257 let res = svc.oneshot(req).await.unwrap();
258 let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();
259
260 assert!(HeaderValue::from_str(request_id).is_ok());
261 }
262
263 #[tokio::test]
264 async fn test_request_id_not_in_response_header() {
265 let svc = ServiceBuilder::new()
266 .layer(&ServerRequestIdProviderLayer::new())
267 .service(service_fn(|_req: Request<Body>| async move {
268 Ok::<_, Infallible>(Response::new(BoxBody::default()))
269 }));
270
271 let req = Request::new(Body::empty());
272
273 let res = svc.oneshot(req).await.unwrap();
274
275 assert!(res.headers().is_empty());
276 }
277}