use std::{path::PathBuf, str::FromStr};
use pyo3::prelude::*;
#[cfg(not(test))]
use tracing::span;
use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{
fmt::{self, writer::MakeWriterExt},
layer::SubscriberExt,
util::SubscriberInitExt,
Layer,
};
use crate::error::PyException;
#[derive(Debug, Default)]
enum Format {
Json,
Pretty,
#[default]
Compact,
}
#[derive(Debug, PartialEq, Eq)]
struct InvalidFormatError;
impl FromStr for Format {
type Err = InvalidFormatError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"pretty" => Ok(Self::Pretty),
"json" => Ok(Self::Json),
"compact" => Ok(Self::Compact),
_ => Err(InvalidFormatError),
}
}
}
fn setup_tracing_subscriber(
level: Option<u8>,
logfile: Option<PathBuf>,
format: Option<String>,
) -> PyResult<Option<WorkerGuard>> {
let format = match format {
Some(format) => Format::from_str(&format).unwrap_or_else(|_| {
tracing::error!("unknown format '{format}', falling back to default formatter");
Format::default()
}),
None => Format::default(),
};
let appender = match logfile {
Some(logfile) => {
let parent = logfile.parent().ok_or_else(|| {
PyException::new_err(format!(
"Tracing setup failed: unable to extract dirname from path {}",
logfile.display()
))
})?;
let filename = logfile.file_name().ok_or_else(|| {
PyException::new_err(format!(
"Tracing setup failed: unable to extract basename from path {}",
logfile.display()
))
})?;
let file_appender = tracing_appender::rolling::hourly(parent, filename);
let (appender, guard) = tracing_appender::non_blocking(file_appender);
Some((appender, guard))
}
None => None,
};
let tracing_level = match level {
Some(40u8) => Level::ERROR,
Some(30u8) => Level::WARN,
Some(20u8) => Level::INFO,
Some(10u8) => Level::DEBUG,
None => Level::INFO,
_ => Level::TRACE,
};
let formatter = fmt::Layer::new().with_line_number(true).with_level(true);
match appender {
Some((appender, guard)) => {
let formatter = formatter.with_writer(appender.with_max_level(tracing_level));
let formatter = match format {
Format::Json => formatter.json().boxed(),
Format::Compact => formatter.compact().boxed(),
Format::Pretty => formatter.pretty().boxed(),
};
tracing_subscriber::registry().with(formatter).init();
Ok(Some(guard))
}
None => {
let formatter = formatter.with_writer(std::io::stdout.with_max_level(tracing_level));
let formatter = match format {
Format::Json => formatter.json().boxed(),
Format::Compact => formatter.compact().boxed(),
Format::Pretty => formatter.pretty().boxed(),
};
tracing_subscriber::registry().with(formatter).init();
Ok(None)
}
}
}
#[pyclass(name = "TracingHandler")]
#[pyo3(text_signature = "($self, level=None, logfile=None, format=None)")]
#[derive(Debug)]
pub struct PyTracingHandler {
_guard: Option<WorkerGuard>,
}
#[pymethods]
impl PyTracingHandler {
#[new]
fn newpy(
py: Python,
level: Option<u8>,
logfile: Option<PathBuf>,
format: Option<String>,
) -> PyResult<Self> {
let _guard = setup_tracing_subscriber(level, logfile, format)?;
let logging = py.import("logging")?;
let root = logging.getattr("root")?;
root.setattr("level", level)?;
Ok(Self { _guard })
}
fn handler(&self, py: Python) -> PyResult<Py<PyAny>> {
let logging = py.import("logging")?;
logging.setattr(
"py_tracing_event",
wrap_pyfunction!(py_tracing_event, logging)?,
)?;
let pycode = r#"
class TracingHandler(Handler):
""" Python logging to Rust tracing handler. """
def emit(self, record):
py_tracing_event(
record.levelno, record.getMessage(), record.module,
record.filename, record.lineno, record.process
)
"#;
py.run(pycode, Some(logging.dict()), None)?;
let all = logging.index()?;
all.append("TracingHandler")?;
let handler = logging.getattr("TracingHandler")?;
Ok(handler.call0()?.into_py(py))
}
}
#[cfg(not(test))]
#[pyfunction]
#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")]
pub fn py_tracing_event(
level: u8,
message: &str,
module: &str,
filename: &str,
lineno: usize,
pid: usize,
) -> PyResult<()> {
let span = span!(
Level::TRACE,
"python",
pid = pid,
module = module,
filename = filename,
lineno = lineno
);
let _guard = span.enter();
match level {
40 => tracing::error!("{message}"),
30 => tracing::warn!("{message}"),
20 => tracing::info!("{message}"),
10 => tracing::debug!("{message}"),
_ => tracing::trace!("{message}"),
};
Ok(())
}
#[cfg(test)]
#[pyfunction]
#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")]
pub fn py_tracing_event(
_level: u8,
message: &str,
_module: &str,
_filename: &str,
_line: usize,
_pid: usize,
) -> PyResult<()> {
pretty_assertions::assert_eq!(message.to_string(), "a message");
Ok(())
}
#[cfg(test)]
mod tests {
use pyo3::types::PyDict;
use super::*;
#[test]
fn tracing_handler_is_injected_in_python() {
crate::tests::initialize();
Python::with_gil(|py| {
let handler = PyTracingHandler::newpy(py, Some(10), None, None).unwrap();
let kwargs = PyDict::new(py);
kwargs
.set_item("handlers", vec![handler.handler(py).unwrap()])
.unwrap();
let logging = py.import("logging").unwrap();
let basic_config = logging.getattr("basicConfig").unwrap();
basic_config.call((), Some(kwargs)).unwrap();
logging.call_method1("info", ("a message",)).unwrap();
});
}
}