aws_smithy_http_server_python/
tls.rs1use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::PathBuf;
14use std::time::Duration;
15
16use pyo3::{pyclass, pymethods};
17use thiserror::Error;
18use tokio_rustls::rustls::{Certificate, Error as RustTlsError, PrivateKey, ServerConfig};
19
20pub mod listener;
21
22#[pyclass(
29 name = "TlsConfig",
30 text_signature = "($self, *, key_path, cert_path, reload_secs=86400)"
31)]
32#[derive(Clone)]
33pub struct PyTlsConfig {
34 key_path: PathBuf,
38
39 cert_path: PathBuf,
43
44 reload_secs: u64,
48}
49
50impl PyTlsConfig {
51 pub fn build(&self) -> Result<ServerConfig, PyTlsConfigError> {
53 let cert_chain = self.cert_chain()?;
54 let key_der = self.key_der()?;
55 let mut config = ServerConfig::builder()
56 .with_safe_defaults()
57 .with_no_client_auth()
58 .with_single_cert(cert_chain, key_der)?;
59 config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
60 Ok(config)
61 }
62
63 pub fn reload_duration(&self) -> Duration {
65 Duration::from_secs(self.reload_secs)
66 }
67
68 fn cert_chain(&self) -> Result<Vec<Certificate>, PyTlsConfigError> {
70 let file = File::open(&self.cert_path).map_err(PyTlsConfigError::CertParse)?;
71 let mut cert_rdr = BufReader::new(file);
72 Ok(rustls_pemfile::certs(&mut cert_rdr)
73 .map_err(PyTlsConfigError::CertParse)?
74 .into_iter()
75 .map(Certificate)
76 .collect())
77 }
78
79 fn key_der(&self) -> Result<PrivateKey, PyTlsConfigError> {
81 let mut key_vec = Vec::new();
82 File::open(&self.key_path)
83 .and_then(|mut f| f.read_to_end(&mut key_vec))
84 .map_err(PyTlsConfigError::KeyParse)?;
85 if key_vec.is_empty() {
86 return Err(PyTlsConfigError::EmptyKey);
87 }
88
89 let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice())
90 .map_err(PyTlsConfigError::Pkcs8Parse)?;
91 if !pkcs8.is_empty() {
92 return Ok(PrivateKey(pkcs8.remove(0)));
93 }
94
95 let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice())
96 .map_err(PyTlsConfigError::RsaParse)?;
97 if !rsa.is_empty() {
98 return Ok(PrivateKey(rsa.remove(0)));
99 }
100
101 Err(PyTlsConfigError::EmptyKey)
102 }
103}
104
105#[pymethods]
106impl PyTlsConfig {
107 #[new]
108 #[pyo3(signature = (key_path, cert_path, reload_secs=86400))]
109 fn py_new(key_path: PathBuf, cert_path: PathBuf, reload_secs: u64) -> Self {
110 Self {
112 key_path,
113 cert_path,
114 reload_secs,
115 }
116 }
117}
118
119#[derive(Error, Debug)]
121pub enum PyTlsConfigError {
122 #[error("could not parse certificate")]
123 CertParse(io::Error),
124 #[error("could not parse key")]
125 KeyParse(io::Error),
126 #[error("empty key")]
127 EmptyKey,
128 #[error("could not parse pkcs8 keys")]
129 Pkcs8Parse(io::Error),
130 #[error("could not parse rsa keys")]
131 RsaParse(io::Error),
132 #[error("rusttls protocol error")]
133 RustTlsError(#[from] RustTlsError),
134}
135
136#[cfg(test)]
137mod tests {
138 use std::str::FromStr;
139
140 use pyo3::{
141 prelude::*,
142 types::{IntoPyDict, PyDict},
143 };
144
145 use super::*;
146
147 const TEST_KEY: &str = concat!(
148 env!("CARGO_MANIFEST_DIR"),
149 "/../../examples/python/pokemon-service-test/tests/testdata/localhost.key"
150 );
151 const TEST_CERT: &str = concat!(
152 env!("CARGO_MANIFEST_DIR"),
153 "/../../examples/python/pokemon-service-test/tests/testdata/localhost.crt"
154 );
155
156 #[test]
157 fn creating_tls_config_in_python() -> PyResult<()> {
158 pyo3::prepare_freethreaded_python();
159
160 let config = Python::with_gil(|py| {
161 let globals = [
162 ("TEST_CERT", TEST_CERT.to_object(py)),
163 ("TEST_KEY", TEST_KEY.to_object(py)),
164 ("TlsConfig", py.get_type::<PyTlsConfig>().to_object(py)),
165 ]
166 .into_py_dict(py);
167 let locals = PyDict::new(py);
168 py.run(
169 r#"
170config = TlsConfig(key_path=TEST_KEY, cert_path=TEST_CERT, reload_secs=1000)
171"#,
172 Some(globals),
173 Some(locals),
174 )?;
175 locals.get_item("config").unwrap().extract::<PyTlsConfig>()
176 })?;
177
178 assert_eq!(PathBuf::from_str(TEST_KEY).unwrap(), config.key_path);
179 assert_eq!(PathBuf::from_str(TEST_CERT).unwrap(), config.cert_path);
180 assert_eq!(1000, config.reload_secs);
181
182 config.build().unwrap();
184
185 Ok(())
186 }
187}