fedimint_core/task/
inner.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::time::{Duration, SystemTime};
4
5use fedimint_core::time::now;
6use fedimint_logging::LOG_TASK;
7use slotmap::SlotMap;
8use tokio::sync::watch;
9use tracing::{debug, error, info, warn};
10
11use super::{TaskGroup, TaskShutdownToken};
12use crate::runtime::{JoinError, JoinHandle};
13use crate::util::FmtCompact as _;
14
15#[derive(Debug)]
16pub struct TaskGroupInner {
17    on_shutdown_tx: watch::Sender<bool>,
18    // It is necessary to keep at least one `Receiver` around,
19    // otherwise shutdown writes are lost.
20    on_shutdown_rx: watch::Receiver<bool>,
21    pub(crate) active_tasks_join_handles:
22        std::sync::Mutex<slotmap::SlotMap<slotmap::DefaultKey, (String, JoinHandle<()>)>>,
23    // using blocking Mutex to avoid `async` in `shutdown` and `add_subgroup`
24    // it's OK as we don't ever need to yield
25    subgroups: std::sync::Mutex<Vec<TaskGroup>>,
26}
27
28impl Default for TaskGroupInner {
29    fn default() -> Self {
30        let (on_shutdown_tx, on_shutdown_rx) = watch::channel(false);
31        Self {
32            on_shutdown_tx,
33            on_shutdown_rx,
34            active_tasks_join_handles: std::sync::Mutex::new(SlotMap::default()),
35            subgroups: std::sync::Mutex::new(vec![]),
36        }
37    }
38}
39
40impl TaskGroupInner {
41    pub fn shutdown(&self) {
42        // Note: set the flag before starting to call shutdown handlers
43        // to avoid confusion.
44        #[allow(clippy::disallowed_methods)]
45        {
46            self.on_shutdown_tx
47                .send(true)
48                .expect("We must have on_shutdown_rx around so this never fails");
49        }
50
51        let subgroups = self.subgroups.lock().expect("locking failed").clone();
52        for subgroup in subgroups {
53            subgroup.inner.shutdown();
54        }
55    }
56
57    #[inline]
58    pub fn is_shutting_down(&self) -> bool {
59        *self.on_shutdown_tx.borrow()
60    }
61
62    #[inline]
63    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
64        TaskShutdownToken::new(self.on_shutdown_rx.clone())
65    }
66
67    #[inline]
68    pub fn add_subgroup(&self, tg: TaskGroup) {
69        self.subgroups.lock().expect("locking failed").push(tg);
70    }
71
72    #[inline]
73    pub async fn join_all(&self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
74        let subgroups = self.subgroups.lock().expect("locking failed").clone();
75        for subgroup in subgroups {
76            info!(target: LOG_TASK, "Waiting for subgroup to finish");
77            subgroup.join_all_inner(deadline, errors).await;
78            info!(target: LOG_TASK, "Subgroup finished");
79        }
80
81        // drop lock early
82        let tasks: Vec<_> = self
83            .active_tasks_join_handles
84            .lock()
85            .expect("Lock failed")
86            .drain()
87            .collect();
88        for (_, (name, join)) in tasks {
89            debug!(target: LOG_TASK, task=%name, "Waiting for task to finish");
90
91            let timeout = deadline.map(|deadline| {
92                deadline
93                    .duration_since(now())
94                    .unwrap_or(Duration::from_millis(10))
95            });
96
97            #[cfg(not(target_family = "wasm"))]
98            let join_future: Pin<Box<dyn Future<Output = _> + Send>> =
99                if let Some(timeout) = timeout {
100                    Box::pin(crate::runtime::timeout(timeout, join))
101                } else {
102                    Box::pin(async { Ok(join.await) })
103                };
104
105            #[cfg(target_family = "wasm")]
106            let join_future: Pin<Box<dyn Future<Output = _>>> = if let Some(timeout) = timeout {
107                Box::pin(crate::runtime::timeout(timeout, join))
108            } else {
109                Box::pin(async { Ok(join.await) })
110            };
111
112            match join_future.await {
113                Ok(Ok(())) => {
114                    debug!(target: LOG_TASK, task=%name, "Task finished");
115                }
116                Ok(Err(err)) => {
117                    error!(target: LOG_TASK, task=%name, err=%err.fmt_compact(), "Task panicked");
118                    errors.push(err);
119                }
120                Err(_) => {
121                    warn!(
122                        target: LOG_TASK, task=%name,
123                        "Timeout waiting for task to shut down"
124                    );
125                }
126            }
127        }
128    }
129}