1use opentelemetry::{
5 propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
6 trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState},
7 Context,
8};
9use std::str::FromStr;
10use std::sync::OnceLock;
11
12const SUPPORTED_VERSION: u8 = 0;
13const MAX_VERSION: u8 = 254;
14const TRACEPARENT_HEADER: &str = "traceparent";
15const TRACESTATE_HEADER: &str = "tracestate";
16
17static TRACE_CONTEXT_HEADER_FIELDS: OnceLock<[String; 2]> = OnceLock::new();
19
20fn trace_context_header_fields() -> &'static [String; 2] {
21 TRACE_CONTEXT_HEADER_FIELDS
22 .get_or_init(|| [TRACEPARENT_HEADER.to_owned(), TRACESTATE_HEADER.to_owned()])
23}
24
25#[derive(Clone, Debug, Default)]
52pub struct TraceContextPropagator {
53 _private: (),
54}
55
56impl TraceContextPropagator {
57 pub fn new() -> Self {
59 TraceContextPropagator { _private: () }
60 }
61
62 fn extract_span_context(&self, extractor: &dyn Extractor) -> Result<SpanContext, ()> {
64 let header_value = extractor.get(TRACEPARENT_HEADER).unwrap_or("").trim();
65 let parts = header_value.split_terminator('-').collect::<Vec<&str>>();
66 if parts.len() < 4 {
68 return Err(());
69 }
70
71 let version = u8::from_str_radix(parts[0], 16).map_err(|_| ())?;
73 if version > MAX_VERSION || version == 0 && parts.len() != 4 {
74 return Err(());
75 }
76
77 if parts[1].chars().any(|c| c.is_ascii_uppercase()) {
79 return Err(());
80 }
81
82 let trace_id = TraceId::from_hex(parts[1]).map_err(|_| ())?;
84
85 if parts[2].chars().any(|c| c.is_ascii_uppercase()) {
87 return Err(());
88 }
89
90 let span_id = SpanId::from_hex(parts[2]).map_err(|_| ())?;
92
93 let opts = u8::from_str_radix(parts[3], 16).map_err(|_| ())?;
95
96 if version == 0 && opts > 2 {
98 return Err(());
99 }
100
101 let trace_flags = TraceFlags::new(opts) & TraceFlags::SAMPLED;
104
105 let trace_state = match extractor.get(TRACESTATE_HEADER) {
106 Some(trace_state_str) => {
107 TraceState::from_str(trace_state_str).unwrap_or_else(|_| TraceState::default())
108 }
109 None => TraceState::default(),
110 };
111
112 let span_context = SpanContext::new(trace_id, span_id, trace_flags, true, trace_state);
114
115 if !span_context.is_valid() {
117 return Err(());
118 }
119
120 Ok(span_context)
121 }
122}
123
124impl TextMapPropagator for TraceContextPropagator {
125 fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
128 let span = cx.span();
129 let span_context = span.span_context();
130 if span_context.is_valid() {
131 let header_value = format!(
132 "{:02x}-{}-{}-{:02x}",
133 SUPPORTED_VERSION,
134 span_context.trace_id(),
135 span_context.span_id(),
136 span_context.trace_flags() & TraceFlags::SAMPLED
137 );
138 injector.set(TRACEPARENT_HEADER, header_value);
139 injector.set(TRACESTATE_HEADER, span_context.trace_state().header());
140 }
141 }
142
143 fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
148 self.extract_span_context(extractor)
149 .map(|sc| cx.with_remote_span_context(sc))
150 .unwrap_or_else(|_| cx.clone())
151 }
152
153 fn fields(&self) -> FieldIter<'_> {
154 FieldIter::new(trace_context_header_fields())
155 }
156}
157
158#[cfg(all(test, feature = "testing", feature = "trace"))]
159mod tests {
160 use super::*;
161 use crate::testing::trace::TestSpan;
162 use std::collections::HashMap;
163
164 #[rustfmt::skip]
165 fn extract_data() -> Vec<(&'static str, &'static str, SpanContext)> {
166 vec![
167 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
168 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
169 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
170 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
171 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-08", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
172 ("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-XYZxsf09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
173 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
174 ("01-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
175 ]
176 }
177
178 #[rustfmt::skip]
179 fn extract_data_invalid() -> Vec<(&'static str, &'static str)> {
180 vec![
181 ("0000-00000000000000000000000000000000-0000000000000000-01", "wrong version length"),
182 ("00-ab00000000000000000000000000000000-cd00000000000000-01", "wrong trace ID length"),
183 ("00-ab000000000000000000000000000000-cd0000000000000000-01", "wrong span ID length"),
184 ("00-ab000000000000000000000000000000-cd00000000000000-0100", "wrong trace flag length"),
185 ("qw-00000000000000000000000000000000-0000000000000000-01", "bogus version"),
186 ("00-qw000000000000000000000000000000-cd00000000000000-01", "bogus trace ID"),
187 ("00-ab000000000000000000000000000000-qw00000000000000-01", "bogus span ID"),
188 ("00-ab000000000000000000000000000000-cd00000000000000-qw", "bogus trace flag"),
189 ("A0-00000000000000000000000000000000-0000000000000000-01", "upper case version"),
190 ("00-AB000000000000000000000000000000-cd00000000000000-01", "upper case trace ID"),
191 ("00-ab000000000000000000000000000000-CD00000000000000-01", "upper case span ID"),
192 ("00-ab000000000000000000000000000000-cd00000000000000-A1", "upper case trace flag"),
193 ("00-00000000000000000000000000000000-0000000000000000-01", "zero trace ID and span ID"),
194 ("00-ab000000000000000000000000000000-cd00000000000000-09", "trace-flag unused bits set"),
195 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7", "missing options"),
196 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-", "empty options"),
197 ]
198 }
199
200 #[rustfmt::skip]
201 fn inject_data() -> Vec<(&'static str, &'static str, SpanContext)> {
202 vec![
203 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
204 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
205 ("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::new(0xff), true, TraceState::from_str("foo=bar").unwrap())),
206 ("", "", SpanContext::empty_context()),
207 ]
208 }
209
210 #[test]
211 fn extract_w3c() {
212 let propagator = TraceContextPropagator::new();
213
214 for (trace_parent, trace_state, expected_context) in extract_data() {
215 let mut extractor = HashMap::new();
216 extractor.insert(TRACEPARENT_HEADER.to_string(), trace_parent.to_string());
217 extractor.insert(TRACESTATE_HEADER.to_string(), trace_state.to_string());
218
219 assert_eq!(
220 propagator.extract(&extractor).span().span_context(),
221 &expected_context
222 )
223 }
224 }
225
226 #[test]
227 fn extract_w3c_tracestate() {
228 let propagator = TraceContextPropagator::new();
229 let state = "foo=bar".to_string();
230 let parent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00".to_string();
231
232 let mut extractor = HashMap::new();
233 extractor.insert(TRACEPARENT_HEADER.to_string(), parent);
234 extractor.insert(TRACESTATE_HEADER.to_string(), state.clone());
235
236 assert_eq!(
237 propagator
238 .extract(&extractor)
239 .span()
240 .span_context()
241 .trace_state()
242 .header(),
243 state
244 )
245 }
246
247 #[test]
248 fn extract_w3c_reject_invalid() {
249 let propagator = TraceContextPropagator::new();
250
251 for (invalid_header, reason) in extract_data_invalid() {
252 let mut extractor = HashMap::new();
253 extractor.insert(TRACEPARENT_HEADER.to_string(), invalid_header.to_string());
254
255 assert_eq!(
256 propagator.extract(&extractor).span().span_context(),
257 &SpanContext::empty_context(),
258 "{}",
259 reason
260 )
261 }
262 }
263
264 #[test]
265 fn inject_w3c() {
266 let propagator = TraceContextPropagator::new();
267
268 for (expected_trace_parent, expected_trace_state, context) in inject_data() {
269 let mut injector = HashMap::new();
270 propagator.inject_context(
271 &Context::current_with_span(TestSpan(context)),
272 &mut injector,
273 );
274
275 assert_eq!(
276 Extractor::get(&injector, TRACEPARENT_HEADER).unwrap_or(""),
277 expected_trace_parent
278 );
279
280 assert_eq!(
281 Extractor::get(&injector, TRACESTATE_HEADER).unwrap_or(""),
282 expected_trace_state
283 );
284 }
285 }
286
287 #[test]
288 fn inject_w3c_tracestate() {
289 let propagator = TraceContextPropagator::new();
290 let state = "foo=bar";
291
292 let mut injector: HashMap<String, String> = HashMap::new();
293 injector.set(TRACESTATE_HEADER, state.to_string());
294
295 Context::map_current(|cx| propagator.inject_context(cx, &mut injector));
296
297 assert_eq!(Extractor::get(&injector, TRACESTATE_HEADER), Some(state))
298 }
299}