1#![deny(clippy::pedantic)]
2#![cfg_attr(feature = "diagnostics", feature(proc_macro_diagnostic))]
3
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{
11 Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Lit, Token, Variant,
12 parse_macro_input,
13};
14
15fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
16 let is_default = variant
17 .attrs
18 .iter()
19 .any(|attr| attr.path().is_ident("encodable_default"));
20
21 if is_default {
22 assert_eq!(
23 variant.ident.to_string(),
24 "Default",
25 "Default variant should be called `Default`"
26 );
27 let two_fields = variant.fields.len() == 2;
28 let field_names = variant
29 .fields
30 .iter()
31 .filter_map(|field| field.ident.as_ref().map(ToString::to_string))
32 .sorted()
33 .collect::<Vec<_>>();
34 let correct_fields = field_names == vec!["bytes".to_string(), "variant".to_string()];
35
36 assert!(
37 two_fields && correct_fields,
38 "The default variant should have exactly two field: `variant: u64` and `bytes: Vec<u8>`"
39 );
40 }
41
42 is_default
43}
44
45#[proc_macro_derive(Encodable, attributes(encodable_default, encodable))]
48pub fn derive_encodable(input: TokenStream) -> TokenStream {
49 let DeriveInput {
50 ident,
51 data,
52 generics,
53 ..
54 } = parse_macro_input!(input);
55
56 let encode_inner = match data {
57 Data::Struct(DataStruct { fields, .. }) => derive_struct_encode(&fields),
58 Data::Enum(DataEnum { variants, .. }) => derive_enum_encode(&ident, &variants),
59 Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
60 };
61 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
62
63 let output = quote! {
64 impl #impl_generics ::fedimint_core::encoding::Encodable for #ident #ty_generics #where_clause {
65 #[allow(deprecated)]
66 fn consensus_encode<W: std::io::Write>(&self, mut writer: &mut W) -> std::result::Result<(), std::io::Error> {
67 #encode_inner
68 }
69 }
70 };
71
72 output.into()
73}
74
75fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
76 if is_tuple_struct(fields) {
77 let field_names = fields
79 .iter()
80 .enumerate()
81 .map(|(idx, _)| Index::from(idx))
82 .collect::<Vec<_>>();
83 quote! {
84 #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
85 Ok(())
86 }
87 } else {
88 let field_names = fields
90 .iter()
91 .map(|field| field.ident.clone().unwrap())
92 .collect::<Vec<_>>();
93 quote! {
94 #(::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
95 Ok(())
96 }
97 }
98}
99
100fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
103 attributes.iter().find_map(|attr| {
104 if attr.path().is_ident("encodable") {
105 attr.parse_args_with(|input: syn::parse::ParseStream| {
106 input.parse::<syn::Ident>()?.span(); input.parse::<Token![=]>()?; if let Lit::Int(lit_int) = input.parse::<Lit>()? {
109 lit_int.base10_parse()
110 } else {
111 Err(input.error("Expected an integer for 'index'"))
112 }
113 })
114 .ok()
115 } else {
116 None
117 }
118 })
119}
120
121fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
124 input_variants
125 .into_iter()
126 .map(|variant| {
127 let index = parse_index_attribute(&variant.attrs);
128 (index, variant)
129 })
130 .collect()
131}
132
133fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
134 let non_default_variants = variants
135 .into_iter()
136 .filter(|variant| !is_default_variant_enforce_valid(variant))
137 .cloned()
138 .collect::<Vec<_>>();
139
140 let attr_indices = extract_variants_with_indices(non_default_variants.clone());
141
142 let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
143 let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());
144
145 assert!(
146 all_have_index || none_have_index,
147 "Either all or none of the variants should have an index annotation"
148 );
149
150 if all_have_index {
151 attr_indices
152 .into_iter()
153 .map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
154 .collect()
155 } else {
156 non_default_variants
157 .into_iter()
158 .enumerate()
159 .map(|(idx, variant)| (idx as u64, variant))
160 .collect()
161 }
162}
163
164fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
165 if variants.is_empty() {
166 return quote! {
167 match *self {}
168 };
169 }
170
171 let non_default_match_arms =
172 non_default_variant_indices(variants)
173 .into_iter()
174 .map(|(variant_idx, variant)| {
175 let variant_ident = variant.ident.clone();
176
177 if is_tuple_struct(&variant.fields) {
178 let variant_fields = variant
179 .fields
180 .iter()
181 .enumerate()
182 .map(|(idx, _)| format_ident!("bound_{}", idx))
183 .collect::<Vec<_>>();
184 let variant_encode_block =
185 derive_enum_variant_encode_block(variant_idx, &variant_fields);
186 quote! {
187 #ident::#variant_ident(#(#variant_fields,)*) => {
188 #variant_encode_block
189 }
190 }
191 } else {
192 let variant_fields = variant
193 .fields
194 .iter()
195 .map(|field| field.ident.clone().unwrap())
196 .collect::<Vec<_>>();
197 let variant_encode_block =
198 derive_enum_variant_encode_block(variant_idx, &variant_fields);
199 quote! {
200 #ident::#variant_ident { #(#variant_fields,)*} => {
201 #variant_encode_block
202 }
203 }
204 }
205 });
206
207 let default_match_arm = variants
208 .iter()
209 .find(|variant| is_default_variant_enforce_valid(variant))
210 .map(|_variant| {
211 quote! {
212 #ident::Default { variant, bytes } => {
213 ::fedimint_core::encoding::Encodable::consensus_encode(variant, writer)?;
214 ::fedimint_core::encoding::Encodable::consensus_encode(bytes, writer)?;
215 }
216 }
217 });
218
219 let match_arms = non_default_match_arms.chain(default_match_arm);
220
221 quote! {
222 match self {
223 #(#match_arms)*
224 }
225 Ok(())
226 }
227}
228
229fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
230 quote! {
231 ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;
232
233 let mut bytes = Vec::<u8>::new();
234 #(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
235
236 ::fedimint_core::encoding::Encodable::consensus_encode(&bytes, writer)?;
237 }
238}
239
240#[proc_macro_derive(Decodable)]
241pub fn derive_decodable(input: TokenStream) -> TokenStream {
242 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
243
244 let decode_inner = match data {
245 Data::Struct(DataStruct { fields, .. }) => derive_struct_decode(&ident, &fields),
246 syn::Data::Enum(DataEnum { variants, .. }) => derive_enum_decode(&ident, &variants),
247 syn::Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
248 };
249
250 let output = quote! {
251 #[allow(deprecated)]
252 impl ::fedimint_core::encoding::Decodable for #ident {
253 fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(d: &mut D, modules: &::fedimint_core::module::registry::ModuleDecoderRegistry) -> std::result::Result<Self, ::fedimint_core::encoding::DecodeError> {
254 use ::fedimint_core:: anyhow::Context;
255 #decode_inner
256 }
257 }
258 };
259
260 output.into()
261}
262
263#[allow(unused_variables, unreachable_code)]
264fn error(ident: &Ident, message: &str) -> TokenStream2 {
265 #[cfg(feature = "diagnostics")]
266 ident.span().unstable().error(message).emit();
267 #[cfg(not(feature = "diagnostics"))]
268 panic!("{message}");
269
270 TokenStream2::new()
271}
272
273fn derive_struct_decode(ident: &Ident, fields: &Fields) -> TokenStream2 {
274 let decode_block =
275 derive_tuple_or_named_decode_block(ident, "e! { #ident }, "e! { d }, fields);
276
277 quote! {
278 Ok(#decode_block)
279 }
280}
281
282fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
283 if variants.is_empty() {
284 return quote! {
285 Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Enum without variants can't be instantiated")))
286 };
287 }
288
289 let non_default_match_arms = non_default_variant_indices(variants).into_iter()
290 .map(|(variant_idx, variant)| {
291 let variant_ident = variant.ident.clone();
292 let decode_block = derive_tuple_or_named_decode_block(
293 ident,
294 "e! { #ident::#variant_ident },
295 "e! { &mut cursor },
296 &variant.fields,
297 );
298
299 quote! {
301 #variant_idx => {
302 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
304 .context(concat!(
305 "Decoding bytes of ",
306 stringify!(#ident)
307 ))?;
308 let mut cursor = std::io::Cursor::new(&bytes);
309
310 let decoded = #decode_block;
311
312 let read_bytes = cursor.position();
313 let total_bytes = bytes.len() as u64;
314 if read_bytes != total_bytes {
315 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!(
316 "Partial read: got {total_bytes} bytes but only read {read_bytes} when decoding {}",
317 concat!(
318 stringify!(#ident),
319 "::",
320 stringify!(#variant)
321 )
322 )));
323 }
324
325 decoded
326 }
327 }
328 });
329
330 let default_match_arm = if variants.iter().any(is_default_variant_enforce_valid) {
331 quote! {
332 variant => {
333 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(d, modules)
334 .context(concat!(
335 "Decoding default variant of ",
336 stringify!(#ident)
337 ))?;
338
339 #ident::Default {
340 variant,
341 bytes
342 }
343 }
344 }
345 } else {
346 quote! {
347 variant => {
348 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Invalid enum variant {} while decoding {}", variant, stringify!(#ident))));
349 }
350 }
351 };
352
353 quote! {
354 let variant = <u64 as ::fedimint_core::encoding::Decodable>::consensus_decode_partial_from_finite_reader(d, modules)
355 .context(concat!(
356 "Decoding default variant of ",
357 stringify!(#ident)
358 ))?;
359
360 let decoded = match variant {
361 #(#non_default_match_arms)*
362 #default_match_arm
363 };
364 Ok(decoded)
365 }
366}
367
368fn is_tuple_struct(fields: &Fields) -> bool {
369 fields.iter().any(|field| field.ident.is_none())
370}
371
372fn derive_tuple_or_named_decode_block(
377 ident: &Ident,
378 constructor: &TokenStream2,
379 reader: &TokenStream2,
380 fields: &Fields,
381) -> TokenStream2 {
382 if is_tuple_struct(fields) {
383 derive_tuple_decode_block(ident, constructor, reader, fields)
384 } else {
385 derive_named_decode_block(ident, constructor, reader, fields)
386 }
387}
388
389fn derive_tuple_decode_block(
390 ident: &Ident,
391 constructor: &TokenStream2,
392 reader: &TokenStream2,
393 fields: &Fields,
394) -> TokenStream2 {
395 let field_names = fields
396 .iter()
397 .enumerate()
398 .map(|(idx, _)| format_ident!("field_{}", idx))
399 .collect::<Vec<_>>();
400 quote! {
401 {
402 #(
403 let #field_names = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
404 .context(concat!(
405 "Decoding tuple block ",
406 stringify!(#ident),
407 " field ",
408 stringify!(#field_names),
409 ))?;
410 )*
411 #constructor(#(#field_names,)*)
412 }
413 }
414}
415
416fn derive_named_decode_block(
417 ident: &Ident,
418 constructor: &TokenStream2,
419 reader: &TokenStream2,
420 fields: &Fields,
421) -> TokenStream2 {
422 let variant_fields = fields
423 .iter()
424 .map(|field| field.ident.clone().unwrap())
425 .collect::<Vec<_>>();
426 quote! {
427 {
428 #(
429 let #variant_fields = ::fedimint_core::encoding::Decodable::consensus_decode_partial_from_finite_reader(#reader, modules)
430 .context(concat!(
431 "Decoding named block field: ",
432 stringify!(#ident),
433 "{ ... ",
434 stringify!(#variant_fields),
435 " ... }",
436 ))?;
437 )*
438 #constructor{
439 #(#variant_fields,)*
440 }
441 }
442 }
443}