1use crate::trace::{TraceError, TraceResult};
2use crate::{SpanId, TraceFlags, TraceId};
3use std::collections::VecDeque;
4use std::hash::Hash;
5use std::str::FromStr;
6use thiserror::Error;
7
8#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
16pub struct TraceState(Option<VecDeque<(String, String)>>);
17
18impl TraceState {
19 pub const NONE: TraceState = TraceState(None);
21
22 fn valid_key(key: &str) -> bool {
26 if key.len() > 256 {
27 return false;
28 }
29
30 let allowed_special = |b: u8| (b == b'_' || b == b'-' || b == b'*' || b == b'/');
31 let mut vendor_start = None;
32 for (i, &b) in key.as_bytes().iter().enumerate() {
33 if !(b.is_ascii_lowercase() || b.is_ascii_digit() || allowed_special(b) || b == b'@') {
34 return false;
35 }
36
37 if i == 0 && (!b.is_ascii_lowercase() && !b.is_ascii_digit()) {
38 return false;
39 } else if b == b'@' {
40 if vendor_start.is_some() || i + 14 < key.len() {
41 return false;
42 }
43 vendor_start = Some(i);
44 } else if let Some(start) = vendor_start {
45 if i == start + 1 && !(b.is_ascii_lowercase() || b.is_ascii_digit()) {
46 return false;
47 }
48 }
49 }
50
51 true
52 }
53
54 fn valid_value(value: &str) -> bool {
58 if value.len() > 256 {
59 return false;
60 }
61
62 !(value.contains(',') || value.contains('='))
63 }
64
65 pub fn from_key_value<T, K, V>(trace_state: T) -> TraceResult<Self>
79 where
80 T: IntoIterator<Item = (K, V)>,
81 K: ToString,
82 V: ToString,
83 {
84 let ordered_data = trace_state
85 .into_iter()
86 .map(|(key, value)| {
87 let (key, value) = (key.to_string(), value.to_string());
88 if !TraceState::valid_key(key.as_str()) {
89 return Err(TraceStateError::Key(key));
90 }
91 if !TraceState::valid_value(value.as_str()) {
92 return Err(TraceStateError::Value(value));
93 }
94
95 Ok((key, value))
96 })
97 .collect::<Result<VecDeque<_>, TraceStateError>>()?;
98
99 if ordered_data.is_empty() {
100 Ok(TraceState(None))
101 } else {
102 Ok(TraceState(Some(ordered_data)))
103 }
104 }
105
106 pub fn get(&self, key: &str) -> Option<&str> {
108 self.0.as_ref().and_then(|kvs| {
109 kvs.iter().find_map(|item| {
110 if item.0.as_str() == key {
111 Some(item.1.as_str())
112 } else {
113 None
114 }
115 })
116 })
117 }
118
119 pub fn insert<K, V>(&self, key: K, value: V) -> TraceResult<TraceState>
126 where
127 K: Into<String>,
128 V: Into<String>,
129 {
130 let (key, value) = (key.into(), value.into());
131 if !TraceState::valid_key(key.as_str()) {
132 return Err(TraceStateError::Key(key).into());
133 }
134 if !TraceState::valid_value(value.as_str()) {
135 return Err(TraceStateError::Value(value).into());
136 }
137
138 let mut trace_state = self.delete_from_deque(key.clone());
139 let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));
140
141 kvs.push_front((key, value));
142
143 Ok(trace_state)
144 }
145
146 pub fn delete<K: Into<String>>(&self, key: K) -> TraceResult<TraceState> {
154 let key = key.into();
155 if !TraceState::valid_key(key.as_str()) {
156 return Err(TraceStateError::Key(key).into());
157 }
158
159 Ok(self.delete_from_deque(key))
160 }
161
162 fn delete_from_deque(&self, key: String) -> TraceState {
164 let mut owned = self.clone();
165 if let Some(kvs) = owned.0.as_mut() {
166 if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
167 kvs.remove(index);
168 }
169 }
170 owned
171 }
172
173 pub fn header(&self) -> String {
176 self.header_delimited("=", ",")
177 }
178
179 pub fn header_delimited(&self, entry_delimiter: &str, list_delimiter: &str) -> String {
181 self.0
182 .as_ref()
183 .map(|kvs| {
184 kvs.iter()
185 .map(|(key, value)| format!("{}{}{}", key, entry_delimiter, value))
186 .collect::<Vec<String>>()
187 .join(list_delimiter)
188 })
189 .unwrap_or_default()
190 }
191}
192
193impl FromStr for TraceState {
194 type Err = TraceError;
195
196 fn from_str(s: &str) -> Result<Self, Self::Err> {
197 let list_members: Vec<&str> = s.split_terminator(',').collect();
198 let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());
199
200 for list_member in list_members {
201 match list_member.find('=') {
202 None => return Err(TraceStateError::List(list_member.to_string()).into()),
203 Some(separator_index) => {
204 let (key, value) = list_member.split_at(separator_index);
205 key_value_pairs
206 .push((key.to_string(), value.trim_start_matches('=').to_string()));
207 }
208 }
209 }
210
211 TraceState::from_key_value(key_value_pairs)
212 }
213}
214
215#[derive(Error, Debug)]
217#[non_exhaustive]
218enum TraceStateError {
219 #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
223 Key(String),
224
225 #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
229 Value(String),
230
231 #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
235 List(String),
236}
237
238impl From<TraceStateError> for TraceError {
239 fn from(err: TraceStateError) -> Self {
240 TraceError::Other(Box::new(err))
241 }
242}
243
244#[derive(Clone, Debug, PartialEq, Hash, Eq)]
254pub struct SpanContext {
255 trace_id: TraceId,
256 span_id: SpanId,
257 trace_flags: TraceFlags,
258 is_remote: bool,
259 trace_state: TraceState,
260}
261
262impl SpanContext {
263 pub const NONE: SpanContext = SpanContext {
265 trace_id: TraceId::INVALID,
266 span_id: SpanId::INVALID,
267 trace_flags: TraceFlags::NOT_SAMPLED,
268 is_remote: false,
269 trace_state: TraceState::NONE,
270 };
271
272 pub fn empty_context() -> Self {
274 SpanContext::NONE
275 }
276
277 pub fn new(
279 trace_id: TraceId,
280 span_id: SpanId,
281 trace_flags: TraceFlags,
282 is_remote: bool,
283 trace_state: TraceState,
284 ) -> Self {
285 SpanContext {
286 trace_id,
287 span_id,
288 trace_flags,
289 is_remote,
290 trace_state,
291 }
292 }
293
294 pub fn trace_id(&self) -> TraceId {
296 self.trace_id
297 }
298
299 pub fn span_id(&self) -> SpanId {
301 self.span_id
302 }
303
304 pub fn trace_flags(&self) -> TraceFlags {
309 self.trace_flags
310 }
311
312 pub fn is_valid(&self) -> bool {
315 self.trace_id != TraceId::INVALID && self.span_id != SpanId::INVALID
316 }
317
318 pub fn is_remote(&self) -> bool {
320 self.is_remote
321 }
322
323 pub fn is_sampled(&self) -> bool {
327 self.trace_flags.is_sampled()
328 }
329
330 pub fn trace_state(&self) -> &TraceState {
332 &self.trace_state
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::{trace::TraceContextExt, Context};
340
341 #[rustfmt::skip]
342 fn trace_state_test_data() -> Vec<(TraceState, &'static str, &'static str)> {
343 vec![
344 (TraceState::from_key_value(vec![("foo", "bar")]).unwrap(), "foo=bar", "foo"),
345 (TraceState::from_key_value(vec![("foo", ""), ("apple", "banana")]).unwrap(), "foo=,apple=banana", "apple"),
346 (TraceState::from_key_value(vec![("foo", "bar"), ("apple", "banana")]).unwrap(), "foo=bar,apple=banana", "apple"),
347 ]
348 }
349
350 #[test]
351 fn test_trace_state() {
352 for test_case in trace_state_test_data() {
353 assert_eq!(test_case.0.clone().header(), test_case.1);
354
355 let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");
356
357 let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
358 assert!(updated_trace_state.is_ok());
359 let updated_trace_state = updated_trace_state.unwrap();
360
361 let updated = format!("{}={}", test_case.2, new_key);
362
363 let index = updated_trace_state.clone().header().find(&updated);
364
365 assert!(index.is_some());
366 assert_eq!(index.unwrap(), 0);
367
368 let deleted_trace_state = updated_trace_state.delete(test_case.2.to_string());
369 assert!(deleted_trace_state.is_ok());
370
371 let deleted_trace_state = deleted_trace_state.unwrap();
372
373 assert!(deleted_trace_state.get(test_case.2).is_none());
374 }
375 }
376
377 #[test]
378 fn test_trace_state_key() {
379 let test_data: Vec<(&'static str, bool)> = vec![
380 ("123", true),
381 ("bar", true),
382 ("foo@bar", true),
383 ("foo@0123456789abcdef", false),
384 ("foo@012345678", true),
385 ("FOO@BAR", false),
386 ("你好", false),
387 ];
388
389 for (key, expected) in test_data {
390 assert_eq!(TraceState::valid_key(key), expected, "test key: {:?}", key);
391 }
392 }
393
394 #[test]
395 fn test_trace_state_insert() {
396 let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
397 let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
398 assert!(trace_state.get("testkey").is_none()); assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); }
401
402 #[test]
403 fn test_context_span_debug() {
404 let cx = Context::current();
405 assert_eq!(
406 format!("{:?}", cx),
407 "Context { span: \"None\", entries: 0 }"
408 );
409 let cx = Context::current().with_remote_span_context(SpanContext::NONE);
410 assert_eq!(
411 format!("{:?}", cx),
412 "Context { \
413 span: SpanContext { \
414 trace_id: 00000000000000000000000000000000, \
415 span_id: 0000000000000000, \
416 trace_flags: TraceFlags(0), \
417 is_remote: false, \
418 trace_state: TraceState(None) \
419 }, \
420 entries: 1 \
421 }"
422 );
423 }
424}