snafu_derive/
shared.rs

1use std::collections::BTreeSet;
2
3pub(crate) use self::context_module::ContextModule;
4pub(crate) use self::context_selector::ContextSelector;
5pub(crate) use self::display::{Display, DisplayMatchArm};
6pub(crate) use self::error::{Error, ErrorProvideMatchArm, ErrorSourceMatchArm};
7pub(crate) use self::error_compat::{ErrorCompat, ErrorCompatBacktraceMatchArm};
8
9pub(crate) struct StaticIdent(&'static str);
10
11impl quote::ToTokens for StaticIdent {
12    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
13        proc_macro2::Ident::new(self.0, proc_macro2::Span::call_site()).to_tokens(tokens)
14    }
15}
16
17struct AllFieldNames<'a>(&'a crate::FieldContainer);
18
19impl<'a> AllFieldNames<'a> {
20    fn field_names(&self) -> BTreeSet<&'a proc_macro2::Ident> {
21        let user_fields = self.0.selector_kind.user_fields();
22        let backtrace_field = self.0.backtrace_field.as_ref();
23        let implicit_fields = &self.0.implicit_fields;
24        let message_field = self.0.selector_kind.message_field();
25        let source_field = self.0.selector_kind.source_field();
26
27        user_fields
28            .iter()
29            .chain(backtrace_field)
30            .chain(implicit_fields)
31            .chain(message_field)
32            .map(crate::Field::name)
33            .chain(source_field.map(crate::SourceField::name))
34            .collect()
35    }
36}
37
38pub mod context_module {
39    use crate::ModuleName;
40    use heck::ToSnakeCase;
41    use proc_macro2::TokenStream;
42    use quote::{quote, ToTokens};
43    use syn::Ident;
44
45    #[derive(Copy, Clone)]
46    pub(crate) struct ContextModule<'a, T> {
47        pub container_name: &'a Ident,
48        pub module_name: &'a ModuleName,
49        pub visibility: Option<&'a dyn ToTokens>,
50        pub body: &'a T,
51    }
52
53    impl<'a, T> ToTokens for ContextModule<'a, T>
54    where
55        T: ToTokens,
56    {
57        fn to_tokens(&self, stream: &mut TokenStream) {
58            let module_name = match self.module_name {
59                ModuleName::Default => {
60                    let name_str = self.container_name.to_string().to_snake_case();
61                    syn::Ident::new(&name_str, self.container_name.span())
62                }
63                ModuleName::Custom(name) => name.clone(),
64            };
65
66            let visibility = self.visibility;
67            let body = self.body;
68
69            let module_tokens = quote! {
70                #visibility mod #module_name {
71                    use super::*;
72
73                    #body
74                }
75            };
76
77            stream.extend(module_tokens);
78        }
79    }
80}
81
82pub mod context_selector {
83    use crate::{ContextSelectorKind, Field, SuffixKind};
84    use proc_macro2::TokenStream;
85    use quote::{format_ident, quote, IdentFragment, ToTokens};
86
87    const DEFAULT_SUFFIX: &str = "Snafu";
88
89    #[derive(Copy, Clone)]
90    pub(crate) struct ContextSelector<'a> {
91        pub backtrace_field: Option<&'a Field>,
92        pub implicit_fields: &'a [Field],
93        pub crate_root: &'a dyn ToTokens,
94        pub error_constructor_name: &'a dyn ToTokens,
95        pub original_generics_without_defaults: &'a [TokenStream],
96        pub parameterized_error_name: &'a dyn ToTokens,
97        pub selector_doc_string: &'a str,
98        pub selector_kind: &'a ContextSelectorKind,
99        pub selector_name: &'a proc_macro2::Ident,
100        pub user_fields: &'a [Field],
101        pub visibility: Option<&'a dyn ToTokens>,
102        pub where_clauses: &'a [TokenStream],
103        pub default_suffix: &'a SuffixKind,
104    }
105
106    impl ToTokens for ContextSelector<'_> {
107        fn to_tokens(&self, stream: &mut TokenStream) {
108            use self::ContextSelectorKind::*;
109
110            let context_selector = match self.selector_kind {
111                Context { source_field, .. } => {
112                    let context_selector_type = self.generate_type();
113                    let context_selector_impl = match source_field {
114                        Some(_) => None,
115                        None => Some(self.generate_leaf()),
116                    };
117                    let context_selector_into_error_impl =
118                        self.generate_into_error(source_field.as_ref());
119
120                    quote! {
121                        #context_selector_type
122                        #context_selector_impl
123                        #context_selector_into_error_impl
124                    }
125                }
126                Whatever {
127                    source_field,
128                    message_field,
129                } => self.generate_whatever(source_field.as_ref(), message_field),
130                NoContext { source_field } => self.generate_from_source(source_field),
131            };
132
133            stream.extend(context_selector)
134        }
135    }
136
137    impl ContextSelector<'_> {
138        fn user_field_generics(&self) -> Vec<proc_macro2::Ident> {
139            (0..self.user_fields.len())
140                .map(|i| format_ident!("__T{}", i))
141                .collect()
142        }
143
144        fn user_field_names(&self) -> Vec<&syn::Ident> {
145            self.user_fields
146                .iter()
147                .map(|Field { name, .. }| name)
148                .collect()
149        }
150
151        fn parameterized_selector_name(&self) -> TokenStream {
152            let selector_name = self.selector_name.to_string();
153            let selector_name = selector_name.trim_end_matches("Error");
154            let suffix: &dyn IdentFragment = match self.selector_kind {
155                ContextSelectorKind::Context { suffix, .. } => {
156                    match suffix.resolve_with_default(self.default_suffix) {
157                        SuffixKind::Some(s) => s,
158                        SuffixKind::None => &"",
159                        SuffixKind::Default => &DEFAULT_SUFFIX,
160                    }
161                }
162                _ => &DEFAULT_SUFFIX,
163            };
164            let selector_name = format_ident!(
165                "{}{}",
166                selector_name,
167                suffix,
168                span = self.selector_name.span()
169            );
170            let user_generics = self.user_field_generics();
171
172            quote! { #selector_name<#(#user_generics,)*> }
173        }
174
175        fn extended_where_clauses(&self) -> Vec<TokenStream> {
176            let user_fields = self.user_fields;
177            let user_field_generics = self.user_field_generics();
178            let where_clauses = self.where_clauses;
179
180            let target_types = user_fields
181                .iter()
182                .map(|Field { ty, .. }| quote! { ::core::convert::Into<#ty>});
183
184            user_field_generics
185                .into_iter()
186                .zip(target_types)
187                .map(|(gen, bound)| quote! { #gen: #bound })
188                .chain(where_clauses.iter().cloned())
189                .collect()
190        }
191
192        fn transfer_user_fields(&self) -> Vec<TokenStream> {
193            self.user_field_names()
194                .into_iter()
195                .map(|name| {
196                    quote! { #name: ::core::convert::Into::into(self.#name) }
197                })
198                .collect()
199        }
200
201        fn construct_implicit_fields(&self) -> TokenStream {
202            let crate_root = self.crate_root;
203            let expression = quote! {
204                #crate_root::GenerateImplicitData::generate()
205            };
206
207            self.construct_implicit_fields_with_expression(expression)
208        }
209
210        fn construct_implicit_fields_with_source(&self) -> TokenStream {
211            let crate_root = self.crate_root;
212            let expression = quote! { {
213                use #crate_root::AsErrorSource;
214                let error = error.as_error_source();
215                #crate_root::GenerateImplicitData::generate_with_source(error)
216            } };
217
218            self.construct_implicit_fields_with_expression(expression)
219        }
220
221        fn construct_implicit_fields_with_expression(
222            &self,
223            expression: TokenStream,
224        ) -> TokenStream {
225            self.implicit_fields
226                .iter()
227                .chain(self.backtrace_field)
228                .map(|field| {
229                    let name = &field.name;
230                    quote! { #name: #expression, }
231                })
232                .collect()
233        }
234
235        fn generate_type(self) -> TokenStream {
236            let visibility = self.visibility;
237            let parameterized_selector_name = self.parameterized_selector_name();
238            let user_field_generics = self.user_field_generics();
239            let user_field_names = self.user_field_names();
240            let selector_doc_string = self.selector_doc_string;
241
242            let body = if user_field_names.is_empty() {
243                quote! { ; }
244            } else {
245                quote! {
246                    {
247                        #(
248                            #[allow(missing_docs)]
249                            #visibility #user_field_names: #user_field_generics
250                        ),*
251                    }
252                }
253            };
254
255            quote! {
256                #[derive(Debug, Copy, Clone)]
257                #[doc = #selector_doc_string]
258                #visibility struct #parameterized_selector_name #body
259            }
260        }
261
262        fn generate_leaf(self) -> TokenStream {
263            let error_constructor_name = self.error_constructor_name;
264            let original_generics_without_defaults = self.original_generics_without_defaults;
265            let parameterized_error_name = self.parameterized_error_name;
266            let parameterized_selector_name = self.parameterized_selector_name();
267            let user_field_generics = self.user_field_generics();
268            let visibility = self.visibility;
269            let extended_where_clauses = self.extended_where_clauses();
270            let transfer_user_fields = self.transfer_user_fields();
271            let construct_implicit_fields = self.construct_implicit_fields();
272
273            let track_caller = track_caller();
274
275            quote! {
276                impl<#(#user_field_generics,)*> #parameterized_selector_name {
277                    #[doc = "Consume the selector and return the associated error"]
278                    #[must_use]
279                    #track_caller
280                    #visibility fn build<#(#original_generics_without_defaults,)*>(self) -> #parameterized_error_name
281                    where
282                        #(#extended_where_clauses),*
283                    {
284                        #error_constructor_name {
285                            #construct_implicit_fields
286                            #(#transfer_user_fields,)*
287                        }
288                    }
289
290                    #[doc = "Consume the selector and return a `Result` with the associated error"]
291                    #track_caller
292                    #visibility fn fail<#(#original_generics_without_defaults,)* __T>(self) -> ::core::result::Result<__T, #parameterized_error_name>
293                    where
294                        #(#extended_where_clauses),*
295                    {
296                        ::core::result::Result::Err(self.build())
297                    }
298                }
299            }
300        }
301
302        fn generate_into_error(self, source_field: Option<&crate::SourceField>) -> TokenStream {
303            let crate_root = self.crate_root;
304            let error_constructor_name = self.error_constructor_name;
305            let original_generics_without_defaults = self.original_generics_without_defaults;
306            let parameterized_error_name = self.parameterized_error_name;
307            let parameterized_selector_name = self.parameterized_selector_name();
308            let user_field_generics = self.user_field_generics();
309            let extended_where_clauses = self.extended_where_clauses();
310            let transfer_user_fields = self.transfer_user_fields();
311            let construct_implicit_fields = if source_field.is_some() {
312                self.construct_implicit_fields_with_source()
313            } else {
314                self.construct_implicit_fields()
315            };
316
317            let (source_ty, transform_source, transfer_source_field) = match source_field {
318                Some(source_field) => {
319                    let SourceInfo {
320                        source_field_type,
321                        transform_source,
322                        transfer_source_field,
323                    } = build_source_info(source_field);
324                    (
325                        quote! { #source_field_type },
326                        Some(transform_source),
327                        Some(transfer_source_field),
328                    )
329                }
330                None => (quote! { #crate_root::NoneError }, None, None),
331            };
332
333            let track_caller = track_caller();
334
335            quote! {
336                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> #crate_root::IntoError<#parameterized_error_name> for #parameterized_selector_name
337                where
338                    #parameterized_error_name: #crate_root::Error + #crate_root::ErrorCompat,
339                    #(#extended_where_clauses),*
340                {
341                    type Source = #source_ty;
342
343                    #track_caller
344                    fn into_error(self, error: Self::Source) -> #parameterized_error_name {
345                        #transform_source;
346                        #error_constructor_name {
347                            #construct_implicit_fields
348                            #transfer_source_field
349                            #(#transfer_user_fields),*
350                        }
351                    }
352                }
353            }
354        }
355
356        fn generate_whatever(
357            self,
358            source_field: Option<&crate::SourceField>,
359            message_field: &crate::Field,
360        ) -> TokenStream {
361            let crate_root = self.crate_root;
362            let parameterized_error_name = self.parameterized_error_name;
363            let error_constructor_name = self.error_constructor_name;
364            let construct_implicit_fields = self.construct_implicit_fields();
365            let construct_implicit_fields_with_source =
366                self.construct_implicit_fields_with_source();
367
368            // testme: transform
369
370            let (source_ty, transfer_source_field, empty_source_field) = match source_field {
371                Some(f) => {
372                    let source_field_type = f.transformation.source_ty();
373                    let source_field_name = &f.name;
374                    let source_transformation = f.transformation.transformation();
375
376                    (
377                        quote! { #source_field_type },
378                        Some(quote! { #source_field_name: (#source_transformation)(error), }),
379                        Some(quote! { #source_field_name: core::option::Option::None, }),
380                    )
381                }
382                None => (quote! { #crate_root::NoneError }, None, None),
383            };
384
385            let message_field_name = &message_field.name;
386
387            let track_caller = track_caller();
388
389            quote! {
390                impl #crate_root::FromString for #parameterized_error_name {
391                    type Source = #source_ty;
392
393                    #track_caller
394                    fn without_source(message: String) -> Self {
395                        #error_constructor_name {
396                            #construct_implicit_fields
397                            #empty_source_field
398                            #message_field_name: message,
399                        }
400                    }
401
402                    #track_caller
403                    fn with_source(error: Self::Source, message: String) -> Self {
404                        #error_constructor_name {
405                            #construct_implicit_fields_with_source
406                            #transfer_source_field
407                            #message_field_name: message,
408                        }
409                    }
410                }
411            }
412        }
413
414        fn generate_from_source(self, source_field: &crate::SourceField) -> TokenStream {
415            let parameterized_error_name = self.parameterized_error_name;
416            let error_constructor_name = self.error_constructor_name;
417            let construct_implicit_fields_with_source =
418                self.construct_implicit_fields_with_source();
419            let original_generics_without_defaults = self.original_generics_without_defaults;
420            let user_field_generics = self.user_field_generics();
421            let where_clauses = self.where_clauses;
422
423            let SourceInfo {
424                source_field_type,
425                transform_source,
426                transfer_source_field,
427            } = build_source_info(source_field);
428
429            let track_caller = track_caller();
430
431            quote! {
432                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> ::core::convert::From<#source_field_type> for #parameterized_error_name
433                where
434                    #(#where_clauses),*
435                {
436                    #track_caller
437                    fn from(error: #source_field_type) -> Self {
438                        #transform_source;
439                        #error_constructor_name {
440                            #construct_implicit_fields_with_source
441                            #transfer_source_field
442                        }
443                    }
444                }
445            }
446        }
447    }
448
449    struct SourceInfo<'a> {
450        source_field_type: &'a syn::Type,
451        transform_source: TokenStream,
452        transfer_source_field: TokenStream,
453    }
454
455    // Assumes that the error is in a variable called "error"
456    fn build_source_info(source_field: &crate::SourceField) -> SourceInfo<'_> {
457        let source_field_name = source_field.name();
458        let source_field_type = source_field.transformation.source_ty();
459        let target_field_type = source_field.transformation.target_ty();
460        let source_transformation = source_field.transformation.transformation();
461
462        let transform_source =
463            quote! { let error: #target_field_type = (#source_transformation)(error) };
464        let transfer_source_field = quote! { #source_field_name: error, };
465
466        SourceInfo {
467            source_field_type,
468            transform_source,
469            transfer_source_field,
470        }
471    }
472
473    fn track_caller() -> proc_macro2::TokenStream {
474        if cfg!(feature = "rust_1_46") {
475            quote::quote! { #[track_caller] }
476        } else {
477            quote::quote! {}
478        }
479    }
480}
481
482pub mod display {
483    use super::StaticIdent;
484    use proc_macro2::TokenStream;
485    use quote::{quote, ToTokens};
486    use std::collections::BTreeSet;
487
488    const FORMATTER_ARG: StaticIdent = StaticIdent("__snafu_display_formatter");
489
490    pub(crate) struct Display<'a> {
491        pub(crate) arms: &'a [TokenStream],
492        pub(crate) original_generics: &'a [TokenStream],
493        pub(crate) parameterized_error_name: &'a dyn ToTokens,
494        pub(crate) where_clauses: &'a [TokenStream],
495    }
496
497    impl ToTokens for Display<'_> {
498        fn to_tokens(&self, stream: &mut TokenStream) {
499            let Self {
500                arms,
501                original_generics,
502                parameterized_error_name,
503                where_clauses,
504            } = *self;
505
506            let display_impl = quote! {
507                #[allow(single_use_lifetimes)]
508                impl<#(#original_generics),*> ::core::fmt::Display for #parameterized_error_name
509                where
510                    #(#where_clauses),*
511                {
512                    fn fmt(&self, #FORMATTER_ARG: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
513                        #[allow(unused_variables)]
514                        match *self {
515                            #(#arms),*
516                        }
517                    }
518                }
519            };
520
521            stream.extend(display_impl);
522        }
523    }
524
525    pub(crate) struct DisplayMatchArm<'a> {
526        pub(crate) field_container: &'a crate::FieldContainer,
527        pub(crate) default_name: &'a dyn ToTokens,
528        pub(crate) display_format: Option<&'a crate::Display>,
529        pub(crate) doc_comment: Option<&'a crate::DocComment>,
530        pub(crate) pattern_ident: &'a dyn ToTokens,
531        pub(crate) selector_kind: &'a crate::ContextSelectorKind,
532    }
533
534    impl ToTokens for DisplayMatchArm<'_> {
535        fn to_tokens(&self, stream: &mut TokenStream) {
536            let Self {
537                field_container,
538                default_name,
539                display_format,
540                doc_comment,
541                pattern_ident,
542                selector_kind,
543            } = *self;
544
545            let source_field = selector_kind.source_field();
546
547            let mut shorthand_names = &BTreeSet::new();
548            let mut assigned_names = &BTreeSet::new();
549
550            let format = match (display_format, doc_comment, source_field) {
551                (Some(v), _, _) => {
552                    let exprs = &v.exprs;
553                    shorthand_names = &v.shorthand_names;
554                    assigned_names = &v.assigned_names;
555                    quote! { #(#exprs),* }
556                }
557                (_, Some(d), _) => {
558                    let content = &d.content;
559                    shorthand_names = &d.shorthand_names;
560                    quote! { #content }
561                }
562                (_, _, Some(f)) => {
563                    let field_name = &f.name;
564                    quote! { concat!(stringify!(#default_name), ": {}"), #field_name }
565                }
566                _ => quote! { stringify!(#default_name)},
567            };
568
569            let field_names = super::AllFieldNames(field_container).field_names();
570
571            let shorthand_names = shorthand_names.iter().collect::<BTreeSet<_>>();
572            let assigned_names = assigned_names.iter().collect::<BTreeSet<_>>();
573
574            let shorthand_fields = &shorthand_names & &field_names;
575            let shorthand_fields = &shorthand_fields - &assigned_names;
576
577            let shorthand_assignments = quote! { #( #shorthand_fields = #shorthand_fields ),* };
578
579            let match_arm = quote! {
580                #pattern_ident { #(ref #field_names),* } => {
581                    write!(#FORMATTER_ARG, #format, #shorthand_assignments)
582                }
583            };
584
585            stream.extend(match_arm);
586        }
587    }
588}
589
590pub mod error {
591    use super::StaticIdent;
592    use crate::{FieldContainer, Provide, SourceField};
593    use proc_macro2::TokenStream;
594    use quote::{format_ident, quote, ToTokens};
595
596    pub(crate) const PROVIDE_ARG: StaticIdent = StaticIdent("__snafu_provide_demand");
597
598    pub(crate) struct Error<'a> {
599        pub(crate) crate_root: &'a dyn ToTokens,
600        pub(crate) description_arms: &'a [TokenStream],
601        pub(crate) original_generics: &'a [TokenStream],
602        pub(crate) parameterized_error_name: &'a dyn ToTokens,
603        pub(crate) provide_arms: &'a [TokenStream],
604        pub(crate) source_arms: &'a [TokenStream],
605        pub(crate) where_clauses: &'a [TokenStream],
606    }
607
608    impl ToTokens for Error<'_> {
609        fn to_tokens(&self, stream: &mut TokenStream) {
610            let Self {
611                crate_root,
612                description_arms,
613                original_generics,
614                parameterized_error_name,
615                provide_arms,
616                source_arms,
617                where_clauses,
618            } = *self;
619
620            let description_fn = quote! {
621                fn description(&self) -> &str {
622                    match *self {
623                        #(#description_arms)*
624                    }
625                }
626            };
627
628            let source_body = quote! {
629                use #crate_root::AsErrorSource;
630                match *self {
631                    #(#source_arms)*
632                }
633            };
634
635            let cause_fn = quote! {
636                fn cause(&self) -> ::core::option::Option<&dyn #crate_root::Error> {
637                    #source_body
638                }
639            };
640
641            let source_fn = quote! {
642                fn source(&self) -> ::core::option::Option<&(dyn #crate_root::Error + 'static)> {
643                    #source_body
644                }
645            };
646
647            let std_backtrace_fn = if cfg!(feature = "unstable-backtraces-impl-std") {
648                Some(quote! {
649                    fn backtrace(&self) -> ::core::option::Option<&::std::backtrace::Backtrace> {
650                        #crate_root::ErrorCompat::backtrace(self)
651                    }
652                })
653            } else {
654                None
655            };
656
657            let provide_fn = if cfg!(feature = "unstable-provider-api") {
658                Some(quote! {
659                    fn provide<'a>(&'a self, #PROVIDE_ARG: &mut core::any::Demand<'a>) {
660                        match *self {
661                            #(#provide_arms,)*
662                        };
663                    }
664                })
665            } else {
666                None
667            };
668
669            let error = quote! {
670                #[allow(single_use_lifetimes)]
671                impl<#(#original_generics),*> #crate_root::Error for #parameterized_error_name
672                where
673                    Self: ::core::fmt::Debug + ::core::fmt::Display,
674                    #(#where_clauses),*
675                {
676                    #description_fn
677                    #cause_fn
678                    #source_fn
679                    #std_backtrace_fn
680                    #provide_fn
681                }
682            };
683
684            stream.extend(error);
685        }
686    }
687
688    pub(crate) struct ErrorSourceMatchArm<'a> {
689        pub(crate) field_container: &'a FieldContainer,
690        pub(crate) pattern_ident: &'a dyn ToTokens,
691    }
692
693    impl ToTokens for ErrorSourceMatchArm<'_> {
694        fn to_tokens(&self, stream: &mut TokenStream) {
695            let Self {
696                field_container: FieldContainer { selector_kind, .. },
697                pattern_ident,
698            } = *self;
699
700            let source_field = selector_kind.source_field();
701
702            let arm = match source_field {
703                Some(source_field) => {
704                    let SourceField {
705                        name: field_name, ..
706                    } = source_field;
707
708                    let convert_to_error_source = if selector_kind.is_whatever() {
709                        quote! {
710                            #field_name.as_ref().map(|e| e.as_error_source())
711                        }
712                    } else {
713                        quote! {
714                            ::core::option::Option::Some(#field_name.as_error_source())
715                        }
716                    };
717
718                    quote! {
719                        #pattern_ident { ref #field_name, .. } => {
720                            #convert_to_error_source
721                        }
722                    }
723                }
724                None => {
725                    quote! {
726                        #pattern_ident { .. } => { ::core::option::Option::None }
727                    }
728                }
729            };
730
731            stream.extend(arm);
732        }
733    }
734
735    pub(crate) struct ProvidePlus<'a> {
736        provide: &'a Provide,
737        cached_name: proc_macro2::Ident,
738    }
739
740    pub(crate) struct ErrorProvideMatchArm<'a> {
741        pub(crate) crate_root: &'a dyn ToTokens,
742        pub(crate) field_container: &'a FieldContainer,
743        pub(crate) pattern_ident: &'a dyn ToTokens,
744    }
745
746    impl<'a> ToTokens for ErrorProvideMatchArm<'a> {
747        fn to_tokens(&self, stream: &mut TokenStream) {
748            let Self {
749                crate_root,
750                field_container,
751                pattern_ident,
752            } = *self;
753
754            let user_fields = field_container.user_fields();
755            let provides = enhance_provider_list(field_container.provides());
756            let field_names = super::AllFieldNames(field_container).field_names();
757
758            let (hi_explicit_calls, lo_explicit_calls) = build_explicit_provide_calls(&provides);
759
760            let cached_expressions = quote_cached_expressions(&provides);
761
762            let provide_refs = user_fields
763                .iter()
764                .chain(&field_container.implicit_fields)
765                .chain(field_container.selector_kind.message_field())
766                .flat_map(|f| {
767                    if f.provide {
768                        Some((&f.ty, f.name()))
769                    } else {
770                        None
771                    }
772                });
773
774            let provided_source = field_container
775                .selector_kind
776                .source_field()
777                .filter(|f| f.provide);
778
779            let source_provide_ref =
780                provided_source.map(|f| (f.transformation.source_ty(), f.name()));
781
782            let provide_refs = provide_refs.chain(source_provide_ref);
783
784            let source_chain = provided_source.map(|f| {
785                let name = f.name();
786                quote! {
787                    #name.provide(#PROVIDE_ARG);
788                }
789            });
790
791            let user_chained = quote_chained(&provides);
792
793            let shorthand_calls = provide_refs.map(|(ty, name)| {
794                quote! { #PROVIDE_ARG.provide_ref::<#ty>(#name) }
795            });
796
797            let provided_backtrace = field_container
798                .backtrace_field
799                .as_ref()
800                .filter(|f| f.provide);
801
802            let provide_backtrace = provided_backtrace.map(|f| {
803                let name = f.name();
804                quote! {
805                    if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#crate_root::Backtrace>() {
806                        if let ::core::option::Option::Some(bt) = #crate_root::AsBacktrace::as_backtrace(#name) {
807                            #PROVIDE_ARG.provide_ref::<#crate_root::Backtrace>(bt);
808                        }
809                    }
810                }
811            });
812
813            let arm = quote! {
814                #pattern_ident { #(ref #field_names,)* .. } => {
815                    #(#cached_expressions;)*
816                    #(#hi_explicit_calls;)*
817                    #source_chain;
818                    #(#user_chained;)*
819                    #provide_backtrace;
820                    #(#shorthand_calls;)*
821                    #(#lo_explicit_calls;)*
822                }
823            };
824
825            stream.extend(arm);
826        }
827    }
828
829    pub(crate) fn enhance_provider_list<'a>(provides: &'a [Provide]) -> Vec<ProvidePlus<'a>> {
830        provides
831            .iter()
832            .enumerate()
833            .map(|(i, provide)| {
834                let cached_name = format_ident!("__snafu_cached_expr_{}", i);
835                ProvidePlus {
836                    provide,
837                    cached_name,
838                }
839            })
840            .collect()
841    }
842
843    pub(crate) fn quote_cached_expressions<'a>(
844        provides: &'a [ProvidePlus<'a>],
845    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
846        provides.iter().filter(|pp| pp.provide.is_chain).map(|pp| {
847            let cached_name = &pp.cached_name;
848            let expr = &pp.provide.expr;
849
850            quote! {
851                let #cached_name = #expr;
852            }
853        })
854    }
855
856    pub(crate) fn quote_chained<'a>(
857        provides: &'a [ProvidePlus<'a>],
858    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
859        provides.iter().filter(|pp| pp.provide.is_chain).map(|pp| {
860            let arm = if pp.provide.is_opt {
861                quote! { ::core::option::Option::Some(chained_item) }
862            } else {
863                quote! { chained_item }
864            };
865            let cached_name = &pp.cached_name;
866
867            quote! {
868                if let #arm = #cached_name {
869                    ::core::any::Provider::provide(chained_item, #PROVIDE_ARG);
870                }
871            }
872        })
873    }
874
875    fn quote_provides<'a, I>(provides: I) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a
876    where
877        I: IntoIterator<Item = &'a ProvidePlus<'a>>,
878        I::IntoIter: 'a,
879    {
880        provides.into_iter().map(|pp| {
881            let ProvidePlus {
882                provide:
883                    Provide {
884                        is_chain,
885                        is_opt,
886                        is_priority: _,
887                        is_ref,
888                        ty,
889                        expr,
890                    },
891                cached_name,
892            } = pp;
893
894            let effective_expr = if *is_chain {
895                quote! { #cached_name }
896            } else {
897                quote! { #expr }
898            };
899
900            match (is_opt, is_ref) {
901                (true, true) => {
902                    quote! {
903                        if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#ty>() {
904                            if let ::core::option::Option::Some(v) = #effective_expr {
905                                #PROVIDE_ARG.provide_ref::<#ty>(v);
906                            }
907                        }
908                    }
909                }
910                (true, false) => {
911                    quote! {
912                        if #PROVIDE_ARG.would_be_satisfied_by_value_of::<#ty>() {
913                            if let ::core::option::Option::Some(v) = #effective_expr {
914                                #PROVIDE_ARG.provide_value::<#ty>(v);
915                            }
916                        }
917                    }
918                }
919                (false, true) => {
920                    quote! { #PROVIDE_ARG.provide_ref_with::<#ty>(|| #effective_expr) }
921                }
922                (false, false) => {
923                    quote! { #PROVIDE_ARG.provide_value_with::<#ty>(|| #effective_expr) }
924                }
925            }
926        })
927    }
928
929    pub(crate) fn build_explicit_provide_calls<'a>(
930        provides: &'a [ProvidePlus<'a>],
931    ) -> (
932        impl Iterator<Item = TokenStream> + 'a,
933        impl Iterator<Item = TokenStream> + 'a,
934    ) {
935        let (high_priority, low_priority): (Vec<_>, Vec<_>) =
936            provides.iter().partition(|pp| pp.provide.is_priority);
937
938        let hi_explicit_calls = quote_provides(high_priority);
939        let lo_explicit_calls = quote_provides(low_priority);
940
941        (hi_explicit_calls, lo_explicit_calls)
942    }
943}
944
945pub mod error_compat {
946    use crate::{Field, FieldContainer, SourceField};
947    use proc_macro2::TokenStream;
948    use quote::{quote, ToTokens};
949
950    pub(crate) struct ErrorCompat<'a> {
951        pub(crate) crate_root: &'a dyn ToTokens,
952        pub(crate) parameterized_error_name: &'a dyn ToTokens,
953        pub(crate) backtrace_arms: &'a [TokenStream],
954        pub(crate) original_generics: &'a [TokenStream],
955        pub(crate) where_clauses: &'a [TokenStream],
956    }
957
958    impl ToTokens for ErrorCompat<'_> {
959        fn to_tokens(&self, stream: &mut TokenStream) {
960            let Self {
961                crate_root,
962                parameterized_error_name,
963                backtrace_arms,
964                original_generics,
965                where_clauses,
966            } = *self;
967
968            let backtrace_fn = quote! {
969                fn backtrace(&self) -> ::core::option::Option<&#crate_root::Backtrace> {
970                    match *self {
971                        #(#backtrace_arms),*
972                    }
973                }
974            };
975
976            let error_compat_impl = quote! {
977                #[allow(single_use_lifetimes)]
978                impl<#(#original_generics),*> #crate_root::ErrorCompat for #parameterized_error_name
979                where
980                    #(#where_clauses),*
981                {
982                    #backtrace_fn
983                }
984            };
985
986            stream.extend(error_compat_impl);
987        }
988    }
989
990    pub(crate) struct ErrorCompatBacktraceMatchArm<'a> {
991        pub(crate) crate_root: &'a dyn ToTokens,
992        pub(crate) field_container: &'a FieldContainer,
993        pub(crate) pattern_ident: &'a dyn ToTokens,
994    }
995
996    impl ToTokens for ErrorCompatBacktraceMatchArm<'_> {
997        fn to_tokens(&self, stream: &mut TokenStream) {
998            let Self {
999                crate_root,
1000                field_container:
1001                    FieldContainer {
1002                        backtrace_field,
1003                        selector_kind,
1004                        ..
1005                    },
1006                pattern_ident,
1007            } = *self;
1008
1009            let match_arm = match (selector_kind.source_field(), backtrace_field) {
1010                (Some(source_field), _) if source_field.backtrace_delegate => {
1011                    let SourceField {
1012                        name: field_name, ..
1013                    } = source_field;
1014                    quote! {
1015                        #pattern_ident { ref #field_name, .. } => { #crate_root::ErrorCompat::backtrace(#field_name) }
1016                    }
1017                }
1018                (_, Some(backtrace_field)) => {
1019                    let Field {
1020                        name: field_name, ..
1021                    } = backtrace_field;
1022                    quote! {
1023                        #pattern_ident { ref #field_name, .. } => { #crate_root::AsBacktrace::as_backtrace(#field_name) }
1024                    }
1025                }
1026                _ => {
1027                    quote! {
1028                        #pattern_ident { .. } => { ::core::option::Option::None }
1029                    }
1030                }
1031            };
1032
1033            stream.extend(match_arm);
1034        }
1035    }
1036}