1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use std::env;
use std::io::Write;
use std::sync::{Arc, Mutex};
use tracing::subscriber::DefaultGuard;
use tracing::Level;
use tracing_subscriber::fmt::TestWriter;

/// A guard that resets log capturing upon being dropped.
#[derive(Debug)]
pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard);

/// Enables output of test logs to stdout at trace level by default.
///
/// The env filter can be changed with the `RUST_LOG` environment variable.
#[must_use]
pub fn show_test_logs() -> LogCaptureGuard {
    let (mut writer, _rx) = Tee::stdout();
    writer.loud();

    let env_var = env::var("RUST_LOG").ok();
    let env_filter = env_var.as_deref().unwrap_or("trace");
    eprintln!(
        "Enabled verbose test logging with env filter {env_filter:?}. \
        You can change the env filter with the RUST_LOG environment variable."
    );

    let subscriber = tracing_subscriber::fmt()
        .with_env_filter(env_filter)
        .with_writer(Mutex::new(writer))
        .finish();
    let guard = tracing::subscriber::set_default(subscriber);
    LogCaptureGuard(guard)
}

/// Capture logs from this test.
///
/// The logs will be captured until the `DefaultGuard` is dropped.
///
/// *Why use this instead of traced_test?*
/// This captures _all_ logs, not just logs produced by the current crate.
#[must_use] // log capturing ceases the instant the `DefaultGuard` is dropped
pub fn capture_test_logs() -> (LogCaptureGuard, Rx) {
    // it may be helpful to upstream this at some point
    let (mut writer, rx) = Tee::stdout();
    if env::var("VERBOSE_TEST_LOGS").is_ok() {
        eprintln!("Enabled verbose test logging.");
        writer.loud();
    } else {
        eprintln!("To see full logs from this test set VERBOSE_TEST_LOGS=true");
    }
    let subscriber = tracing_subscriber::fmt()
        .with_max_level(Level::TRACE)
        .with_writer(Mutex::new(writer))
        .finish();
    let guard = tracing::subscriber::set_default(subscriber);
    (LogCaptureGuard(guard), rx)
}

/// Receiver for the captured logs.
pub struct Rx(Arc<Mutex<Vec<u8>>>);
impl Rx {
    /// Returns the captured logs as a string.
    ///
    /// # Panics
    /// This will panic if the logs are not valid UTF-8.
    pub fn contents(&self) -> String {
        String::from_utf8(self.0.lock().unwrap().clone()).unwrap()
    }
}

struct Tee<W> {
    buf: Arc<Mutex<Vec<u8>>>,
    quiet: bool,
    inner: W,
}

impl Tee<TestWriter> {
    fn stdout() -> (Self, Rx) {
        let buf: Arc<Mutex<Vec<u8>>> = Default::default();
        (
            Tee {
                buf: buf.clone(),
                quiet: true,
                inner: TestWriter::new(),
            },
            Rx(buf),
        )
    }
}

impl<W> Tee<W> {
    fn loud(&mut self) {
        self.quiet = false;
    }
}

impl<W> Write for Tee<W>
where
    W: Write,
{
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.buf.lock().unwrap().extend_from_slice(buf);
        if !self.quiet {
            self.inner.write_all(buf)?;
            Ok(buf.len())
        } else {
            Ok(buf.len())
        }
    }

    fn flush(&mut self) -> std::io::Result<()> {
        self.inner.flush()
    }
}