1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! A [tower::Layer] for injecting and populating [PyContext].

use std::task::{Context, Poll};

use http::{Request, Response};
use tower::{Layer, Service};

use super::PyContext;

/// AddPyContextLayer is a [tower::Layer] that populates given [PyContext] from the [Request]
/// and injects [PyContext] to the [Request] as an extension.
pub struct AddPyContextLayer {
    ctx: PyContext,
}

impl AddPyContextLayer {
    pub fn new(ctx: PyContext) -> Self {
        Self { ctx }
    }
}

impl<S> Layer<S> for AddPyContextLayer {
    type Service = AddPyContextService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        AddPyContextService {
            inner,
            ctx: self.ctx.clone(),
        }
    }
}

#[derive(Clone)]
pub struct AddPyContextService<S> {
    inner: S,
    ctx: PyContext,
}

impl<ResBody, ReqBody, S> Service<Request<ReqBody>> for AddPyContextService<S>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;

    #[inline]
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
        self.ctx.populate_from_extensions(req.extensions());
        req.extensions_mut().insert(self.ctx.clone());
        self.inner.call(req)
    }
}

#[cfg(test)]
mod tests {
    use std::convert::Infallible;

    use http::{Request, Response};
    use hyper::Body;
    use pyo3::prelude::*;
    use pyo3::types::IntoPyDict;
    use tower::{service_fn, ServiceBuilder, ServiceExt};

    use crate::context::testing::{get_context, lambda_ctx};

    use super::*;

    #[tokio::test]
    async fn populates_lambda_context() {
        pyo3::prepare_freethreaded_python();

        let ctx = get_context(
            r#"
class Context:
    counter: int = 42
    lambda_ctx: typing.Optional[LambdaContext] = None

ctx = Context()
    "#,
        );

        let svc = ServiceBuilder::new()
            .layer(AddPyContextLayer::new(ctx))
            .service(service_fn(|req: Request<Body>| async move {
                let ctx = req.extensions().get::<PyContext>().unwrap();
                let (req_id, counter) = Python::with_gil(|py| {
                    let locals = [("ctx", ctx)].into_py_dict(py);
                    py.run(
                        r#"
req_id = ctx.lambda_ctx.request_id
ctx.counter += 1
counter = ctx.counter
    "#,
                        None,
                        Some(locals),
                    )
                    .unwrap();

                    (
                        locals.get_item("req_id").unwrap().to_string(),
                        locals.get_item("counter").unwrap().to_string(),
                    )
                });
                Ok::<_, Infallible>(Response::new((req_id, counter)))
            }));

        let mut req = Request::new(Body::empty());
        req.extensions_mut().insert(lambda_ctx("my-req-id", "178"));

        let res = svc.oneshot(req).await.unwrap().into_body();

        assert_eq!(("my-req-id".to_string(), "43".to_string()), res);
    }
}