aws_smithy_http_server_python/
context.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Python context definition.
7
8use http::Extensions;
9use pyo3::{PyObject, PyResult, Python, ToPyObject};
10
11mod lambda;
12pub mod layer;
13#[cfg(test)]
14mod testing;
15
16/// PyContext is a wrapper for context object provided by the user.
17/// It injects some values (currently only [super::lambda::PyLambdaContext]) that is type-hinted by the user.
18///
19/// PyContext is initialised during the startup, it inspects the provided context object for fields
20/// that are type-hinted to inject some values provided by the framework (see [PyContext::new()]).
21///
22/// After finding fields that needs to be injected, [layer::AddPyContextLayer], a [tower::Layer],
23/// populates request-scoped values from incoming request.
24///
25/// And finally PyContext implements [ToPyObject] (so it can by passed to Python handlers)
26/// that provides [PyObject] provided by the user with the additional values injected by the framework.
27#[derive(Clone)]
28pub struct PyContext {
29    inner: PyObject,
30    // TODO(Refactor): We should ideally keep record of injectable fields in a hashmap like:
31    // `injectable_fields: HashMap<Field, Box<dyn Injectable>>` where `Injectable` provides a method to extract a `PyObject` from a `Request`,
32    // but I couldn't find a way to extract a trait object from a Python object.
33    // We could introduce a registry to keep track of every injectable type but I'm not sure that is the best way to do it,
34    // so until we found a good way to achive that, I didn't want to introduce any abstraction here and
35    // keep it simple because we only have one field that is injectable.
36    lambda_ctx: lambda::PyContextLambda,
37}
38
39impl PyContext {
40    pub fn new(inner: PyObject) -> PyResult<Self> {
41        Ok(Self {
42            lambda_ctx: lambda::PyContextLambda::new(inner.clone())?,
43            inner,
44        })
45    }
46
47    pub fn populate_from_extensions(&self, _ext: &Extensions) {
48        self.lambda_ctx
49            .populate_from_extensions(self.inner.clone(), _ext);
50    }
51}
52
53impl ToPyObject for PyContext {
54    fn to_object(&self, _py: Python<'_>) -> PyObject {
55        self.inner.clone()
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use http::Extensions;
62    use pyo3::{prelude::*, py_run};
63
64    use super::testing::get_context;
65
66    #[test]
67    fn py_context() -> PyResult<()> {
68        pyo3::prepare_freethreaded_python();
69
70        let ctx = get_context(
71            r#"
72class Context:
73    foo: int = 0
74    bar: str = 'qux'
75
76ctx = Context()
77ctx.foo = 42
78"#,
79        );
80        Python::with_gil(|py| {
81            py_run!(
82                py,
83                ctx,
84                r#"
85assert ctx.foo == 42
86assert ctx.bar == 'qux'
87# Make some modifications
88ctx.foo += 1
89ctx.bar = 'baz'
90"#
91            );
92        });
93
94        ctx.populate_from_extensions(&Extensions::new());
95
96        Python::with_gil(|py| {
97            py_run!(
98                py,
99                ctx,
100                r#"
101# Make sure we are preserving any modifications
102assert ctx.foo == 43
103assert ctx.bar == 'baz'
104"#
105            );
106        });
107
108        Ok(())
109    }
110
111    #[test]
112    fn works_with_none() -> PyResult<()> {
113        // Users can set context to `None` by explicity or implicitly by not providing a custom context class,
114        // it shouldn't be fail in that case.
115
116        pyo3::prepare_freethreaded_python();
117
118        let ctx = get_context("ctx = None");
119        Python::with_gil(|py| {
120            py_run!(py, ctx, "assert ctx is None");
121        });
122
123        Ok(())
124    }
125}