use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures_util::{ready, TryFuture};
use http::{HeaderMap, Request, Response, StatusCode, Uri};
use tower::Service;
use tracing::{debug, debug_span, instrument::Instrumented, Instrument};
use crate::shape_id::ShapeId;
use super::{MakeDebug, MakeDisplay, MakeIdentity};
pin_project_lite::pin_project! {
struct InnerFuture<Fut, ResponseMakeFmt> {
#[pin]
inner: Fut,
make: ResponseMakeFmt
}
}
impl<Fut, ResponseMakeFmt, T> Future for InnerFuture<Fut, ResponseMakeFmt>
where
Fut: TryFuture<Ok = Response<T>>,
Fut: Future<Output = Result<Fut::Ok, Fut::Error>>,
for<'a> ResponseMakeFmt: MakeDebug<&'a HeaderMap>,
for<'a> ResponseMakeFmt: MakeDisplay<StatusCode>,
{
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let response = ready!(this.inner.poll(cx))?;
{
let headers = this.make.make_debug(response.headers());
let status_code = this.make.make_display(response.status());
debug!(?headers, %status_code, "response");
}
Poll::Ready(Ok(response))
}
}
pin_project_lite::pin_project! {
pub struct InstrumentedFuture<Fut, ResponseMakeFmt> {
#[pin]
inner: Instrumented<InnerFuture<Fut, ResponseMakeFmt>>
}
}
impl<Fut, ResponseMakeFmt, T> Future for InstrumentedFuture<Fut, ResponseMakeFmt>
where
Fut: TryFuture<Ok = Response<T>>,
Fut: Future<Output = Result<Fut::Ok, Fut::Error>>,
for<'a> ResponseMakeFmt: MakeDebug<&'a HeaderMap>,
for<'a> ResponseMakeFmt: MakeDisplay<StatusCode>,
{
type Output = Fut::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
#[derive(Debug, Clone)]
pub struct InstrumentOperation<S, RequestMakeFmt = MakeIdentity, ResponseMakeFmt = MakeIdentity> {
inner: S,
operation_id: ShapeId,
make_request: RequestMakeFmt,
make_response: ResponseMakeFmt,
}
impl<S> InstrumentOperation<S> {
pub fn new(inner: S, operation_id: ShapeId) -> Self {
Self {
inner,
operation_id,
make_request: MakeIdentity,
make_response: MakeIdentity,
}
}
}
impl<S, RequestMakeFmt, ResponseMakeFmt> InstrumentOperation<S, RequestMakeFmt, ResponseMakeFmt> {
pub fn request_fmt<R>(self, make_request: R) -> InstrumentOperation<S, R, ResponseMakeFmt> {
InstrumentOperation {
inner: self.inner,
operation_id: self.operation_id,
make_request,
make_response: self.make_response,
}
}
pub fn response_fmt<R>(self, make_response: R) -> InstrumentOperation<S, RequestMakeFmt, R> {
InstrumentOperation {
inner: self.inner,
operation_id: self.operation_id,
make_request: self.make_request,
make_response,
}
}
}
impl<S, U, V, RequestMakeFmt, ResponseMakeFmt> Service<Request<U>>
for InstrumentOperation<S, RequestMakeFmt, ResponseMakeFmt>
where
S: Service<Request<U>, Response = Response<V>>,
for<'a> RequestMakeFmt: MakeDebug<&'a HeaderMap>,
for<'a> RequestMakeFmt: MakeDisplay<&'a Uri>,
ResponseMakeFmt: Clone,
for<'a> ResponseMakeFmt: MakeDebug<&'a HeaderMap>,
for<'a> ResponseMakeFmt: MakeDisplay<StatusCode>,
{
type Response = S::Response;
type Error = S::Error;
type Future = InstrumentedFuture<S::Future, ResponseMakeFmt>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<U>) -> Self::Future {
let span = {
let headers = self.make_request.make_debug(request.headers());
let uri = self.make_request.make_display(request.uri());
debug_span!("request", operation = %self.operation_id.absolute(), method = %request.method(), %uri, ?headers)
};
InstrumentedFuture {
inner: InnerFuture {
inner: self.inner.call(request),
make: self.make_response.clone(),
}
.instrument(span),
}
}
}