aws_smithy_http_server_python/
tls.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! TLS related types for Python.
7//!
8//! [PyTlsConfig] implementation is mostly borrowed from:
9//! <https://github.com/seanmonstar/warp/blob/4e9c4fd6ce238197fd1088061bbc07fa2852cb0f/src/tls.rs>
10
11use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::PathBuf;
14use std::time::Duration;
15
16use pyo3::{pyclass, pymethods};
17use thiserror::Error;
18use tokio_rustls::rustls::{Certificate, Error as RustTlsError, PrivateKey, ServerConfig};
19
20pub mod listener;
21
22/// PyTlsConfig represents TLS configuration created from Python.
23///
24/// :param key_path pathlib.Path:
25/// :param cert_path pathlib.Path:
26/// :param reload_secs int:
27/// :rtype None:
28#[pyclass(name = "TlsConfig")]
29#[derive(Clone)]
30pub struct PyTlsConfig {
31    /// Absolute path of the RSA or PKCS private key.
32    ///
33    /// :type pathlib.Path:
34    key_path: PathBuf,
35
36    /// Absolute path of the x509 certificate.
37    ///
38    /// :type pathlib.Path:
39    cert_path: PathBuf,
40
41    /// Duration to reloading certificates.
42    ///
43    /// :type int:
44    reload_secs: u64,
45}
46
47impl PyTlsConfig {
48    /// Build [ServerConfig] from [PyTlsConfig].
49    pub fn build(&self) -> Result<ServerConfig, PyTlsConfigError> {
50        let cert_chain = self.cert_chain()?;
51        let key_der = self.key_der()?;
52        let mut config = ServerConfig::builder()
53            .with_safe_defaults()
54            .with_no_client_auth()
55            .with_single_cert(cert_chain, key_der)?;
56        config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
57        Ok(config)
58    }
59
60    /// Returns reload duration.
61    pub fn reload_duration(&self) -> Duration {
62        Duration::from_secs(self.reload_secs)
63    }
64
65    /// Reads certificates from `cert_path`.
66    fn cert_chain(&self) -> Result<Vec<Certificate>, PyTlsConfigError> {
67        let file = File::open(&self.cert_path).map_err(PyTlsConfigError::CertParse)?;
68        let mut cert_rdr = BufReader::new(file);
69        Ok(rustls_pemfile::certs(&mut cert_rdr)
70            .map_err(PyTlsConfigError::CertParse)?
71            .into_iter()
72            .map(Certificate)
73            .collect())
74    }
75
76    /// Parses RSA or PKCS private key from `key_path`.
77    fn key_der(&self) -> Result<PrivateKey, PyTlsConfigError> {
78        let mut key_vec = Vec::new();
79        File::open(&self.key_path)
80            .and_then(|mut f| f.read_to_end(&mut key_vec))
81            .map_err(PyTlsConfigError::KeyParse)?;
82        if key_vec.is_empty() {
83            return Err(PyTlsConfigError::EmptyKey);
84        }
85
86        let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice())
87            .map_err(PyTlsConfigError::Pkcs8Parse)?;
88        if !pkcs8.is_empty() {
89            return Ok(PrivateKey(pkcs8.remove(0)));
90        }
91
92        let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice())
93            .map_err(PyTlsConfigError::RsaParse)?;
94        if !rsa.is_empty() {
95            return Ok(PrivateKey(rsa.remove(0)));
96        }
97
98        Err(PyTlsConfigError::EmptyKey)
99    }
100}
101
102#[pymethods]
103impl PyTlsConfig {
104    #[new]
105    #[pyo3(text_signature = "($self, *, key_path, cert_path, reload_secs=86400)")]
106    #[pyo3(signature = (key_path, cert_path, reload_secs=86400))]
107    fn py_new(key_path: PathBuf, cert_path: PathBuf, reload_secs: u64) -> Self {
108        // TODO(BugOnUpstream): `reload: &PyDelta` segfaults, create an issue on PyO3
109        Self {
110            key_path,
111            cert_path,
112            reload_secs,
113        }
114    }
115}
116
117/// Possible TLS configuration errors.
118#[derive(Error, Debug)]
119pub enum PyTlsConfigError {
120    #[error("could not parse certificate")]
121    CertParse(io::Error),
122    #[error("could not parse key")]
123    KeyParse(io::Error),
124    #[error("empty key")]
125    EmptyKey,
126    #[error("could not parse pkcs8 keys")]
127    Pkcs8Parse(io::Error),
128    #[error("could not parse rsa keys")]
129    RsaParse(io::Error),
130    #[error("rusttls protocol error")]
131    RustTlsError(#[from] RustTlsError),
132}
133
134#[cfg(test)]
135mod tests {
136    use std::str::FromStr;
137
138    use pyo3::{
139        prelude::*,
140        types::{IntoPyDict, PyDict},
141    };
142
143    use super::*;
144
145    const TEST_KEY: &str = concat!(
146        env!("CARGO_MANIFEST_DIR"),
147        "/../../examples/python/pokemon-service-test/tests/testdata/localhost.key"
148    );
149    const TEST_CERT: &str = concat!(
150        env!("CARGO_MANIFEST_DIR"),
151        "/../../examples/python/pokemon-service-test/tests/testdata/localhost.crt"
152    );
153
154    #[test]
155    fn creating_tls_config_in_python() -> PyResult<()> {
156        pyo3::prepare_freethreaded_python();
157
158        let config = Python::with_gil(|py| {
159            let globals = [
160                ("TEST_CERT", TEST_CERT.to_object(py)),
161                ("TEST_KEY", TEST_KEY.to_object(py)),
162                ("TlsConfig", py.get_type::<PyTlsConfig>().to_object(py)),
163            ]
164            .into_py_dict(py);
165            let locals = PyDict::new(py);
166            py.run(
167                r#"
168config = TlsConfig(key_path=TEST_KEY, cert_path=TEST_CERT, reload_secs=1000)
169"#,
170                Some(globals),
171                Some(locals),
172            )?;
173            locals
174                .get_item("config")
175                .expect("Python exception occurred during dictionary lookup")
176                .unwrap()
177                .extract::<PyTlsConfig>()
178        })?;
179
180        assert_eq!(PathBuf::from_str(TEST_KEY).unwrap(), config.key_path);
181        assert_eq!(PathBuf::from_str(TEST_CERT).unwrap(), config.cert_path);
182        assert_eq!(1000, config.reload_secs);
183
184        // Make sure build succeeds
185        config.build().unwrap();
186
187        Ok(())
188    }
189}