aws_smithy_http_server_python/context/
layer.rs1use std::task::{Context, Poll};
9
10use http::{Request, Response};
11use tower::{Layer, Service};
12
13use super::PyContext;
14
15pub struct AddPyContextLayer {
18 ctx: PyContext,
19}
20
21impl AddPyContextLayer {
22 pub fn new(ctx: PyContext) -> Self {
23 Self { ctx }
24 }
25}
26
27impl<S> Layer<S> for AddPyContextLayer {
28 type Service = AddPyContextService<S>;
29
30 fn layer(&self, inner: S) -> Self::Service {
31 AddPyContextService {
32 inner,
33 ctx: self.ctx.clone(),
34 }
35 }
36}
37
38#[derive(Clone)]
39pub struct AddPyContextService<S> {
40 inner: S,
41 ctx: PyContext,
42}
43
44impl<ResBody, ReqBody, S> Service<Request<ReqBody>> for AddPyContextService<S>
45where
46 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
47{
48 type Response = S::Response;
49 type Error = S::Error;
50 type Future = S::Future;
51
52 #[inline]
53 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54 self.inner.poll_ready(cx)
55 }
56
57 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
58 self.ctx.populate_from_extensions(req.extensions());
59 req.extensions_mut().insert(self.ctx.clone());
60 self.inner.call(req)
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use std::convert::Infallible;
67
68 use http::{Request, Response};
69 use hyper::Body;
70 use pyo3::prelude::*;
71 use pyo3::types::IntoPyDict;
72 use tower::{service_fn, ServiceBuilder, ServiceExt};
73
74 use crate::context::testing::{get_context, lambda_ctx};
75
76 use super::*;
77
78 #[tokio::test]
79 async fn populates_lambda_context() {
80 pyo3::prepare_freethreaded_python();
81
82 let ctx = get_context(
83 r#"
84class Context:
85 counter: int = 42
86 lambda_ctx: typing.Optional[LambdaContext] = None
87
88ctx = Context()
89 "#,
90 );
91
92 let svc = ServiceBuilder::new()
93 .layer(AddPyContextLayer::new(ctx))
94 .service(service_fn(|req: Request<Body>| async move {
95 let ctx = req.extensions().get::<PyContext>().unwrap();
96 let (req_id, counter) = Python::with_gil(|py| {
97 let locals = [("ctx", ctx)].into_py_dict(py);
98 py.run(
99 r#"
100req_id = ctx.lambda_ctx.request_id
101ctx.counter += 1
102counter = ctx.counter
103 "#,
104 None,
105 Some(locals),
106 )
107 .unwrap();
108
109 (
110 locals.get_item("req_id").unwrap().to_string(),
111 locals.get_item("counter").unwrap().to_string(),
112 )
113 });
114 Ok::<_, Infallible>(Response::new((req_id, counter)))
115 }));
116
117 let mut req = Request::new(Body::empty());
118 req.extensions_mut().insert(lambda_ctx("my-req-id", "178"));
119
120 let res = svc.oneshot(req).await.unwrap().into_body();
121
122 assert_eq!(("my-req-id".to_string(), "43".to_string()), res);
123 }
124}