tracing_test_macro/
lib.rs

1//! # tracing_test_macro
2//!
3//! This crate provides a procedural macro that can be added to test functions in order to ensure
4//! that all tracing logs are written to a global buffer.
5//!
6//! You should not use this crate directly. Instead, use the macro through [tracing-test].
7//!
8//! [tracing-test]: https://docs.rs/tracing-test
9extern crate proc_macro;
10
11use std::sync::{Mutex, OnceLock};
12
13use proc_macro::TokenStream;
14use quote::{quote, ToTokens};
15use syn::{parse, ItemFn, Stmt};
16
17/// Registered scopes.
18///
19/// By default, every traced test registers a span with the function name.
20/// However, since multiple tests can share the same function name, in case
21/// of conflict, a counter is appended.
22///
23/// This vector is used to store all already registered scopes.
24fn registered_scopes() -> &'static Mutex<Vec<String>> {
25    static REGISTERED_SCOPES: OnceLock<Mutex<Vec<String>>> = OnceLock::new();
26    REGISTERED_SCOPES.get_or_init(|| Mutex::new(vec![]))
27}
28
29/// Check whether this test function name is already taken as scope. If yes, a
30/// counter is appended to make it unique. In the end, a unique scope is returned.
31fn get_free_scope(mut test_fn_name: String) -> String {
32    let mut vec = registered_scopes().lock().unwrap();
33    let mut counter = 1;
34    let len = test_fn_name.len();
35    while vec.contains(&test_fn_name) {
36        counter += 1;
37        test_fn_name.replace_range(len.., &counter.to_string());
38    }
39    vec.push(test_fn_name.clone());
40    test_fn_name
41}
42
43/// A procedural macro that ensures that a global logger is registered for the
44/// annotated test.
45///
46/// Additionally, the macro injects a local function called `logs_contain`,
47/// which can be used to assert that a certain string was logged within this
48/// test.
49///
50/// Check out the docs of the `tracing-test` crate for more usage information.
51#[proc_macro_attribute]
52pub fn traced_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
53    // Parse annotated function
54    let mut function: ItemFn = parse(item).expect("Could not parse ItemFn");
55
56    // Determine scope
57    let scope = get_free_scope(function.sig.ident.to_string());
58
59    // Determine features
60    //
61    // Note: This cannot be called in the block below, otherwise it would be
62    //       evaluated in the context of the calling crate, not of the macro
63    //       crate!
64    let no_env_filter = cfg!(feature = "no-env-filter");
65
66    // Prepare code that should be injected at the start of the function
67    let init = parse::<Stmt>(
68        quote! {
69            tracing_test::internal::INITIALIZED.call_once(|| {
70                let env_filter = if #no_env_filter {
71                    "trace".to_string()
72                } else {
73                    let crate_name = module_path!()
74                        .split(":")
75                        .next()
76                        .expect("Could not find crate name in module path")
77                        .to_string();
78                    format!("{}=trace", crate_name)
79                };
80                let mock_writer = tracing_test::internal::MockWriter::new(&tracing_test::internal::global_buf());
81                let subscriber = tracing_test::internal::get_subscriber(mock_writer, &env_filter);
82                tracing::dispatcher::set_global_default(subscriber)
83                    .expect("Could not set global tracing subscriber");
84            });
85        }
86        .into(),
87    )
88    .expect("Could not parse quoted statement init");
89    let span = parse::<Stmt>(
90        quote! {
91            let span = tracing::info_span!(#scope);
92        }
93        .into(),
94    )
95    .expect("Could not parse quoted statement span");
96    let enter = parse::<Stmt>(
97        quote! {
98            let _enter = span.enter();
99        }
100        .into(),
101    )
102    .expect("Could not parse quoted statement enter");
103    let logs_contain_fn = parse::<Stmt>(
104        quote! {
105            fn logs_contain(val: &str) -> bool {
106                tracing_test::internal::logs_with_scope_contain(#scope, val)
107            }
108
109        }
110        .into(),
111    )
112    .expect("Could not parse quoted statement logs_contain_fn");
113    let logs_assert_fn = parse::<Stmt>(
114        quote! {
115            /// Run a function against the log lines. If the function returns
116            /// an `Err`, panic. This can be used to run arbitrary assertion
117            /// logic against the logs.
118            fn logs_assert(f: impl Fn(&[&str]) -> std::result::Result<(), String>) {
119                match tracing_test::internal::logs_assert(#scope, f) {
120                    Ok(()) => {},
121                    Err(msg) => panic!("The logs_assert function returned an error: {}", msg),
122                };
123            }
124        }
125        .into(),
126    )
127    .expect("Could not parse quoted statement logs_assert_fn");
128
129    // Inject code into function
130    function.block.stmts.insert(0, init);
131    function.block.stmts.insert(1, span);
132    function.block.stmts.insert(2, enter);
133    function.block.stmts.insert(3, logs_contain_fn);
134    function.block.stmts.insert(4, logs_assert_fn);
135
136    // Generate token stream
137    TokenStream::from(function.to_token_stream())
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_get_free_scope() {
146        let initial = get_free_scope("test_fn_name".to_string());
147        assert_eq!(initial, "test_fn_name");
148
149        let second = get_free_scope("test_fn_name".to_string());
150        assert_eq!(second, "test_fn_name2");
151        let third = get_free_scope("test_fn_name".to_string());
152        assert_eq!(third, "test_fn_name3");
153
154        // Insert a conflicting entry
155        let fourth = get_free_scope("test_fn_name4".to_string());
156        assert_eq!(fourth, "test_fn_name4");
157
158        let fifth = get_free_scope("test_fn_name5".to_string());
159        assert_eq!(fifth, "test_fn_name5");
160    }
161}