1#![deny(clippy::pedantic)]
2#![cfg_attr(feature = "diagnostics", feature(proc_macro_diagnostic))]
3
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{
11 Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Lit, Token, Variant,
12 parse_macro_input,
13};
14
15fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
16 let is_default = variant
17 .attrs
18 .iter()
19 .any(|attr| attr.path().is_ident("encodable_default"));
20
21 if is_default {
22 assert_eq!(
23 variant.ident.to_string(),
24 "Default",
25 "Default variant should be called `Default`"
26 );
27 let two_fields = variant.fields.len() == 2;
28 let field_names = variant
29 .fields
30 .iter()
31 .filter_map(|field| field.ident.as_ref().map(ToString::to_string))
32 .sorted()
33 .collect::<Vec<_>>();
34 let correct_fields = field_names == vec!["bytes".to_string(), "variant".to_string()];
35
36 assert!(
37 two_fields && correct_fields,
38 "The default variant should have exactly two field: `variant: u64` and `bytes: Vec<u8>`"
39 );
40 }
41
42 is_default
43}
44
45#[proc_macro_derive(Encodable, attributes(encodable_default, encodable))]
48pub fn derive_encodable(input: TokenStream) -> TokenStream {
49 let DeriveInput {
50 ident,
51 data,
52 generics,
53 ..
54 } = parse_macro_input!(input);
55
56 let encode_inner = match data {
57 Data::Struct(DataStruct { fields, .. }) => derive_struct_encode(&fields),
58 Data::Enum(DataEnum { variants, .. }) => derive_enum_encode(&ident, &variants),
59 Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
60 };
61 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
62
63 let output = quote! {
64 impl #impl_generics ::fedimint_core::encoding::Encodable for #ident #ty_generics #where_clause {
65 #[allow(deprecated)]
66 fn consensus_encode<W: std::io::Write>(&self, mut writer: &mut W) -> std::result::Result<(), std::io::Error> {
67 #encode_inner
68 }
69 }
70 };
71
72 output.into()
73}
74
75fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
76 if is_tuple_struct(fields) {
77 let field_names = fields
79 .iter()
80 .enumerate()
81 .map(|(idx, _)| Index::from(idx))
82 .collect::<Vec<_>>();
83 quote! {
84 #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
85 Ok(())
86 }
87 } else {
88 let field_names = fields
90 .iter()
91 .map(|field| field.ident.clone().unwrap())
92 .collect::<Vec<_>>();
93 quote! {
94 #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
95 Ok(())
96 }
97 }
98}
99
100fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
103 attributes.iter().find_map(|attr| {
104 if attr.path().is_ident("encodable") {
105 attr.parse_args_with(|input: syn::parse::ParseStream| {
106 input.parse::<syn::Ident>()?.span(); input.parse::<Token![=]>()?; if let Lit::Int(lit_int) = input.parse::<Lit>()? {
109 lit_int.base10_parse()
110 } else {
111 Err(input.error("Expected an integer for 'index'"))
112 }
113 })
114 .ok()
115 } else {
116 None
117 }
118 })
119}
120
121fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
124 input_variants
125 .into_iter()
126 .map(|variant| {
127 let index = parse_index_attribute(&variant.attrs);
128 (index, variant)
129 })
130 .collect()
131}
132
133fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
134 let non_default_variants = variants
135 .into_iter()
136 .filter(|variant| !is_default_variant_enforce_valid(variant))
137 .cloned()
138 .collect::<Vec<_>>();
139
140 let attr_indices = extract_variants_with_indices(non_default_variants.clone());
141
142 let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
143 let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());
144
145 assert!(
146 all_have_index || none_have_index,
147 "Either all or none of the variants should have an index annotation"
148 );
149
150 if all_have_index {
151 attr_indices
152 .into_iter()
153 .map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
154 .collect()
155 } else {
156 non_default_variants
157 .into_iter()
158 .enumerate()
159 .map(|(idx, variant)| (idx as u64, variant))
160 .collect()
161 }
162}
163
164fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
165 if variants.is_empty() {
166 return quote! {
167 match *self {}
168 };
169 }
170
171 let non_default_match_arms =
172 non_default_variant_indices(variants)
173 .into_iter()
174 .map(|(variant_idx, variant)| {
175 let variant_ident = variant.ident.clone();
176
177 if is_tuple_struct(&variant.fields) {
178 let variant_fields = variant
179 .fields
180 .iter()
181 .enumerate()
182 .map(|(idx, _)| format_ident!("bound_{}", idx))
183 .collect::<Vec<_>>();
184 let variant_encode_block =
185 derive_enum_variant_encode_block(variant_idx, &variant_fields);
186 quote! {
187 #ident::#variant_ident(#(#variant_fields,)*) => {
188 #variant_encode_block
189 }
190 }
191 } else {
192 let variant_fields = variant
193 .fields
194 .iter()
195 .map(|field| field.ident.clone().unwrap())
196 .collect::<Vec<_>>();
197 let variant_encode_block =
198 derive_enum_variant_encode_block(variant_idx, &variant_fields);
199 quote! {
200 #ident::#variant_ident { #(#variant_fields,)*} => {
201 #variant_encode_block
202 }
203 }
204 }
205 });
206
207 let default_match_arm = variants
208 .iter()
209 .find(|variant| is_default_variant_enforce_valid(variant))
210 .map(|_variant| {
211 quote! {
212 #ident::Default { variant, bytes } => {
213 ::fedimint_core::encoding::Encodable::consensus_encode(variant, writer)?;
214 ::fedimint_core::encoding::Encodable::consensus_encode(bytes, writer)?;
215 }
216 }
217 });
218
219 let match_arms = non_default_match_arms.chain(default_match_arm);
220
221 quote! {
222 match self {
223 #(#match_arms)*
224 }
225 Ok(())
226 }
227}
228
229fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
230 quote! {
231 ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;
232
233 let mut bytes = Vec::<u8>::new();
234 #(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
235
236 ::fedimint_core::encoding::Encodable::consensus_encode(&bytes, writer)?;
237 }
238}
239
240#[proc_macro_derive(Decodable)]
241pub fn derive_decodable(input: TokenStream) -> TokenStream {
242 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
243
244 let decode_inner = match data {
245 Data::Struct(DataStruct { fields, .. }) => derive_struct_decode(&ident, &fields),
246 syn::Data::Enum(DataEnum { variants, .. }) => derive_enum_decode(&ident, &variants),
247 syn::Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
248 };
249
250 let output = quote! {
251 #[allow(deprecated)]
252 impl ::fedimint_core::encoding::Decodable for #ident {
253 fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(d: &mut D, modules: &::fedimint_core::module::registry::ModuleDecoderRegistry) -> std::result::Result<Self, ::fedimint_core::encoding::DecodeError> {
254 use ::fedimint_core:: anyhow::Context;
255 #decode_inner
256 }
257 }
258 };
259
260 output.into()
261}
262
263#[allow(unused_variables, unreachable_code)]
264fn error(ident: &Ident, message: &str) -> TokenStream2 {
265 #[cfg(feature = "diagnostics")]
266 ident.span().unstable().error(message).emit();
267 #[cfg(not(feature = "diagnostics"))]
268 panic!("{message}");
269
270 TokenStream2::new()
271}
272
273fn derive_struct_decode(ident: &Ident, fields: &Fields) -> TokenStream2 {
274 let decode_block =
275 derive_tuple_or_named_decode_block(ident, "e! { #ident }, "e! { d }, fields);
276
277 quote! {
278 Ok(#decode_block)
279 }
280}
281
282fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
283 if variants.is_empty() {
284 return quote! {
285 Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Enum without variants can't be instantiated")))
286 };
287 }
288
289 let non_default_match_arms = non_default_variant_indices(variants).into_iter()
290 .map(|(variant_idx, variant)| {
291 let variant_ident = variant.ident.clone();
292 let decode_block = derive_tuple_or_named_decode_block(
293 ident,
294 "e! { #ident::#variant_ident },
295 "e! { &mut cursor },
296 &variant.fields,
297 );
298
299 quote! {
301 #variant_idx => {
302 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
304 .context(concat!(
305 "Decoding bytes of ",
306 stringify!(#ident)
307 ))?;
308 let mut cursor = std::io::Cursor::new(&bytes);
309
310 let decoded = anyhow::Context::context(
311 (|| -> anyhow::Result<_> {
312 Ok(#decode_block)
313 })(), concat!("Decoding variant ", stringify!(#variant_ident), " (idx: ", #variant_idx, ")"))?;
314
315 let read_bytes = cursor.position();
316 let total_bytes = bytes.len() as u64;
317 if read_bytes != total_bytes {
318 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!(
319 "Partial read: got {total_bytes} bytes but only read {read_bytes} when decoding {}",
320 concat!(
321 stringify!(#ident),
322 "::",
323 stringify!(#variant)
324 )
325 )));
326 }
327
328 decoded
329 }
330 }
331 });
332
333 let default_match_arm = if variants.iter().any(is_default_variant_enforce_valid) {
334 quote! {
335 variant => {
336 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
337 .context(concat!(
338 "Decoding default variant of ",
339 stringify!(#ident)
340 ))?;
341
342 #ident::Default {
343 variant,
344 bytes
345 }
346 }
347 }
348 } else {
349 quote! {
350 variant => {
351 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Invalid enum variant {} while decoding {}", variant, stringify!(#ident))));
352 }
353 }
354 };
355
356 quote! {
357 let variant = <u64 as ::fedimint_core::encoding::Decodable>::consensus_decode_partial_from_finite_reader(d, modules)
358 .context(concat!(
359 "Decoding default variant of ",
360 stringify!(#ident)
361 ))?;
362
363 let decoded = match variant {
364 #(#non_default_match_arms)*
365 #default_match_arm
366 };
367 Ok(decoded)
368 }
369}
370
371fn is_tuple_struct(fields: &Fields) -> bool {
372 fields.iter().any(|field| field.ident.is_none())
373}
374
375fn derive_tuple_or_named_decode_block(
380 ident: &Ident,
381 constructor: &TokenStream2,
382 reader: &TokenStream2,
383 fields: &Fields,
384) -> TokenStream2 {
385 if is_tuple_struct(fields) {
386 derive_tuple_decode_block(ident, constructor, reader, fields)
387 } else {
388 derive_named_decode_block(ident, constructor, reader, fields)
389 }
390}
391
392fn derive_tuple_decode_block(
393 ident: &Ident,
394 constructor: &TokenStream2,
395 reader: &TokenStream2,
396 fields: &Fields,
397) -> TokenStream2 {
398 let field_names = fields
399 .iter()
400 .enumerate()
401 .map(|(idx, _)| format_ident!("field_{}", idx))
402 .collect::<Vec<_>>();
403 quote! {
404 {
405 #(
406 let #field_names = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
407 .context(concat!(
408 "Decoding tuple block ",
409 stringify!(#ident),
410 " field ",
411 stringify!(#field_names),
412 ))?;
413 )*
414 #constructor(#(#field_names,)*)
415 }
416 }
417}
418
419fn derive_named_decode_block(
420 ident: &Ident,
421 constructor: &TokenStream2,
422 reader: &TokenStream2,
423 fields: &Fields,
424) -> TokenStream2 {
425 let variant_fields = fields
426 .iter()
427 .map(|field| field.ident.clone().unwrap())
428 .collect::<Vec<_>>();
429 quote! {
430 {
431 #(
432 let #variant_fields = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
433 .context(concat!(
434 "Decoding named block field: ",
435 stringify!(#ident),
436 "{ ... ",
437 stringify!(#variant_fields),
438 " ... }",
439 ))?;
440 )*
441 #constructor{
442 #(#variant_fields,)*
443 }
444 }
445 }
446}