fedimint_derive/
lib.rs

1#![deny(clippy::pedantic)]
2#![cfg_attr(feature = "diagnostics", feature(proc_macro_diagnostic))]
3
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{
11    Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Lit, Token, Variant,
12    parse_macro_input,
13};
14
15fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
16    let is_default = variant
17        .attrs
18        .iter()
19        .any(|attr| attr.path().is_ident("encodable_default"));
20
21    if is_default {
22        assert_eq!(
23            variant.ident.to_string(),
24            "Default",
25            "Default variant should be called `Default`"
26        );
27        let two_fields = variant.fields.len() == 2;
28        let field_names = variant
29            .fields
30            .iter()
31            .filter_map(|field| field.ident.as_ref().map(ToString::to_string))
32            .sorted()
33            .collect::<Vec<_>>();
34        let correct_fields = field_names == vec!["bytes".to_string(), "variant".to_string()];
35
36        assert!(
37            two_fields && correct_fields,
38            "The default variant should have exactly two field: `variant: u64` and `bytes: Vec<u8>`"
39        );
40    }
41
42    is_default
43}
44
45// TODO: use encodable attr for everything: #[encodable(index = 42)],
46// #[encodable(default)], …
47#[proc_macro_derive(Encodable, attributes(encodable_default, encodable))]
48pub fn derive_encodable(input: TokenStream) -> TokenStream {
49    let DeriveInput {
50        ident,
51        data,
52        generics,
53        ..
54    } = parse_macro_input!(input);
55
56    let encode_inner = match data {
57        Data::Struct(DataStruct { fields, .. }) => derive_struct_encode(&fields),
58        Data::Enum(DataEnum { variants, .. }) => derive_enum_encode(&ident, &variants),
59        Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
60    };
61    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
62
63    let output = quote! {
64        impl #impl_generics ::fedimint_core::encoding::Encodable for #ident #ty_generics #where_clause {
65            #[allow(deprecated)]
66            fn consensus_encode<W: std::io::Write>(&self, mut writer: &mut W) -> std::result::Result<(), std::io::Error> {
67                #encode_inner
68            }
69        }
70    };
71
72    output.into()
73}
74
75fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
76    if is_tuple_struct(fields) {
77        // Tuple struct
78        let field_names = fields
79            .iter()
80            .enumerate()
81            .map(|(idx, _)| Index::from(idx))
82            .collect::<Vec<_>>();
83        quote! {
84            #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
85            Ok(())
86        }
87    } else {
88        // Named struct
89        let field_names = fields
90            .iter()
91            .map(|field| field.ident.clone().unwrap())
92            .collect::<Vec<_>>();
93        quote! {
94            #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
95            Ok(())
96        }
97    }
98}
99
100/// Extracts the u64 index from an attribute if it matches `#[encodable(index =
101/// <u64>)]`.
102fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
103    attributes.iter().find_map(|attr| {
104        if attr.path().is_ident("encodable") {
105            attr.parse_args_with(|input: syn::parse::ParseStream| {
106                input.parse::<syn::Ident>()?.span(); // consume the ident 'index'
107                input.parse::<Token![=]>()?; // consume the '='
108                if let Lit::Int(lit_int) = input.parse::<Lit>()? {
109                    lit_int.base10_parse()
110                } else {
111                    Err(input.error("Expected an integer for 'index'"))
112                }
113            })
114            .ok()
115        } else {
116            None
117        }
118    })
119}
120
121/// Processes all variants in a `Punctuated` list extracting any specified
122/// index.
123fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
124    input_variants
125        .into_iter()
126        .map(|variant| {
127            let index = parse_index_attribute(&variant.attrs);
128            (index, variant)
129        })
130        .collect()
131}
132
133fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
134    let non_default_variants = variants
135        .into_iter()
136        .filter(|variant| !is_default_variant_enforce_valid(variant))
137        .cloned()
138        .collect::<Vec<_>>();
139
140    let attr_indices = extract_variants_with_indices(non_default_variants.clone());
141
142    let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
143    let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());
144
145    assert!(
146        all_have_index || none_have_index,
147        "Either all or none of the variants should have an index annotation"
148    );
149
150    if all_have_index {
151        attr_indices
152            .into_iter()
153            .map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
154            .collect()
155    } else {
156        non_default_variants
157            .into_iter()
158            .enumerate()
159            .map(|(idx, variant)| (idx as u64, variant))
160            .collect()
161    }
162}
163
164fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
165    if variants.is_empty() {
166        return quote! {
167            match *self {}
168        };
169    }
170
171    let non_default_match_arms =
172        non_default_variant_indices(variants)
173            .into_iter()
174            .map(|(variant_idx, variant)| {
175                let variant_ident = variant.ident.clone();
176
177                if is_tuple_struct(&variant.fields) {
178                    let variant_fields = variant
179                        .fields
180                        .iter()
181                        .enumerate()
182                        .map(|(idx, _)| format_ident!("bound_{}", idx))
183                        .collect::<Vec<_>>();
184                    let variant_encode_block =
185                        derive_enum_variant_encode_block(variant_idx, &variant_fields);
186                    quote! {
187                        #ident::#variant_ident(#(#variant_fields,)*) => {
188                            #variant_encode_block
189                        }
190                    }
191                } else {
192                    let variant_fields = variant
193                        .fields
194                        .iter()
195                        .map(|field| field.ident.clone().unwrap())
196                        .collect::<Vec<_>>();
197                    let variant_encode_block =
198                        derive_enum_variant_encode_block(variant_idx, &variant_fields);
199                    quote! {
200                        #ident::#variant_ident { #(#variant_fields,)*} => {
201                            #variant_encode_block
202                        }
203                    }
204                }
205            });
206
207    let default_match_arm = variants
208        .iter()
209        .find(|variant| is_default_variant_enforce_valid(variant))
210        .map(|_variant| {
211            quote! {
212                #ident::Default { variant, bytes } => {
213                    ::fedimint_core::encoding::Encodable::consensus_encode(variant, writer)?;
214                    ::fedimint_core::encoding::Encodable::consensus_encode(bytes, writer)?;
215                }
216            }
217        });
218
219    let match_arms = non_default_match_arms.chain(default_match_arm);
220
221    quote! {
222        match self {
223            #(#match_arms)*
224        }
225        Ok(())
226    }
227}
228
229fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
230    quote! {
231        ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;
232
233        let mut bytes = Vec::<u8>::new();
234        #(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
235
236        ::fedimint_core::encoding::Encodable::consensus_encode(&bytes, writer)?;
237    }
238}
239
240#[proc_macro_derive(Decodable)]
241pub fn derive_decodable(input: TokenStream) -> TokenStream {
242    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
243
244    let decode_inner = match data {
245        Data::Struct(DataStruct { fields, .. }) => derive_struct_decode(&ident, &fields),
246        syn::Data::Enum(DataEnum { variants, .. }) => derive_enum_decode(&ident, &variants),
247        syn::Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
248    };
249
250    let output = quote! {
251        #[allow(deprecated)]
252        impl ::fedimint_core::encoding::Decodable for #ident {
253            fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(d: &mut D, modules: &::fedimint_core::module::registry::ModuleDecoderRegistry) -> std::result::Result<Self, ::fedimint_core::encoding::DecodeError> {
254                use ::fedimint_core:: anyhow::Context;
255                #decode_inner
256            }
257        }
258    };
259
260    output.into()
261}
262
263#[allow(unused_variables, unreachable_code)]
264fn error(ident: &Ident, message: &str) -> TokenStream2 {
265    #[cfg(feature = "diagnostics")]
266    ident.span().unstable().error(message).emit();
267    #[cfg(not(feature = "diagnostics"))]
268    panic!("{message}");
269
270    TokenStream2::new()
271}
272
273fn derive_struct_decode(ident: &Ident, fields: &Fields) -> TokenStream2 {
274    let decode_block =
275        derive_tuple_or_named_decode_block(ident, &quote! { #ident }, &quote! { d }, fields);
276
277    quote! {
278        Ok(#decode_block)
279    }
280}
281
282fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
283    if variants.is_empty() {
284        return quote! {
285            Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Enum without variants can't be instantiated")))
286        };
287    }
288
289    let non_default_match_arms = non_default_variant_indices(variants).into_iter()
290        .map(|(variant_idx, variant)| {
291            let variant_ident = variant.ident.clone();
292            let decode_block = derive_tuple_or_named_decode_block(
293                ident,
294                &quote! { #ident::#variant_ident },
295                &quote! { &mut cursor },
296                &variant.fields,
297            );
298
299            // FIXME: make sure we read all bytes
300            quote! {
301                #variant_idx => {
302                    // FIXME: feels like there's a way more elegant way to do this with limited readers
303                    let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
304                        .context(concat!(
305                            "Decoding bytes of ",
306                            stringify!(#ident)
307                        ))?;
308                    let mut cursor = std::io::Cursor::new(&bytes);
309
310                    let decoded = anyhow::Context::context(
311                        (|| -> anyhow::Result<_> {
312                          Ok(#decode_block)
313                        })(), concat!("Decoding variant ", stringify!(#variant_ident), " (idx: ", #variant_idx, ")"))?;
314
315                    let read_bytes = cursor.position();
316                    let total_bytes = bytes.len() as u64;
317                    if read_bytes != total_bytes {
318                        return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!(
319                            "Partial read: got {total_bytes} bytes but only read {read_bytes} when decoding {}",
320                            concat!(
321                                stringify!(#ident),
322                                "::",
323                                stringify!(#variant)
324                            )
325                        )));
326                    }
327
328                    decoded
329                }
330            }
331        });
332
333    let default_match_arm = if variants.iter().any(is_default_variant_enforce_valid) {
334        quote! {
335            variant => {
336                let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
337                    .context(concat!(
338                        "Decoding default variant of ",
339                        stringify!(#ident)
340                    ))?;
341
342                #ident::Default {
343                    variant,
344                    bytes
345                }
346            }
347        }
348    } else {
349        quote! {
350            variant => {
351                return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Invalid enum variant {} while decoding {}", variant, stringify!(#ident))));
352            }
353        }
354    };
355
356    quote! {
357        let variant = <u64 as ::fedimint_core::encoding::Decodable>::consensus_decode_partial_from_finite_reader(d, modules)
358            .context(concat!(
359                "Decoding default variant of ",
360                stringify!(#ident)
361            ))?;
362
363        let decoded = match variant {
364            #(#non_default_match_arms)*
365            #default_match_arm
366        };
367        Ok(decoded)
368    }
369}
370
371fn is_tuple_struct(fields: &Fields) -> bool {
372    fields.iter().any(|field| field.ident.is_none())
373}
374
375// TODO: how not to use token stream for constructor, but still support both:
376//   * Enum::Variant
377//   * Struct
378// as idents
379fn derive_tuple_or_named_decode_block(
380    ident: &Ident,
381    constructor: &TokenStream2,
382    reader: &TokenStream2,
383    fields: &Fields,
384) -> TokenStream2 {
385    if is_tuple_struct(fields) {
386        derive_tuple_decode_block(ident, constructor, reader, fields)
387    } else {
388        derive_named_decode_block(ident, constructor, reader, fields)
389    }
390}
391
392fn derive_tuple_decode_block(
393    ident: &Ident,
394    constructor: &TokenStream2,
395    reader: &TokenStream2,
396    fields: &Fields,
397) -> TokenStream2 {
398    let field_names = fields
399        .iter()
400        .enumerate()
401        .map(|(idx, _)| format_ident!("field_{}", idx))
402        .collect::<Vec<_>>();
403    quote! {
404        {
405            #(
406                let #field_names = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
407                    .context(concat!(
408                        "Decoding tuple block ",
409                        stringify!(#ident),
410                        " field ",
411                        stringify!(#field_names),
412                    ))?;
413            )*
414            #constructor(#(#field_names,)*)
415        }
416    }
417}
418
419fn derive_named_decode_block(
420    ident: &Ident,
421    constructor: &TokenStream2,
422    reader: &TokenStream2,
423    fields: &Fields,
424) -> TokenStream2 {
425    let variant_fields = fields
426        .iter()
427        .map(|field| field.ident.clone().unwrap())
428        .collect::<Vec<_>>();
429    quote! {
430        {
431            #(
432                let #variant_fields = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
433                    .context(concat!(
434                        "Decoding named block field: ",
435                        stringify!(#ident),
436                        "{ ... ",
437                        stringify!(#variant_fields),
438                        " ... }",
439                    ))?;
440            )*
441            #constructor{
442                #(#variant_fields,)*
443            }
444        }
445    }
446}