1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Provides utilities for Python errors.

use std::fmt;

use pyo3::{PyErr, Python};

/// Wraps [PyErr] with a richer debug output that includes traceback and cause.
pub struct RichPyErr(PyErr);

impl fmt::Debug for RichPyErr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
        Python::with_gil(|py| {
            let mut debug_struct = f.debug_struct("RichPyErr");
            debug_struct
                .field("type", self.0.get_type(py))
                .field("value", self.0.value(py));

            if let Some(traceback) = self.0.traceback(py) {
                if let Ok(traceback) = traceback.format() {
                    debug_struct.field("traceback", &traceback);
                }
            }

            if let Some(cause) = self.0.cause(py) {
                debug_struct.field("cause", &rich_py_err(cause));
            }

            debug_struct.finish()
        })
    }
}

/// Wrap `err` with [RichPyErr] to have a richer debug output.
pub fn rich_py_err(err: PyErr) -> RichPyErr {
    RichPyErr(err)
}

#[cfg(test)]
mod tests {
    use pyo3::prelude::*;

    use super::*;

    #[test]
    fn rich_python_errors() -> PyResult<()> {
        pyo3::prepare_freethreaded_python();

        let py_err = Python::with_gil(|py| {
            py.run(
                r#"
def foo():
    base_err = ValueError("base error")
    raise ValueError("some python error") from base_err

def bar():
    foo()

def baz():
    bar()

baz()
"#,
                None,
                None,
            )
            .unwrap_err()
        });

        let debug_output = format!("{:?}", rich_py_err(py_err));

        // Make sure we are capturing error message
        assert!(debug_output.contains("some python error"));

        // Make sure we are capturing traceback
        assert!(debug_output.contains("foo"));
        assert!(debug_output.contains("bar"));
        assert!(debug_output.contains("baz"));

        // Make sure we are capturing cause
        assert!(debug_output.contains("base error"));

        Ok(())
    }
}