aws_smithy_http_server_python/
server.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::net::TcpListener as StdTcpListener;
9use std::ops::Deref;
10use std::process;
11use std::sync::{mpsc, Arc};
12use std::thread;
13
14use aws_smithy_http_server::{
15    body::{Body, BoxBody},
16    routing::IntoMakeService,
17};
18use http::{Request, Response};
19use hyper::server::conn::AddrIncoming;
20use parking_lot::Mutex;
21use pyo3::{prelude::*, types::IntoPyDict};
22use signal_hook::{consts::*, iterator::Signals};
23use socket2::Socket;
24use tokio::{net::TcpListener, runtime};
25use tokio_rustls::TlsAcceptor;
26use tower::{util::BoxCloneService, ServiceBuilder};
27
28use crate::{
29    context::{layer::AddPyContextLayer, PyContext},
30    tls::{listener::Listener as TlsListener, PyTlsConfig},
31    util::{error::rich_py_err, func_metadata},
32    PySocket,
33};
34
35/// A Python handler function representation.
36///
37/// The Python business logic implementation needs to carry some information
38/// to be executed properly like the size of its arguments and if it is
39/// a coroutine.
40#[pyclass]
41#[derive(Debug, Clone)]
42pub struct PyHandler {
43    pub func: PyObject,
44    // Number of args is needed to decide whether handler accepts context as an argument
45    pub args: usize,
46    pub is_coroutine: bool,
47}
48
49impl Deref for PyHandler {
50    type Target = PyObject;
51
52    fn deref(&self) -> &Self::Target {
53        &self.func
54    }
55}
56
57// A `BoxCloneService` with default `Request`, `Response` and `Error`.
58type Service = BoxCloneService<Request<Body>, Response<BoxBody>, Infallible>;
59
60/// Trait defining a Python application.
61///
62/// A Python application requires handling of multiple processes, signals and allows to register Python
63/// function that will be executed as business logic by the code generated Rust handlers.
64/// To properly function, the application requires some state:
65/// * `workers`: the list of child Python worker processes, protected by a Mutex.
66/// * `context`: the optional Python object that should be passed inside the Rust state struct.
67/// * `handlers`: the mapping between an operation name and its [PyHandler] representation.
68///
69/// Since the Python application is spawning multiple workers, it also requires signal handling to allow the gracefull
70/// termination of multiple Hyper servers. The main Rust process is registering signal and using them to understand when it
71/// it time to loop through all the active workers and terminate them. Workers registers their own signal handlers and attaches
72/// them to the Python event loop, ensuring all coroutines are cancelled before terminating a worker.
73///
74/// This trait will be implemented by the code generated by the `PythonApplicationGenerator` Kotlin class.
75pub trait PyApp: Clone + pyo3::IntoPy<PyObject> {
76    /// List of active Python workers registered with this application.
77    fn workers(&self) -> &Mutex<Vec<PyObject>>;
78
79    /// Optional Python context object that will be passed as part of the Rust state.
80    fn context(&self) -> &Option<PyObject>;
81
82    /// Mapping between operation names and their `PyHandler` representation.
83    fn handlers(&mut self) -> &mut HashMap<String, PyHandler>;
84
85    /// Build the app's `Service` using given `event_loop`.
86    fn build_service(&mut self, event_loop: &pyo3::PyAny) -> pyo3::PyResult<Service>;
87
88    /// Handle the graceful termination of Python workers by looping through all the
89    /// active workers and calling `terminate()` on them. If termination fails, this
90    /// method will try to `kill()` any failed worker.
91    fn graceful_termination(&self, workers: &Mutex<Vec<PyObject>>) -> ! {
92        let workers = workers.lock();
93        for (idx, worker) in workers.iter().enumerate() {
94            let idx = idx + 1;
95            Python::with_gil(|py| {
96                let pid: isize = worker
97                    .getattr(py, "pid")
98                    .map(|pid| pid.extract(py).unwrap_or(-1))
99                    .unwrap_or(-1);
100                tracing::debug!(idx, pid, "terminating worker");
101                match worker.call_method0(py, "terminate") {
102                    Ok(_) => {}
103                    Err(e) => {
104                        tracing::error!(error = ?rich_py_err(e), idx, pid, "error terminating worker");
105                        worker
106                            .call_method0(py, "kill")
107                            .map_err(|e| {
108                                tracing::error!(
109                                    error = ?rich_py_err(e), idx, pid, "unable to kill kill worker"
110                                );
111                            })
112                            .unwrap();
113                    }
114                }
115            });
116        }
117        process::exit(0);
118    }
119
120    /// Handler the immediate termination of Python workers by looping through all the
121    /// active workers and calling `kill()` on them.
122    fn immediate_termination(&self, workers: &Mutex<Vec<PyObject>>) -> ! {
123        let workers = workers.lock();
124        for (idx, worker) in workers.iter().enumerate() {
125            let idx = idx + 1;
126            Python::with_gil(|py| {
127                let pid: isize = worker
128                    .getattr(py, "pid")
129                    .map(|pid| pid.extract(py).unwrap_or(-1))
130                    .unwrap_or(-1);
131                tracing::debug!(idx, pid, "killing worker");
132                worker
133                    .call_method0(py, "kill")
134                    .map_err(|e| {
135                        tracing::error!(error = ?rich_py_err(e), idx, pid, "unable to kill kill worker");
136                    })
137                    .unwrap();
138            });
139        }
140        process::exit(0);
141    }
142
143    /// Register and handler signals of the main Rust thread. Signals not registered
144    /// in this method are ignored.
145    ///
146    /// Signals supported:
147    ///   * SIGTERM|SIGQUIT - graceful termination of all workers.
148    ///   * SIGINT - immediate termination of all workers.
149    ///
150    /// Other signals are NOOP.
151    fn block_on_rust_signals(&self) {
152        let mut signals =
153            Signals::new([SIGINT, SIGHUP, SIGQUIT, SIGTERM, SIGUSR1, SIGUSR2, SIGWINCH])
154                .expect("Unable to register signals");
155        for sig in signals.forever() {
156            match sig {
157                SIGINT => {
158                    tracing::info!(
159                        sig = %sig, "termination signal received, all workers will be immediately terminated"
160                    );
161
162                    self.immediate_termination(self.workers());
163                }
164                SIGTERM | SIGQUIT => {
165                    tracing::info!(
166                        sig = %sig, "termination signal received, all workers will be gracefully terminated"
167                    );
168                    self.graceful_termination(self.workers());
169                }
170                _ => {
171                    tracing::debug!(sig = %sig, "signal is ignored by this application");
172                }
173            }
174        }
175    }
176
177    /// Register and handle termination of all the tasks on the Python asynchronous event loop.
178    /// We only register SIGQUIT and SIGINT since the main signal handling is done by Rust.
179    fn register_python_signals(&self, py: Python, event_loop: PyObject) -> PyResult<()> {
180        let locals = [("event_loop", event_loop)].into_py_dict(py);
181        py.run(
182            r#"
183import asyncio
184import logging
185import functools
186import signal
187
188async def shutdown(sig, event_loop):
189    # reimport asyncio and logging to be sure they are available when
190    # this handler runs on signal catching.
191    import asyncio
192    import logging
193    logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop")
194    tasks = [task for task in asyncio.all_tasks() if task is not
195             asyncio.current_task()]
196    list(map(lambda task: task.cancel(), tasks))
197    results = await asyncio.gather(*tasks, return_exceptions=True)
198    logging.debug(f"Finished awaiting cancelled tasks, results: {results}")
199    event_loop.stop()
200
201event_loop.add_signal_handler(signal.SIGTERM,
202    functools.partial(asyncio.ensure_future, shutdown(signal.SIGTERM, event_loop)))
203event_loop.add_signal_handler(signal.SIGINT,
204    functools.partial(asyncio.ensure_future, shutdown(signal.SIGINT, event_loop)))
205"#,
206            None,
207            Some(locals),
208        )?;
209        Ok(())
210    }
211
212    /// Start a single worker with its own Tokio and Python async runtime and provided shared socket.
213    ///
214    /// Python asynchronous loop needs to be started and handled during the lifetime of the process and
215    /// it is passed to this method by the caller, which can use
216    /// [configure_python_event_loop](#method.configure_python_event_loop) to properly setup it up.
217    ///
218    /// We retrieve the Python context object, if setup by the user calling [PyApp::context] method,
219    /// generate the state structure and build the [aws_smithy_http_server::routing::Router], filling
220    /// it with the functions generated by `PythonServerOperationHandlerGenerator.kt`.
221    /// At last we get a cloned reference to the underlying [socket2::Socket].
222    ///
223    /// Now that all the setup is done, we can start the two runtimes and run the [hyper] server.
224    /// We spawn a thread with a new [tokio::runtime], setup the middlewares and finally block the
225    /// thread on Hyper serve() method.
226    /// The main process continues and at the end it is blocked on Python `loop.run_forever()`.
227    ///
228    /// [uvloop]: https://github.com/MagicStack/uvloop
229    fn start_hyper_worker(
230        &mut self,
231        py: Python,
232        socket: &PyCell<PySocket>,
233        event_loop: &PyAny,
234        service: Service,
235        worker_number: isize,
236        tls: Option<PyTlsConfig>,
237    ) -> PyResult<()> {
238        // Clone the socket.
239        let borrow = socket.try_borrow_mut()?;
240        let held_socket: &PySocket = &borrow;
241        let raw_socket = held_socket.get_socket()?;
242
243        // Register signals on the Python event loop.
244        self.register_python_signals(py, event_loop.to_object(py))?;
245
246        // Spawn a new background [std::thread] to run the application.
247        // This is needed because `asyncio` doesn't work properly if it doesn't control the main thread.
248        // At the end of this function you can see we are calling `event_loop.run_forever()` to
249        // yield execution of main thread to `asyncio` runtime.
250        // For more details: https://docs.rs/pyo3-asyncio/latest/pyo3_asyncio/#pythons-event-loop-and-the-main-thread
251        tracing::trace!("start the tokio runtime in a background task");
252        thread::spawn(move || {
253            // The thread needs a new [tokio] runtime.
254            let rt = runtime::Builder::new_multi_thread()
255                .enable_all()
256                .thread_name(format!("smithy-rs-tokio[{worker_number}]"))
257                .build()
258                .expect("unable to start a new tokio runtime for this process");
259            rt.block_on(async move {
260                let addr = addr_incoming_from_socket(raw_socket);
261
262                if let Some(config) = tls {
263                    let (acceptor, acceptor_rx) = tls_config_reloader(config);
264                    let listener = TlsListener::new(acceptor, addr, acceptor_rx);
265                    let server =
266                        hyper::Server::builder(listener).serve(IntoMakeService::new(service));
267
268                    tracing::trace!("started tls hyper server from shared socket");
269                    // Run forever-ish...
270                    if let Err(err) = server.await {
271                        tracing::error!(error = ?err, "server error");
272                    }
273                } else {
274                    let server = hyper::Server::builder(addr).serve(IntoMakeService::new(service));
275
276                    tracing::trace!("started hyper server from shared socket");
277                    // Run forever-ish...
278                    if let Err(err) = server.await {
279                        tracing::error!(error = ?err, "server error");
280                    }
281                }
282            });
283        });
284        // Block on the event loop forever.
285        tracing::trace!("run and block on the python event loop until a signal is received");
286        event_loop.call_method0("run_forever")?;
287        Ok(())
288    }
289
290    /// Register a Python function to be executed inside the Smithy Rust handler.
291    ///
292    /// There are some information needed to execute the Python code from a Rust handler,
293    /// such has if the registered function needs to be awaited (if it is a coroutine) and
294    /// the number of arguments available, which tells us if the handler wants the state to be
295    /// passed or not.
296    fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> {
297        let func_metadata = func_metadata(py, &func)?;
298        let handler = PyHandler {
299            func,
300            is_coroutine: func_metadata.is_coroutine,
301            args: func_metadata.num_args,
302        };
303        tracing::info!(
304            name,
305            is_coroutine = handler.is_coroutine,
306            args = handler.args,
307            "registering handler function",
308        );
309        // Insert the handler in the handlers map.
310        self.handlers().insert(name.to_string(), handler);
311        Ok(())
312    }
313
314    /// Configure the Python asyncio event loop.
315    ///
316    /// First of all we install [uvloop] as the main Python event loop. Thanks to libuv, uvloop
317    /// performs ~20% better than Python standard event loop in most benchmarks, while being 100%
318    /// compatible. If [uvloop] is not available as a dependency, we just fall back to the standard
319    /// Python event loop.
320    ///
321    /// [uvloop]: https://github.com/MagicStack/uvloop
322    fn configure_python_event_loop<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
323        let asyncio = py.import("asyncio")?;
324        match py.import("uvloop") {
325            Ok(uvloop) => {
326                uvloop.call_method0("install")?;
327                tracing::trace!("setting up uvloop for current process");
328            }
329            Err(_) => {
330                tracing::warn!("uvloop not found, using python standard event loop, which could have worse performance than uvloop");
331            }
332        }
333        let event_loop = asyncio.call_method0("new_event_loop")?;
334        asyncio.call_method1("set_event_loop", (event_loop,))?;
335        Ok(event_loop)
336    }
337
338    /// Main entrypoint: start the server on multiple workers.
339    ///
340    /// The multiprocessing server is achieved using the ability of a Python interpreter
341    /// to clone and start itself as a new process.
342    /// The shared sockets is created and Using the [multiprocessing::Process] module, multiple
343    /// workers with the method `self.start_worker()` as target are started.
344    ///
345    /// NOTE: this method ends up calling `self.start_worker` from the Python context, forcing
346    /// the struct implementing this trait to also implement a `start_worker` method.
347    /// This is done to ensure the Python event loop is started in the right child process space before being
348    /// passed to `start_hyper_worker`.
349    ///
350    /// `PythonApplicationGenerator.kt` generates the `start_worker` method:
351    ///
352    /// ```no_run
353    ///     use std::convert::Infallible;
354    ///     use std::collections::HashMap;
355    ///     use pyo3::prelude::*;
356    ///     use aws_smithy_http_server_python::{PyApp, PyHandler};
357    ///     use aws_smithy_http_server::body::{Body, BoxBody};
358    ///     use parking_lot::Mutex;
359    ///     use http::{Request, Response};
360    ///     use tower::util::BoxCloneService;
361    ///
362    ///     #[pyclass]
363    ///     #[derive(Debug, Clone)]
364    ///     pub struct App {};
365    ///
366    ///     impl PyApp for App {
367    ///         fn workers(&self) -> &Mutex<Vec<PyObject>> { todo!() }
368    ///         fn context(&self) -> &Option<PyObject> { todo!() }
369    ///         fn handlers(&mut self) -> &mut HashMap<String, PyHandler> { todo!() }
370    ///         fn build_service(&mut self, event_loop: &PyAny) -> PyResult<BoxCloneService<Request<Body>, Response<BoxBody>, Infallible>> { todo!() }
371    ///     }
372    ///
373    ///     #[pymethods]
374    ///     impl App {
375    ///     #[pyo3(text_signature = "($self, socket, worker_number, tls)")]
376    ///         pub fn start_worker(
377    ///             &mut self,
378    ///             py: pyo3::Python,
379    ///             socket: &pyo3::PyCell<aws_smithy_http_server_python::PySocket>,
380    ///             worker_number: isize,
381    ///             tls: Option<aws_smithy_http_server_python::tls::PyTlsConfig>,
382    ///         ) -> pyo3::PyResult<()> {
383    ///             let event_loop = self.configure_python_event_loop(py)?;
384    ///             let service = self.build_service(event_loop)?;
385    ///             self.start_hyper_worker(py, socket, event_loop, service, worker_number, tls)
386    ///         }
387    ///     }
388    /// ```
389    ///
390    /// [multiprocessing::Process]: https://docs.python.org/3/library/multiprocessing.html
391    fn run_server(
392        &mut self,
393        py: Python,
394        address: Option<String>,
395        port: Option<i32>,
396        backlog: Option<i32>,
397        workers: Option<usize>,
398        tls: Option<PyTlsConfig>,
399    ) -> PyResult<()> {
400        // Setup multiprocessing environment, allowing connections and socket
401        // sharing between processes.
402        let mp = py.import("multiprocessing")?;
403        // https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/multiprocessing/context.py#L164
404        mp.call_method0("allow_connection_pickling")?;
405
406        // Starting from Python 3.8, on macOS, the spawn start method is now the default. See bpo-33725.
407        // This forces the `PyApp` class to be pickled when it is shared between different process,
408        // which is currently not supported by PyO3 classes.
409        //
410        // Forcing the multiprocessing start method to fork is a workaround for it.
411        // https://github.com/pytest-dev/pytest-flask/issues/104#issuecomment-577908228
412        #[cfg(target_os = "macos")]
413        mp.call_method(
414            "set_start_method",
415            ("fork",),
416            // We need to pass `force=True` to prevent `context has already been set` exception,
417            // see https://github.com/pytorch/pytorch/issues/3492
418            Some(vec![("force", true)].into_py_dict(py)),
419        )?;
420
421        let address = address.unwrap_or_else(|| String::from("127.0.0.1"));
422        let port = port.unwrap_or(13734);
423        let socket = PySocket::new(address, port, backlog)?;
424        // Lock the workers mutex.
425        let mut active_workers = self.workers().lock();
426        // Register the main signal handler.
427        // TODO(move from num_cpus to thread::available_parallelism after MSRV is 1.60)
428        // Start all the workers as new Python processes and store the in the `workers` attribute.
429        for idx in 1..workers.unwrap_or_else(num_cpus::get) + 1 {
430            let sock = socket.try_clone()?;
431            let tls = tls.clone();
432            let process = mp.getattr("Process")?;
433            let handle = process.call1((
434                py.None(),
435                self.clone().into_py(py).getattr(py, "start_worker")?,
436                format!("smithy-rs-worker[{idx}]"),
437                (sock.into_py(py), idx, tls.into_py(py)),
438            ))?;
439            handle.call_method0("start")?;
440            active_workers.push(handle.to_object(py));
441        }
442        // Unlock the workers mutex.
443        drop(active_workers);
444        tracing::trace!("rust python server started successfully");
445        self.block_on_rust_signals();
446        Ok(())
447    }
448
449    /// Lambda main entrypoint: start the handler on Lambda.
450    fn run_lambda_handler(&mut self, py: Python) -> PyResult<()> {
451        use aws_smithy_http_server::routing::LambdaHandler;
452
453        let event_loop = self.configure_python_event_loop(py)?;
454        // Register signals on the Python event loop.
455        self.register_python_signals(py, event_loop.to_object(py))?;
456
457        let service = self.build_and_configure_service(py, event_loop)?;
458
459        // Spawn a new background [std::thread] to run the application.
460        // This is needed because `asyncio` doesn't work properly if it doesn't control the main thread.
461        // At the end of this function you can see we are calling `event_loop.run_forever()` to
462        // yield execution of main thread to `asyncio` runtime.
463        // For more details: https://docs.rs/pyo3-asyncio/latest/pyo3_asyncio/#pythons-event-loop-and-the-main-thread
464        tracing::trace!("start the tokio runtime in a background task");
465        thread::spawn(move || {
466            let rt = runtime::Builder::new_multi_thread()
467                .enable_all()
468                .build()
469                .expect("unable to start a new tokio runtime for this process");
470            rt.block_on(async move {
471                let handler = LambdaHandler::new(service);
472                let lambda = lambda_http::run(handler);
473                tracing::debug!("starting lambda handler");
474                if let Err(err) = lambda.await {
475                    tracing::error!(error = %err, "unable to start lambda handler");
476                }
477            });
478        });
479        // Block on the event loop forever.
480        tracing::trace!("run and block on the python event loop until a signal is received");
481        event_loop.call_method0("run_forever")?;
482        Ok(())
483    }
484
485    // Builds the `Service` and adds necessary layers to it.
486    fn build_and_configure_service(
487        &mut self,
488        py: Python,
489        event_loop: &pyo3::PyAny,
490    ) -> pyo3::PyResult<Service> {
491        let service = self.build_service(event_loop)?;
492        let context = PyContext::new(self.context().clone().unwrap_or_else(|| py.None()))?;
493        let service = ServiceBuilder::new()
494            .boxed_clone()
495            .layer(AddPyContextLayer::new(context))
496            .service(service);
497        Ok(service)
498    }
499}
500
501fn addr_incoming_from_socket(socket: Socket) -> AddrIncoming {
502    let std_listener: StdTcpListener = socket.into();
503    // StdTcpListener::from_std doesn't set O_NONBLOCK
504    std_listener
505        .set_nonblocking(true)
506        .expect("unable to set `O_NONBLOCK=true` on `std::net::TcpListener`");
507    let listener = TcpListener::from_std(std_listener)
508        .expect("unable to create `tokio::net::TcpListener` from `std::net::TcpListener`");
509    AddrIncoming::from_listener(listener)
510        .expect("unable to create `AddrIncoming` from `TcpListener`")
511}
512
513// Builds `TlsAcceptor` from given `config` and also creates a background task
514// to reload certificates and returns a channel to receive new `TlsAcceptor`s.
515fn tls_config_reloader(config: PyTlsConfig) -> (TlsAcceptor, mpsc::Receiver<TlsAcceptor>) {
516    let reload_dur = config.reload_duration();
517    let (tx, rx) = mpsc::channel();
518    let acceptor = TlsAcceptor::from(Arc::new(config.build().expect("invalid tls config")));
519
520    tokio::spawn(async move {
521        tracing::trace!(dur = ?reload_dur, "starting timer to reload tls config");
522        loop {
523            tokio::time::sleep(reload_dur).await;
524            tracing::trace!("reloading tls config");
525            match config.build() {
526                Ok(config) => {
527                    let new_config = TlsAcceptor::from(Arc::new(config));
528                    // Note on expect: `tx.send` can only fail if the receiver is dropped,
529                    // it probably a bug if that happens
530                    tx.send(new_config).expect("could not send new tls config")
531                }
532                Err(err) => {
533                    tracing::error!(error = ?err, "could not reload tls config because it is invalid");
534                }
535            }
536        }
537    });
538
539    (acceptor, rx)
540}