aws_smithy_http_server_python/context/
layer.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! A [tower::Layer] for injecting and populating [PyContext].
7
8use std::task::{Context, Poll};
9
10use http::{Request, Response};
11use tower::{Layer, Service};
12
13use super::PyContext;
14
15/// AddPyContextLayer is a [tower::Layer] that populates given [PyContext] from the [Request]
16/// and injects [PyContext] to the [Request] as an extension.
17pub struct AddPyContextLayer {
18    ctx: PyContext,
19}
20
21impl AddPyContextLayer {
22    pub fn new(ctx: PyContext) -> Self {
23        Self { ctx }
24    }
25}
26
27impl<S> Layer<S> for AddPyContextLayer {
28    type Service = AddPyContextService<S>;
29
30    fn layer(&self, inner: S) -> Self::Service {
31        AddPyContextService {
32            inner,
33            ctx: self.ctx.clone(),
34        }
35    }
36}
37
38#[derive(Clone)]
39pub struct AddPyContextService<S> {
40    inner: S,
41    ctx: PyContext,
42}
43
44impl<ResBody, ReqBody, S> Service<Request<ReqBody>> for AddPyContextService<S>
45where
46    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
47{
48    type Response = S::Response;
49    type Error = S::Error;
50    type Future = S::Future;
51
52    #[inline]
53    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54        self.inner.poll_ready(cx)
55    }
56
57    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
58        self.ctx.populate_from_extensions(req.extensions());
59        req.extensions_mut().insert(self.ctx.clone());
60        self.inner.call(req)
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use std::convert::Infallible;
67
68    use http::{Request, Response};
69    use hyper::Body;
70    use pyo3::prelude::*;
71    use pyo3::types::IntoPyDict;
72    use tower::{service_fn, ServiceBuilder, ServiceExt};
73
74    use crate::context::testing::{get_context, lambda_ctx};
75
76    use super::*;
77
78    #[tokio::test]
79    async fn populates_lambda_context() {
80        pyo3::prepare_freethreaded_python();
81
82        let ctx = get_context(
83            r#"
84class Context:
85    counter: int = 42
86    lambda_ctx: typing.Optional[LambdaContext] = None
87
88ctx = Context()
89    "#,
90        );
91
92        let svc = ServiceBuilder::new()
93            .layer(AddPyContextLayer::new(ctx))
94            .service(service_fn(|req: Request<Body>| async move {
95                let ctx = req.extensions().get::<PyContext>().unwrap();
96                let (req_id, counter) = Python::with_gil(|py| {
97                    let locals = [("ctx", ctx)].into_py_dict(py);
98                    py.run(
99                        r#"
100req_id = ctx.lambda_ctx.request_id
101ctx.counter += 1
102counter = ctx.counter
103    "#,
104                        None,
105                        Some(locals),
106                    )
107                    .unwrap();
108
109                    (
110                        locals.get_item("req_id").unwrap().to_string(),
111                        locals.get_item("counter").unwrap().to_string(),
112                    )
113                });
114                Ok::<_, Infallible>(Response::new((req_id, counter)))
115            }));
116
117        let mut req = Request::new(Body::empty());
118        req.extensions_mut().insert(lambda_ctx("my-req-id", "178"));
119
120        let res = svc.oneshot(req).await.unwrap().into_body();
121
122        assert_eq!(("my-req-id".to_string(), "43".to_string()), res);
123    }
124}