snix_castore/
composition.rs

1//! The composition module allows composing different kinds of services based on a set of service
2//! configurations _at runtime_.
3//!
4//! Store configs are deserialized with serde. The registry provides a stateful mapping from the
5//! `type` tag of an internally tagged enum on the serde side to a Config struct which is
6//! deserialized and then returned as a `Box<dyn ServiceBuilder<Output = dyn BlobService>>`
7//! (the same for DirectoryService instead of BlobService etc).
8//!
9//! ### Example 1.: Implementing a new BlobService
10//!
11//! You need a Config struct which implements `DeserializeOwned` and
12//! `ServiceBuilder<Output = dyn BlobService>`.
13//! Provide the user with a function to call with
14//! their registry. You register your new type as:
15//!
16//! ```
17//! use std::sync::Arc;
18//!
19//! use snix_castore::composition::*;
20//! use snix_castore::blobservice::BlobService;
21//!
22//! #[derive(serde::Deserialize)]
23//! struct MyBlobServiceConfig {
24//! }
25//!
26//! #[tonic::async_trait]
27//! impl ServiceBuilder for MyBlobServiceConfig {
28//!     type Output = dyn BlobService;
29//!     async fn build(&self, _: &str, _: &CompositionContext) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
30//!         todo!()
31//!     }
32//! }
33//!
34//! impl TryFrom<url::Url> for MyBlobServiceConfig {
35//!     type Error = Box<dyn std::error::Error + Send + Sync>;
36//!     fn try_from(url: url::Url) -> Result<Self, Self::Error> {
37//!         todo!()
38//!     }
39//! }
40//!
41//! pub fn add_my_service(reg: &mut Registry) {
42//!     reg.register::<Box<dyn ServiceBuilder<Output = dyn BlobService>>, MyBlobServiceConfig>("myblobservicetype");
43//! }
44//! ```
45//!
46//! Now, when a user deserializes a store config with the type tag "myblobservicetype" into a
47//! `Box<dyn ServiceBuilder<Output = Arc<dyn BlobService>>>`, it will be done via `MyBlobServiceConfig`.
48//!
49//! ### Example 2.: Composing stores to get one store
50//!
51//! ```
52//! use std::sync::Arc;
53//! use snix_castore::composition::*;
54//! use snix_castore::blobservice::BlobService;
55//!
56//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
57//! # tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async move {
58//! let blob_services_configs_json = serde_json::json!({
59//!   "blobstore1": {
60//!     "type": "memory"
61//!   },
62//!   "blobstore2": {
63//!     "type": "memory"
64//!   },
65//!   "root": {
66//!     "type": "combined",
67//!     "near": "&blobstore1",
68//!     "far": "&blobstore2"
69//!   }
70//! });
71//!
72//! let blob_services_configs = with_registry(&REG, || serde_json::from_value(blob_services_configs_json))?;
73//! let mut blob_service_composition = Composition::new(&REG);
74//! blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
75//! let blob_service: Arc<dyn BlobService> = blob_service_composition.build("root").await?;
76//! # Ok(())
77//! # })
78//! # }
79//! ```
80//!
81//! ### Example 3.: Creating another registry extending the default registry with third-party types
82//!
83//! ```
84//! # pub fn add_my_service(reg: &mut snix_castore::composition::Registry) {}
85//! let mut my_registry = snix_castore::composition::Registry::default();
86//! snix_castore::composition::add_default_services(&mut my_registry);
87//! add_my_service(&mut my_registry);
88//! ```
89//!
90//! Continue with Example 2, with my_registry instead of REG
91//!
92//! EXPERIMENTAL: If the xp-composition-url-refs feature is enabled,
93//! entrypoints can also be URL strings, which are created as
94//! anonymous stores. Instantiations of the same URL will
95//! result in a new, distinct anonymous store each time, so creating
96//! two `memory://` stores with this method will not share the same view.
97//! This behavior might change in the future.
98
99use erased_serde::deserialize;
100use futures::FutureExt;
101use futures::future::{BoxFuture, err};
102use serde::de::DeserializeOwned;
103use serde_tagged::de::{BoxFnSeed, SeedFactory};
104use serde_tagged::util::TagString;
105use std::any::{Any, TypeId};
106use std::cell::Cell;
107use std::collections::BTreeMap;
108use std::collections::HashMap;
109use std::marker::PhantomData;
110use std::sync::{Arc, LazyLock};
111use tonic::async_trait;
112
113/// Resolves tag names to the corresponding Config type.
114// Registry implementation details:
115// This is really ugly. Really we would want to store this as a generic static field:
116//
117// ```
118// struct Registry<T>(BTreeMap<(&'static str), RegistryEntry<T>);
119// static REG<T>: Registry<T>;
120// ```
121//
122// so that one version of the static is generated for each Type that the registry is accessed for.
123// However, this is not possible, because generics are only a thing in functions, and even there
124// they will not interact with static items:
125// https://doc.rust-lang.org/reference/items/static-items.html#statics--generics
126//
127// So instead, we make this lookup at runtime by putting the TypeId into the key.
128// But now we can no longer store the `BoxFnSeed<T>` because we are lacking the generic parameter
129// T, so instead store it as `Box<dyn Any>` and downcast to `&BoxFnSeed<T>` when performing the
130// lookup.
131// I said it was ugly...
132#[derive(Default)]
133pub struct Registry(BTreeMap<(TypeId, &'static str), Box<dyn Any + Sync>>);
134pub type FromUrlSeed<T> =
135    Box<dyn Fn(url::Url) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + Sync>;
136pub struct RegistryEntry<T> {
137    serde_deserialize_seed: BoxFnSeed<DeserializeWithRegistry<T>>,
138    from_url_seed: FromUrlSeed<DeserializeWithRegistry<T>>,
139}
140
141struct RegistryWithFakeType<'r, T>(&'r Registry, PhantomData<T>);
142
143impl<'r, 'de: 'r, T: 'static> SeedFactory<'de, TagString<'de>> for RegistryWithFakeType<'r, T> {
144    type Value = DeserializeWithRegistry<T>;
145    type Seed = &'r BoxFnSeed<Self::Value>;
146
147    // Required method
148    fn seed<E>(self, tag: TagString<'de>) -> Result<Self::Seed, E>
149    where
150        E: serde::de::Error,
151    {
152        // using find() and not get() because of https://github.com/rust-lang/rust/issues/80389
153        let seed: &Box<dyn Any + Sync> = self
154            .0
155            .0
156            .iter()
157            .find(|(k, _)| *k == &(TypeId::of::<T>(), tag.as_ref()))
158            .ok_or_else(|| serde::de::Error::custom(format!("Unknown type: {tag}")))?
159            .1;
160
161        let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
162
163        Ok(&entry.serde_deserialize_seed)
164    }
165}
166
167/// Wrapper type which implements Deserialize using the registry
168///
169/// Wrap your type in this in order to deserialize it using a registry, e.g.
170/// `RegistryWithFakeType<Box<dyn MyTrait>>`, then the types registered for `Box<dyn MyTrait>`
171/// will be used.
172pub struct DeserializeWithRegistry<T>(pub T);
173
174impl Registry {
175    /// Registers a mapping from type tag to a concrete type into the registry.
176    ///
177    /// The type parameters are very important:
178    /// After calling `register::<Box<dyn FooTrait>, FooStruct>("footype")`, when a user
179    /// deserializes into an input with the type tag "myblobservicetype" into a
180    /// `Box<dyn FooTrait>`, it will first call the Deserialize imple of `FooStruct` and
181    /// then convert it into a `Box<dyn FooTrait>` using From::from.
182    pub fn register<
183        T: 'static,
184        C: DeserializeOwned
185            + TryFrom<url::Url, Error = Box<dyn std::error::Error + Send + Sync>>
186            + Into<T>,
187    >(
188        &mut self,
189        type_name: &'static str,
190    ) {
191        self.0.insert(
192            (TypeId::of::<T>(), type_name),
193            Box::new(RegistryEntry {
194                serde_deserialize_seed: BoxFnSeed::new(|x| {
195                    deserialize::<C>(x)
196                        .map(Into::into)
197                        .map(DeserializeWithRegistry)
198                }),
199                from_url_seed: Box::new(|url| {
200                    C::try_from(url)
201                        .map(Into::into)
202                        .map(DeserializeWithRegistry)
203                }),
204            }),
205        );
206    }
207}
208
209impl<'de, T: 'static> serde::Deserialize<'de> for DeserializeWithRegistry<T> {
210    fn deserialize<D>(de: D) -> std::result::Result<Self, D::Error>
211    where
212        D: serde::Deserializer<'de>,
213    {
214        serde_tagged::de::internal::deserialize(
215            de,
216            "type",
217            RegistryWithFakeType(ACTIVE_REG.get().unwrap(), PhantomData::<T>),
218        )
219    }
220}
221
222#[derive(Debug, thiserror::Error)]
223enum TryFromUrlError {
224    #[error("Unknown type: {0}")]
225    UnknownTag(String),
226}
227
228impl<T: 'static> TryFrom<url::Url> for DeserializeWithRegistry<T> {
229    type Error = Box<dyn std::error::Error + Send + Sync>;
230    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
231        let tag = url.scheme().split('+').next().unwrap();
232        // same as in the SeedFactory impl: using find() and not get() because of https://github.com/rust-lang/rust/issues/80389
233        let seed = ACTIVE_REG
234            .get()
235            .unwrap()
236            .0
237            .iter()
238            .find(|(k, _)| *k == &(TypeId::of::<T>(), tag))
239            .ok_or_else(|| Box::new(TryFromUrlError::UnknownTag(tag.into())))?
240            .1;
241        let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
242        (entry.from_url_seed)(url)
243    }
244}
245
246thread_local! {
247    /// The active Registry is global state, because there is no convenient and universal way to pass state
248    /// into the functions usually used for deserialization, e.g. `serde_json::from_str`, `toml::from_str`,
249    /// `serde_qs::from_str`.
250    static ACTIVE_REG: Cell<Option<&'static Registry>> = panic!("reg was accessed before initialization");
251}
252
253/// Run the provided closure with a registry context.
254/// Any serde deserialize calls within the closure will use the registry to resolve tag names to
255/// the corresponding Config type.
256pub fn with_registry<R>(reg: &'static Registry, f: impl FnOnce() -> R) -> R {
257    ACTIVE_REG.set(Some(reg));
258    let result = f();
259    ACTIVE_REG.set(None);
260    result
261}
262
263/// The provided registry of snix_castore, with all builtin BlobStore/DirectoryStore implementations
264pub static REG: LazyLock<&'static Registry> = LazyLock::new(|| {
265    let mut reg = Default::default();
266    add_default_services(&mut reg);
267    // explicitly leak to get an &'static, so that we gain `&Registry: Send` from `Registry: Sync`
268    Box::leak(Box::new(reg))
269});
270
271// ---------- End of generic registry code --------- //
272
273/// Register the builtin services of snix_castore (blob services and directory
274/// services) with the given registry.
275/// This can be used outside to create your own registry with the builtin types
276/// _and_ extra third party types.
277pub fn add_default_services(reg: &mut Registry) {
278    crate::blobservice::register_blob_services(reg);
279    crate::directoryservice::register_directory_services(reg);
280}
281
282pub struct CompositionContext<'a> {
283    // The stack used to detect recursive instantiations and prevent deadlocks
284    // The TypeId of the trait object is included to distinguish e.g. the
285    // BlobService "root" and the DirectoryService "root".
286    stack: Vec<(TypeId, String)>,
287    registry: &'static Registry,
288    composition: Option<&'a Composition>,
289}
290
291impl CompositionContext<'_> {
292    /// Get a composition context for one-off store creation.
293    pub fn blank(registry: &'static Registry) -> Self {
294        Self {
295            registry,
296            stack: Default::default(),
297            composition: None,
298        }
299    }
300
301    /// Resolves an instance ref (instance name prefixed with "&") or an
302    /// anonymous store URL to an instantiated service.
303    /// The latter is only allowed if xp-composition-url-refs is enabled.
304    pub async fn resolve<T: ?Sized + Send + Sync + 'static>(
305        &self,
306        s: &str,
307    ) -> Result<Arc<T>, CompositionError> {
308        // The string is expected to start with a `&`...
309        if let Some(instance_name) = s.strip_prefix("&") {
310            // disallow recursion
311            if self
312                .stack
313                .contains(&(TypeId::of::<T>(), instance_name.to_owned()))
314            {
315                return Err(CompositionError::Recursion(
316                    self.stack.iter().map(|(_, n)| n.clone()).collect(),
317                ));
318            }
319
320            self.build_internal(instance_name.to_owned()).await
321        } else {
322            // ... or it's an anonymous store with xp-composition-url-refs
323            #[cfg(feature = "xp-composition-url-refs")]
324            {
325                // This might be a URL, we are building an anonymous store
326                Ok(self
327                    .build_anonymous(s)
328                    .await
329                    .map_err(|e| CompositionError::Failed(s.to_string(), Arc::from(e)))?)
330            }
331
332            #[cfg(not(feature = "xp-composition-url-refs"))]
333            {
334                Err(CompositionError::InvalidReference(s.to_owned()))
335            }
336        }
337    }
338
339    #[cfg(feature = "xp-composition-url-refs")]
340    async fn build_anonymous<T: ?Sized + Send + Sync + 'static>(
341        &self,
342        url_str: &str,
343    ) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync>> {
344        let url = url::Url::parse(url_str)?;
345        let config: DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>> =
346            with_registry(self.registry, || url.try_into())?;
347        config.0.build("anonymous", self).await
348    }
349
350    fn build_internal<T: ?Sized + Send + Sync + 'static>(
351        &self,
352        instance_name: String,
353    ) -> BoxFuture<'_, Result<Arc<T>, CompositionError>> {
354        debug_assert!(
355            !instance_name.starts_with("&"),
356            "build_internal should never be called with &"
357        );
358
359        let mut stores = match self.composition {
360            Some(comp) => comp.stores.lock().unwrap(),
361            None => return Box::pin(err(CompositionError::NotFound(instance_name))),
362        };
363        let entry = match stores.get_mut(&(TypeId::of::<T>(), instance_name.to_owned())) {
364            Some(v) => v,
365            None => return Box::pin(err(CompositionError::NotFound(instance_name))),
366        };
367        // for lifetime reasons, we put a placeholder value in the hashmap while we figure out what
368        // the new value should be. the Mutex stays locked the entire time, so nobody will ever see
369        // this temporary value.
370        let prev_val = std::mem::replace(
371            entry,
372            Box::new(InstantiationState::<T>::Done(Err(
373                CompositionError::Poisoned(instance_name.to_owned()),
374            ))),
375        );
376        let (new_val, ret) = match *prev_val.downcast::<InstantiationState<T>>().unwrap() {
377            InstantiationState::Done(service) => (
378                InstantiationState::Done(service.clone()),
379                futures::future::ready(service).boxed(),
380            ),
381            // the construction of the store has not started yet.
382            InstantiationState::Config(config) => {
383                let (tx, rx) = tokio::sync::watch::channel(None);
384                (
385                    InstantiationState::InProgress(rx),
386                    (async move {
387                        let new_context = CompositionContext {
388                            composition: self.composition,
389                            registry: self.registry,
390                            stack: {
391                                let mut stack = self.stack.clone();
392                                stack.push((TypeId::of::<T>(), instance_name.to_owned()));
393                                stack
394                            },
395                        };
396
397                        let res = config
398                            .build(&instance_name, &new_context)
399                            .await
400                            .map_err(|e| match e.downcast() {
401                                Ok(e) => *e,
402                                Err(e) => CompositionError::Failed(instance_name, e.into()),
403                            });
404                        tx.send(Some(res.clone())).unwrap();
405                        res
406                    })
407                    .boxed(),
408                )
409            }
410            // there is already a task driving forward the construction of this store, wait for it
411            // to notify us via the provided channel
412            InstantiationState::InProgress(mut recv) => {
413                (InstantiationState::InProgress(recv.clone()), {
414                    (async move {
415                        loop {
416                            if let Some(v) =
417                                recv.borrow_and_update().as_ref().map(|res| res.clone())
418                            {
419                                break v;
420                            }
421                            recv.changed().await.unwrap();
422                        }
423                    })
424                    .boxed()
425                })
426            }
427        };
428        *entry = Box::new(new_val);
429        ret
430    }
431}
432
433#[async_trait]
434/// This is the trait usually implemented on a per-store-type Config struct and
435/// used to instantiate it.
436pub trait ServiceBuilder: Send + Sync {
437    type Output: ?Sized;
438    async fn build(
439        &self,
440        instance_name: &str,
441        context: &CompositionContext,
442    ) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>>;
443}
444
445impl<T: ?Sized, S: ServiceBuilder<Output = T> + 'static> From<S>
446    for Box<dyn ServiceBuilder<Output = T>>
447{
448    fn from(t: S) -> Self {
449        Box::new(t)
450    }
451}
452
453enum InstantiationState<T: ?Sized> {
454    Config(Box<dyn ServiceBuilder<Output = T>>),
455    InProgress(tokio::sync::watch::Receiver<Option<Result<Arc<T>, CompositionError>>>),
456    Done(Result<Arc<T>, CompositionError>),
457}
458
459pub struct Composition {
460    registry: &'static Registry,
461    stores: std::sync::Mutex<HashMap<(TypeId, String), Box<dyn Any + Send + Sync>>>,
462}
463
464#[derive(thiserror::Error, Clone, Debug)]
465pub enum CompositionError {
466    #[error("store not found: {0}")]
467    NotFound(String),
468    #[error("recursion not allowed {0:?}")]
469    Recursion(Vec<String>),
470    #[error("store construction panicked {0}")]
471    Poisoned(String),
472    #[error("invalid reference, must start with @")]
473    InvalidReference(String),
474    #[error("instantiation of service {0} failed: {1}")]
475    Failed(String, Arc<dyn std::error::Error + Send + Sync>),
476}
477
478impl<T: ?Sized + Send + Sync + 'static>
479    Extend<(
480        String,
481        DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
482    )> for Composition
483{
484    fn extend<I>(&mut self, configs: I)
485    where
486        I: IntoIterator<
487            Item = (
488                String,
489                DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
490            ),
491        >,
492    {
493        self.stores
494            .lock()
495            .unwrap()
496            .extend(configs.into_iter().map(|(k, v)| {
497                (
498                    (TypeId::of::<T>(), k),
499                    Box::new(InstantiationState::Config(v.0)) as Box<dyn Any + Send + Sync>,
500                )
501            }))
502    }
503}
504
505impl Composition {
506    /// The given registry will be used for creation of anonymous stores during composition
507    pub fn new(registry: &'static Registry) -> Self {
508        Self {
509            registry,
510            stores: Default::default(),
511        }
512    }
513
514    pub fn extend_with_configs<T: ?Sized + Send + Sync + 'static>(
515        &mut self,
516        // Keep the concrete `HashMap` type here since it allows for type
517        // inference of what type is previously being deserialized.
518        configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
519    ) {
520        self.extend(configs);
521    }
522
523    /// Looks up the instance name in the composition and returns an instantiated service.
524    pub async fn build<T: ?Sized + Send + Sync + 'static>(
525        &self,
526        instance_name: &str,
527    ) -> Result<Arc<T>, CompositionError> {
528        self.context()
529            .build_internal(instance_name.to_string())
530            .await
531    }
532
533    pub fn context(&self) -> CompositionContext<'_> {
534        CompositionContext {
535            registry: self.registry,
536            stack: vec![],
537            composition: Some(self),
538        }
539    }
540}
541
542#[cfg(test)]
543mod test {
544    use super::*;
545    use crate::blobservice::BlobService;
546    use std::sync::Arc;
547
548    /// Test that we return a reference to the same instance of MemoryBlobService (via ptr_eq)
549    /// when instantiating the same entrypoint twice. By instantiating concurrently, we also
550    /// test the channels notifying the second consumer when the store has been instantiated.
551    #[tokio::test]
552    async fn concurrent() {
553        let blob_services_configs_json = serde_json::json!({
554            "root": {
555                "type": "memory",
556            }
557        });
558
559        let blob_services_configs =
560            with_registry(&REG, || serde_json::from_value(blob_services_configs_json)).unwrap();
561        let mut blob_service_composition = Composition::new(&REG);
562        blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
563        let (blob_service1, blob_service2) = tokio::join!(
564            blob_service_composition.build::<dyn BlobService>("root"),
565            blob_service_composition.build::<dyn BlobService>("root")
566        );
567        assert!(Arc::ptr_eq(
568            &blob_service1.unwrap(),
569            &blob_service2.unwrap()
570        ));
571    }
572
573    /// Test that we throw the correct error when an instantiation would recurse (deadlock)
574    #[tokio::test]
575    async fn reject_recursion() {
576        let blob_services_configs_json = serde_json::json!({
577            "root": {
578                "type": "combined",
579                "near": "&other",
580                "far": "&other"
581            },
582            "other": {
583                "type": "combined",
584                "near": "&root",
585                "far": "&root"
586            }
587        });
588
589        let blob_services_configs =
590            with_registry(&REG, || serde_json::from_value(blob_services_configs_json)).unwrap();
591        let mut blob_service_composition = Composition::new(&REG);
592        blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
593        match blob_service_composition
594            .build::<dyn BlobService>("root")
595            .await
596        {
597            Err(CompositionError::Recursion(stack)) => {
598                assert_eq!(stack, vec!["root".to_string(), "other".to_string()])
599            }
600            other => panic!("should have returned an error, returned: {:?}", other.err()),
601        }
602    }
603}