aws_smithy_http_server_python/
tls.rs1use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::PathBuf;
14use std::time::Duration;
15
16use pyo3::{pyclass, pymethods};
17use thiserror::Error;
18use tokio_rustls::rustls::{Certificate, Error as RustTlsError, PrivateKey, ServerConfig};
19
20pub mod listener;
21
22#[pyclass(name = "TlsConfig")]
29#[derive(Clone)]
30pub struct PyTlsConfig {
31 key_path: PathBuf,
35
36 cert_path: PathBuf,
40
41 reload_secs: u64,
45}
46
47impl PyTlsConfig {
48 pub fn build(&self) -> Result<ServerConfig, PyTlsConfigError> {
50 let cert_chain = self.cert_chain()?;
51 let key_der = self.key_der()?;
52 let mut config = ServerConfig::builder()
53 .with_safe_defaults()
54 .with_no_client_auth()
55 .with_single_cert(cert_chain, key_der)?;
56 config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
57 Ok(config)
58 }
59
60 pub fn reload_duration(&self) -> Duration {
62 Duration::from_secs(self.reload_secs)
63 }
64
65 fn cert_chain(&self) -> Result<Vec<Certificate>, PyTlsConfigError> {
67 let file = File::open(&self.cert_path).map_err(PyTlsConfigError::CertParse)?;
68 let mut cert_rdr = BufReader::new(file);
69 Ok(rustls_pemfile::certs(&mut cert_rdr)
70 .map_err(PyTlsConfigError::CertParse)?
71 .into_iter()
72 .map(Certificate)
73 .collect())
74 }
75
76 fn key_der(&self) -> Result<PrivateKey, PyTlsConfigError> {
78 let mut key_vec = Vec::new();
79 File::open(&self.key_path)
80 .and_then(|mut f| f.read_to_end(&mut key_vec))
81 .map_err(PyTlsConfigError::KeyParse)?;
82 if key_vec.is_empty() {
83 return Err(PyTlsConfigError::EmptyKey);
84 }
85
86 let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice())
87 .map_err(PyTlsConfigError::Pkcs8Parse)?;
88 if !pkcs8.is_empty() {
89 return Ok(PrivateKey(pkcs8.remove(0)));
90 }
91
92 let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice())
93 .map_err(PyTlsConfigError::RsaParse)?;
94 if !rsa.is_empty() {
95 return Ok(PrivateKey(rsa.remove(0)));
96 }
97
98 Err(PyTlsConfigError::EmptyKey)
99 }
100}
101
102#[pymethods]
103impl PyTlsConfig {
104 #[new]
105 #[pyo3(text_signature = "($self, *, key_path, cert_path, reload_secs=86400)")]
106 #[pyo3(signature = (key_path, cert_path, reload_secs=86400))]
107 fn py_new(key_path: PathBuf, cert_path: PathBuf, reload_secs: u64) -> Self {
108 Self {
110 key_path,
111 cert_path,
112 reload_secs,
113 }
114 }
115}
116
117#[derive(Error, Debug)]
119pub enum PyTlsConfigError {
120 #[error("could not parse certificate")]
121 CertParse(io::Error),
122 #[error("could not parse key")]
123 KeyParse(io::Error),
124 #[error("empty key")]
125 EmptyKey,
126 #[error("could not parse pkcs8 keys")]
127 Pkcs8Parse(io::Error),
128 #[error("could not parse rsa keys")]
129 RsaParse(io::Error),
130 #[error("rusttls protocol error")]
131 RustTlsError(#[from] RustTlsError),
132}
133
134#[cfg(test)]
135mod tests {
136 use std::str::FromStr;
137
138 use pyo3::{
139 prelude::*,
140 types::{IntoPyDict, PyDict},
141 };
142
143 use super::*;
144
145 const TEST_KEY: &str = concat!(
146 env!("CARGO_MANIFEST_DIR"),
147 "/../../examples/python/pokemon-service-test/tests/testdata/localhost.key"
148 );
149 const TEST_CERT: &str = concat!(
150 env!("CARGO_MANIFEST_DIR"),
151 "/../../examples/python/pokemon-service-test/tests/testdata/localhost.crt"
152 );
153
154 #[test]
155 fn creating_tls_config_in_python() -> PyResult<()> {
156 pyo3::prepare_freethreaded_python();
157
158 let config = Python::with_gil(|py| {
159 let globals = [
160 ("TEST_CERT", TEST_CERT.to_object(py)),
161 ("TEST_KEY", TEST_KEY.to_object(py)),
162 ("TlsConfig", py.get_type::<PyTlsConfig>().to_object(py)),
163 ]
164 .into_py_dict(py);
165 let locals = PyDict::new(py);
166 py.run(
167 r#"
168config = TlsConfig(key_path=TEST_KEY, cert_path=TEST_CERT, reload_secs=1000)
169"#,
170 Some(globals),
171 Some(locals),
172 )?;
173 locals
174 .get_item("config")
175 .expect("Python exception occurred during dictionary lookup")
176 .unwrap()
177 .extract::<PyTlsConfig>()
178 })?;
179
180 assert_eq!(PathBuf::from_str(TEST_KEY).unwrap(), config.key_path);
181 assert_eq!(PathBuf::from_str(TEST_CERT).unwrap(), config.cert_path);
182 assert_eq!(1000, config.reload_secs);
183
184 config.build().unwrap();
186
187 Ok(())
188 }
189}