1use std::sync::LazyLock;
2use tracing::level_filters::LevelFilter;
3use tracing_indicatif::{
4    IndicatifLayer, IndicatifWriter, filter::IndicatifFilter, style::ProgressStyle,
5    util::FilteredFormatFields, writer,
6};
7use tracing_subscriber::{
8    EnvFilter, Layer, Registry,
9    layer::{Identity, SubscriberExt},
10    util::SubscriberInitExt as _,
11};
12
13#[cfg(feature = "otlp")]
14use opentelemetry_sdk::{
15    Resource, propagation::TraceContextPropagator, resource::SdkProvidedResourceDetector,
16};
17#[cfg(feature = "tracy")]
18use tracing_tracy::TracyLayer;
19
20pub mod propagate;
21
22pub static PB_PROGRESS_STYLE: LazyLock<ProgressStyle> = LazyLock::new(|| {
23    ProgressStyle::with_template(
24        "{span_child_prefix} {wide_msg} {bar:10} ({elapsed}) {pos:>7}/{len:7}",
25    )
26    .expect("invalid progress template")
27});
28pub static PB_TRANSFER_STYLE: LazyLock<ProgressStyle> = LazyLock::new(|| {
29    ProgressStyle::with_template(
30        "{span_child_prefix} {wide_msg} {binary_bytes:>7}/{binary_total_bytes:7}@{decimal_bytes_per_sec} ({elapsed}) {bar:10} "
31    )
32    .expect("invalid progress template")
33});
34pub static PB_SPINNER_STYLE: LazyLock<ProgressStyle> = LazyLock::new(|| {
35    ProgressStyle::with_template(
36        "{span_child_prefix}{spinner} {wide_msg} ({elapsed}) {pos:>7}/{len:7}",
37    )
38    .expect("invalid progress template")
39});
40
41#[derive(thiserror::Error, Debug)]
42pub enum Error {
43    #[error(transparent)]
44    Init(#[from] tracing_subscriber::util::TryInitError),
45
46    #[cfg(feature = "otlp")]
47    #[error(transparent)]
48    OTEL(#[from] opentelemetry_sdk::error::OTelSdkError),
49}
50
51#[derive(Clone)]
52pub struct TracingHandle {
53    stdout_writer: IndicatifWriter<writer::Stdout>,
54    stderr_writer: IndicatifWriter<writer::Stderr>,
55
56    #[cfg(feature = "otlp")]
57    meter_provider: Option<opentelemetry_sdk::metrics::SdkMeterProvider>,
58
59    #[cfg(feature = "otlp")]
60    tracer_provider: Option<opentelemetry_sdk::trace::SdkTracerProvider>,
61}
62
63impl TracingHandle {
64    pub fn get_stdout_writer(&self) -> IndicatifWriter<writer::Stdout> {
69        self.stdout_writer.clone()
71    }
72
73    pub fn get_stderr_writer(&self) -> IndicatifWriter<writer::Stderr> {
78        self.stderr_writer.clone()
80    }
81
82    pub async fn flush(&self) -> Result<(), Error> {
87        #[cfg(feature = "otlp")]
88        {
89            if let Some(tracer_provider) = &self.tracer_provider {
90                tracer_provider.force_flush()?;
91            }
92            if let Some(meter_provider) = &self.meter_provider {
93                meter_provider.force_flush()?;
94            }
95        }
96        Ok(())
97    }
98
99    pub async fn shutdown(&self) -> Result<(), Error> {
104        self.flush().await?;
105        #[cfg(feature = "otlp")]
106        {
107            if let Some(tracer_provider) = &self.tracer_provider {
108                tracer_provider.shutdown()?;
109            }
110            if let Some(meter_provider) = &self.meter_provider {
111                meter_provider.shutdown()?;
112            }
113        }
114
115        Ok(())
116    }
117}
118
119#[must_use = "Don't forget to call build() to enable tracing."]
120#[derive(Default)]
121pub struct TracingBuilder {
122    progess_bar: bool,
123
124    #[cfg(feature = "otlp")]
125    service_name: Option<&'static str>,
126}
127
128impl TracingBuilder {
129    #[cfg(feature = "otlp")]
130    pub fn enable_otlp(mut self, service_name: &'static str) -> TracingBuilder {
132        self.service_name = Some(service_name);
133        self
134    }
135
136    pub fn enable_progressbar(mut self) -> TracingBuilder {
138        self.progess_bar = true;
139        self
140    }
141
142    pub fn build(self) -> Result<TracingHandle, Error> {
149        self.build_with_additional(Identity::new())
150    }
151
152    pub fn build_with_additional<L>(self, additional_layer: L) -> Result<TracingHandle, Error>
167    where
168        L: Layer<Registry> + Send + Sync + 'static,
169    {
170        let indicatif_layer = IndicatifLayer::new().with_progress_style(PB_SPINNER_STYLE.clone());
172        let stdout_writer = indicatif_layer.get_stdout_writer();
173        let stderr_writer = indicatif_layer.get_stderr_writer();
174
175        let layered = tracing_subscriber::fmt::Layer::new()
176            .fmt_fields(FilteredFormatFields::new(
177                tracing_subscriber::fmt::format::DefaultFields::new(),
178                |field| field.name() != "indicatif.pb_show",
179            ))
180            .with_writer(indicatif_layer.get_stderr_writer())
181            .compact()
182            .and_then((self.progess_bar).then(|| {
183                indicatif_layer.with_filter(
184                    IndicatifFilter::new(false),
186                )
187            }));
188        #[cfg(feature = "tracy")]
189        let layered = layered.and_then(TracyLayer::default());
190
191        #[cfg(feature = "otlp")]
192        let mut g_tracer_provider = None;
193        #[cfg(feature = "otlp")]
194        let mut g_meter_provider = None;
195
196        #[cfg(feature = "otlp")]
198        let layered = layered.and_then({
199            if let Some(service_name) = self.service_name.map(String::from) {
200                use opentelemetry::trace::TracerProvider;
201
202                opentelemetry::global::set_text_map_propagator(TraceContextPropagator::new());
204
205                let tracer_provider = gen_tracer_provider(service_name.clone())
206                    .expect("Unable to configure trace provider");
207
208                let meter_provider =
209                    gen_meter_provider(service_name).expect("Unable to configure meter provider");
210
211                opentelemetry::global::set_meter_provider(meter_provider.clone());
214
215                g_tracer_provider = Some(tracer_provider.clone());
216                g_meter_provider = Some(meter_provider.clone());
217
218                Some(tracing_opentelemetry::layer().with_tracer(tracer_provider.tracer("snix")))
220            } else {
221                None
222            }
223        });
224
225        let layered = layered.with_filter(
226            EnvFilter::builder()
227                .with_default_directive(LevelFilter::INFO.into())
228                .from_env()
229                .expect("invalid RUST_LOG"),
230        );
231
232        tracing_subscriber::registry()
233            .with(additional_layer)
236            .with(layered)
237            .try_init()?;
238
239        Ok(TracingHandle {
240            stdout_writer,
241            stderr_writer,
242
243            #[cfg(feature = "otlp")]
244            meter_provider: g_meter_provider,
245            #[cfg(feature = "otlp")]
246            tracer_provider: g_tracer_provider,
247        })
248    }
249}
250
251#[cfg(feature = "otlp")]
252fn gen_resources(service_name: String) -> Resource {
253    Resource::builder()
258        .with_service_name(service_name)
259        .with_detector(Box::new(SdkProvidedResourceDetector))
260        .build()
261}
262
263#[cfg(feature = "otlp")]
266fn gen_tracer_provider(
267    service_name: String,
268) -> Result<opentelemetry_sdk::trace::SdkTracerProvider, opentelemetry::trace::TraceError> {
269    use opentelemetry_otlp::{ExportConfig, SpanExporter, WithExportConfig};
270
271    let exporter = SpanExporter::builder()
272        .with_tonic()
273        .with_export_config(ExportConfig::default())
274        .build()?;
275
276    let tracer_provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
277        .with_batch_exporter(exporter)
278        .with_resource(gen_resources(service_name))
279        .build();
280    Ok(tracer_provider)
300}
301
302const _OTEL_METRIC_EXPORT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
309
310#[cfg(feature = "otlp")]
311fn gen_meter_provider(
312    service_name: String,
313) -> Result<opentelemetry_sdk::metrics::SdkMeterProvider, opentelemetry_sdk::metrics::MetricError> {
314    use std::time::Duration;
315
316    use opentelemetry_otlp::WithExportConfig;
317    use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider};
318    let exporter = opentelemetry_otlp::MetricExporter::builder()
319        .with_tonic()
320        .with_timeout(Duration::from_secs(10))
321        .build()?;
322
323    let reader = PeriodicReader::builder(exporter)
324        .with_interval(_OTEL_METRIC_EXPORT_INTERVAL)
325        .build();
326
327    Ok(SdkMeterProvider::builder()
328        .with_reader(reader)
329        .with_resource(gen_resources(service_name))
330        .build())
331}