fedimint_client_module/module/init/
recovery.rs

1use std::collections::BTreeMap;
2use std::time::Duration;
3use std::{cmp, ops};
4
5use bitcoin::secp256k1::PublicKey;
6use fedimint_api_client::api::{
7    DynGlobalApi, VERSION_THAT_INTRODUCED_GET_SESSION_STATUS,
8    VERSION_THAT_INTRODUCED_GET_SESSION_STATUS_V2,
9};
10use fedimint_core::db::DatabaseTransaction;
11use fedimint_core::encoding::{Decodable, Encodable};
12use fedimint_core::epoch::ConsensusItem;
13use fedimint_core::module::registry::ModuleDecoderRegistry;
14use fedimint_core::module::{ApiVersion, ModuleCommon};
15use fedimint_core::session_outcome::{AcceptedItem, SessionStatus};
16use fedimint_core::task::{MaybeSend, MaybeSync, ShuttingDownError, TaskGroup};
17use fedimint_core::transaction::Transaction;
18use fedimint_core::util::FmtCompactAnyhow as _;
19use fedimint_core::{OutPoint, PeerId, apply, async_trait_maybe_send};
20use fedimint_logging::LOG_CLIENT_RECOVERY;
21use futures::{Stream, StreamExt as _};
22use rand::{Rng as _, thread_rng};
23use serde::{Deserialize, Serialize};
24use tracing::{debug, trace, warn};
25
26use super::{ClientModuleInit, ClientModuleRecoverArgs};
27use crate::module::recovery::RecoveryProgress;
28use crate::module::{ClientContext, ClientModule};
29
30#[allow(clippy::struct_field_names)]
31#[derive(Debug, Clone, Eq, PartialEq, Encodable, Decodable, Serialize, Deserialize)]
32/// Common state tracked during recovery from history
33pub struct RecoveryFromHistoryCommon {
34    start_session: u64,
35    next_session: u64,
36    end_session: u64,
37}
38
39impl RecoveryFromHistoryCommon {
40    pub fn new(start_session: u64, next_session: u64, end_session: u64) -> Self {
41        Self {
42            start_session,
43            next_session,
44            end_session,
45        }
46    }
47}
48
49/// Module specific logic for [`ClientModuleRecoverArgs::recover_from_history`]
50///
51/// See [`ClientModuleRecoverArgs::recover_from_history`] for more information.
52#[apply(async_trait_maybe_send!)]
53pub trait RecoveryFromHistory: std::fmt::Debug + MaybeSend + MaybeSync + Clone {
54    /// [`ClientModuleInit`] of this recovery logic.
55    type Init: ClientModuleInit;
56
57    /// New empty state to start recovery from, and session number to start from
58    async fn new(
59        init: &Self::Init,
60        args: &ClientModuleRecoverArgs<Self::Init>,
61        snapshot: Option<&<<Self::Init as ClientModuleInit>::Module as ClientModule>::Backup>,
62    ) -> anyhow::Result<(Self, u64)>;
63
64    /// Try to load the existing state previously stored with
65    /// [`RecoveryFromHistory::store_dbtx`].
66    ///
67    /// Storing and restoring progress is used to save progress and
68    /// continue recovery if it was previously terminated before completion.
69    async fn load_dbtx(
70        init: &Self::Init,
71        dbtx: &mut DatabaseTransaction<'_>,
72        args: &ClientModuleRecoverArgs<Self::Init>,
73    ) -> anyhow::Result<Option<(Self, RecoveryFromHistoryCommon)>>;
74
75    /// Store the current recovery state in the database
76    ///
77    /// See [`Self::load_dbtx`].
78    async fn store_dbtx(
79        &self,
80        dbtx: &mut DatabaseTransaction<'_>,
81        common: &RecoveryFromHistoryCommon,
82    );
83
84    /// Delete the recovery state from the database
85    ///
86    /// See [`Self::load_dbtx`].
87    async fn delete_dbtx(&self, dbtx: &mut DatabaseTransaction<'_>);
88
89    /// Read the finalization status
90    ///
91    /// See [`Self::load_dbtx`].
92    async fn load_finalized(dbtx: &mut DatabaseTransaction<'_>) -> Option<bool>;
93
94    /// Store finalization status
95    ///
96    /// See [`Self::load_finalized`].
97    async fn store_finalized(dbtx: &mut DatabaseTransaction<'_>, state: bool);
98
99    /// Handle session outcome, adjusting the current state
100    ///
101    /// It is expected that most implementations don't need to override this
102    /// function, and override more granular ones instead (e.g.
103    /// [`Self::handle_input`] and/or [`Self::handle_output`]).
104    ///
105    /// The default implementation will loop through items in the
106    /// `session.items` and forward them one by one to respective functions
107    /// (see [`Self::handle_transaction`]).
108    async fn handle_session(
109        &mut self,
110        client_ctx: &ClientContext<<Self::Init as ClientModuleInit>::Module>,
111        session_idx: u64,
112        session_items: &Vec<AcceptedItem>,
113    ) -> anyhow::Result<()> {
114        for accepted_item in session_items {
115            if let ConsensusItem::Transaction(ref transaction) = accepted_item.item {
116                self.handle_transaction(client_ctx, transaction, session_idx)
117                    .await?;
118            }
119        }
120        Ok(())
121    }
122
123    /// Handle session outcome, adjusting the current state
124    ///
125    /// It is expected that most implementations don't need to override this
126    /// function, and override more granular ones instead (e.g.
127    /// [`Self::handle_input`] and/or [`Self::handle_output`]).
128    ///
129    /// The default implementation will loop through inputs and outputs
130    /// of the transaction, filter and downcast ones matching current module
131    /// and forward them one by one to respective functions
132    /// (e.g. [`Self::handle_input`], [`Self::handle_output`]).
133    async fn handle_transaction(
134        &mut self,
135        client_ctx: &ClientContext<<Self::Init as ClientModuleInit>::Module>,
136        transaction: &Transaction,
137        session_idx: u64,
138    ) -> anyhow::Result<()> {
139        trace!(
140            target: LOG_CLIENT_RECOVERY,
141            tx_hash = %transaction.tx_hash(),
142            input_num = transaction.inputs.len(),
143            output_num = transaction.outputs.len(),
144            "processing transaction"
145        );
146
147        for (idx, input) in transaction.inputs.iter().enumerate() {
148            trace!(
149                target: LOG_CLIENT_RECOVERY,
150                tx_hash = %transaction.tx_hash(),
151                idx,
152                module_id = input.module_instance_id(),
153                "found transaction input"
154            );
155
156            if let Some(own_input) = client_ctx.input_from_dyn(input) {
157                self.handle_input(client_ctx, idx, own_input, session_idx)
158                    .await?;
159            }
160        }
161
162        for (out_idx, output) in transaction.outputs.iter().enumerate() {
163            trace!(
164                target: LOG_CLIENT_RECOVERY,
165                tx_hash = %transaction.tx_hash(),
166                idx = out_idx,
167                module_id = output.module_instance_id(),
168                "found transaction output"
169            );
170
171            if let Some(own_output) = client_ctx.output_from_dyn(output) {
172                let out_point = OutPoint {
173                    txid: transaction.tx_hash(),
174                    out_idx: out_idx as u64,
175                };
176
177                self.handle_output(client_ctx, out_point, own_output, session_idx)
178                    .await?;
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Handle transaction input, adjusting the current state
186    ///
187    /// Default implementation does nothing.
188    async fn handle_input(
189        &mut self,
190        _client_ctx: &ClientContext<<Self::Init as ClientModuleInit>::Module>,
191        _idx: usize,
192        _input: &<<<Self::Init as ClientModuleInit>::Module as ClientModule>::Common as ModuleCommon>::Input,
193        _session_idx: u64,
194    ) -> anyhow::Result<()> {
195        Ok(())
196    }
197
198    /// Handle transaction output, adjusting the current state
199    ///
200    /// Default implementation does nothing.
201    async fn handle_output(
202        &mut self,
203        _client_ctx: &ClientContext<<Self::Init as ClientModuleInit>::Module>,
204        _out_point: OutPoint,
205        _output: &<<<Self::Init as ClientModuleInit>::Module as ClientModule>::Common as ModuleCommon>::Output,
206        _session_idx: u64,
207    ) -> anyhow::Result<()> {
208        Ok(())
209    }
210
211    /// Called before `finalize_dbtx`, to allow final state changes outside
212    /// of retriable database transaction.
213    async fn pre_finalize(&mut self) -> anyhow::Result<()> {
214        Ok(())
215    }
216
217    /// Finalize the recovery converting the tracked state to final
218    /// changes in the database.
219    ///
220    /// This is the only place during recovery where module gets a chance to
221    /// create state machines, etc.
222    ///
223    /// Notably this function is running in a database-autocommit wrapper, so
224    /// might be called again on database commit failure.
225    async fn finalize_dbtx(&self, dbtx: &mut DatabaseTransaction<'_>) -> anyhow::Result<()>;
226}
227
228impl<Init> ClientModuleRecoverArgs<Init>
229where
230    Init: ClientModuleInit,
231{
232    /// Run recover of a module from federation consensus history
233    ///
234    /// It is expected that most modules will implement their recovery
235    /// by following Federation consensus history to restore their
236    /// state. This function implement such a recovery by being generic
237    /// over [`RecoveryFromHistory`] trait, which provides module-specific
238    /// parts of recovery logic.
239    pub async fn recover_from_history<Recovery>(
240        &self,
241        init: &Init,
242        snapshot: Option<&<<Init as ClientModuleInit>::Module as ClientModule>::Backup>,
243    ) -> anyhow::Result<()>
244    where
245        Recovery: RecoveryFromHistory<Init = Init> + std::fmt::Debug,
246    {
247        /// Fetch epochs in a given range and send them over `sender`
248        ///
249        /// Since WASM's `spawn` does not support join handles, we indicate
250        /// errors via `sender` itself.
251        fn fetch_block_stream<'a>(
252            api: DynGlobalApi,
253            core_api_version: ApiVersion,
254            decoders: ModuleDecoderRegistry,
255            epoch_range: ops::Range<u64>,
256            broadcast_public_keys: Option<BTreeMap<PeerId, PublicKey>>,
257            task_group: TaskGroup,
258        ) -> impl futures::Stream<Item = Result<(u64, Vec<AcceptedItem>), ShuttingDownError>> + 'a
259        {
260            // How many request for blocks to run in parallel (streaming).
261            let parallelism_level =
262                if core_api_version < VERSION_THAT_INTRODUCED_GET_SESSION_STATUS_V2 {
263                    64
264                } else {
265                    128
266                };
267
268            futures::stream::iter(epoch_range.clone())
269                .map(move |session_idx| {
270                    let api = api.clone();
271                    // When decoding we're only interested in items we can understand, so we don't
272                    // want to fail on a missing decoder of some unrelated module.
273                    let decoders = decoders.clone().with_fallback();
274                    let task_group = task_group.clone();
275                    let broadcast_public_keys = broadcast_public_keys.clone();
276
277                    Box::pin(async move {
278                        // NOTE: Each block is fetched in a spawned task. This avoids a footgun
279                        // of stuff in streams not making any progress when the stream itself
280                        // is not being polled, and possibly can increase the fetching performance.
281                        task_group.spawn_cancellable("recovery fetch block", async move {
282
283                            let mut retry_sleep = Duration::from_millis(10);
284                            let block = loop {
285                                trace!(target: LOG_CLIENT_RECOVERY, session_idx, "Awaiting signed block");
286
287                                let items_res = if core_api_version < VERSION_THAT_INTRODUCED_GET_SESSION_STATUS {
288                                    api.await_block(session_idx, &decoders).await.map(|s| s.items)
289                                } else {
290                                    api.get_session_status(session_idx, &decoders, core_api_version, broadcast_public_keys.as_ref()).await.map(|s| match s {
291                                        SessionStatus::Initial => panic!("Federation missing session that existed when we started recovery"),
292                                        SessionStatus::Pending(items) => items,
293                                        SessionStatus::Complete(s) => s.items,
294                                    })
295                                };
296
297                                match items_res {
298                                    Ok(block) => {
299                                        trace!(target: LOG_CLIENT_RECOVERY, session_idx, "Got signed session");
300                                        break block
301                                    },
302                                    Err(err) => {
303                                        const MAX_SLEEP: Duration = Duration::from_secs(120);
304
305                                        warn!(target: LOG_CLIENT_RECOVERY, err = %err.fmt_compact_anyhow(), session_idx, "Error trying to fetch signed block");
306                                        // We don't want PARALLELISM_LEVEL tasks hammering Federation
307                                        // with requests, so max sleep is significant
308                                        if retry_sleep <= MAX_SLEEP {
309                                            retry_sleep = retry_sleep
310                                                + thread_rng().gen_range(Duration::ZERO..=retry_sleep);
311                                        }
312                                        fedimint_core::runtime::sleep(cmp::min(retry_sleep, MAX_SLEEP))
313                                            .await;
314                                    }
315                                }
316                            };
317
318                            (session_idx, block)
319                        }).await.expect("Can't fail")
320                    })
321                })
322                .buffered(parallelism_level)
323        }
324
325        /// Make enough progress to justify saving a state snapshot
326        async fn make_progress<Init, Recovery: RecoveryFromHistory<Init = Init>>(
327            client_ctx: &ClientContext<<Init as ClientModuleInit>::Module>,
328            common_state: &mut RecoveryFromHistoryCommon,
329            state: &mut Recovery,
330            block_stream: &mut (
331                     impl Stream<Item = Result<(u64, Vec<AcceptedItem>), ShuttingDownError>> + Unpin
332                 ),
333        ) -> anyhow::Result<()>
334        where
335            Init: ClientModuleInit,
336        {
337            /// the amount of blocks after which we unconditionally save
338            /// progress in the database (return from this function)
339            ///
340            /// We are also bound by time inside the loop, below
341            const PROGRESS_SNAPSHOT_BLOCKS: u64 = 5000;
342
343            let start = fedimint_core::time::now();
344
345            let block_range = common_state.next_session
346                ..cmp::min(
347                    common_state
348                        .next_session
349                        .wrapping_add(PROGRESS_SNAPSHOT_BLOCKS),
350                    common_state.end_session,
351                );
352
353            for _ in block_range {
354                let Some(res) = block_stream.next().await else {
355                    break;
356                };
357
358                let (session_idx, accepted_items) = res?;
359
360                assert_eq!(common_state.next_session, session_idx);
361                state
362                    .handle_session(client_ctx, session_idx, &accepted_items)
363                    .await?;
364
365                common_state.next_session += 1;
366
367                if Duration::from_secs(10)
368                    < fedimint_core::time::now()
369                        .duration_since(start)
370                        .unwrap_or_default()
371                {
372                    break;
373                }
374            }
375
376            Ok(())
377        }
378
379        let db = self.db();
380        let client_ctx = self.context();
381
382        if Recovery::load_finalized(&mut db.begin_transaction_nc().await)
383            .await
384            .unwrap_or_default()
385        {
386            // In rare circumstances, the finalization could complete, yet the completion
387            // of `recover` function not yet persisted in the database. So
388            // it's possible that `recovery` would be called again on an
389            // already finalized state. Because of this we store a
390            // finalization marker in the same dbtx as the finalization itself, detect this
391            // here and exit early.
392            //
393            // Example sequence how this happens (if `finalize_dbtx` didn't exist):
394            //
395            // 0. module recovery is complete and progress saved to the db
396            // 1. `dbtx` with finalization commits, progress deleted, completing recovery on
397            //    the client module side
398            // 2. client crashes/gets terminated (tricky corner case)
399            // 3. client starts again
400            // 4. client never observed/persisted that the module finished recovery, so
401            //    calls module recovery again
402            // 5. module doesn't see progress, starts recovery again, eventually completes
403            //    again and moves to finalization
404            // 6. module runs finalization again and probably fails because it's actually
405            //    not idempotent and doesn't expect the already existing state.
406            warn!(
407                target: LOG_CLIENT_RECOVERY,
408                "Previously finalized, exiting"
409            );
410            return Ok(());
411        }
412        let current_session_count = client_ctx.global_api().session_count().await?;
413        debug!(target: LOG_CLIENT_RECOVERY, session_count = current_session_count, "Current session count");
414
415        let (mut state, mut common_state) =
416            // TODO: if load fails (e.g. module didn't migrate an existing recovery state and failed to decode it),
417            // we could just ... start from scratch? at least being able to force this behavior might be useful
418            if let Some((state, common_state)) = Recovery::load_dbtx(init, &mut db.begin_transaction_nc().await, self).await? {
419                (state, common_state)
420            } else {
421                let (state, start_session) = Recovery::new(init, self, snapshot).await?;
422
423                debug!(target: LOG_CLIENT_RECOVERY, start_session, "Recovery start session");
424                (state,
425                RecoveryFromHistoryCommon {
426                    start_session,
427                    next_session: start_session,
428                    end_session: current_session_count + 1,
429                })
430            };
431
432        let block_stream_session_range = common_state.next_session..common_state.end_session;
433        debug!(target: LOG_CLIENT_RECOVERY, range = ?block_stream_session_range, "Starting block streaming");
434
435        let mut block_stream = fetch_block_stream(
436            self.api().clone(),
437            *self.core_api_version(),
438            client_ctx.decoders(),
439            block_stream_session_range,
440            client_ctx
441                .get_config()
442                .await
443                .global
444                .broadcast_public_keys
445                .clone(),
446            self.task_group().clone(),
447        );
448        let client_ctx = self.context();
449
450        while common_state.next_session < common_state.end_session {
451            make_progress(
452                &client_ctx,
453                &mut common_state,
454                &mut state,
455                &mut block_stream,
456            )
457            .await?;
458
459            let mut dbtx = db.begin_transaction().await;
460            state.store_dbtx(&mut dbtx.to_ref_nc(), &common_state).await;
461            dbtx.commit_tx().await;
462
463            self.update_recovery_progress(RecoveryProgress {
464                complete: (common_state.next_session - common_state.start_session)
465                    .try_into()
466                    .unwrap_or(u32::MAX),
467                total: (common_state.end_session - common_state.start_session)
468                    .try_into()
469                    .unwrap_or(u32::MAX),
470            });
471        }
472
473        state.pre_finalize().await?;
474
475        let mut dbtx = db.begin_transaction().await;
476        state.store_dbtx(&mut dbtx.to_ref_nc(), &common_state).await;
477        dbtx.commit_tx().await;
478
479        debug!(
480            target: LOG_CLIENT_RECOVERY,
481            ?state,
482            "Finalizing restore"
483        );
484
485        db.autocommit(
486            |dbtx, _| {
487                let state = state.clone();
488                {
489                    Box::pin(async move {
490                        state.delete_dbtx(dbtx).await;
491                        state.finalize_dbtx(dbtx).await?;
492                        Recovery::store_finalized(dbtx, true).await;
493
494                        Ok::<_, anyhow::Error>(())
495                    })
496                }
497            },
498            None,
499        )
500        .await?;
501
502        Ok(())
503    }
504}