fedimint_core/task/
inner.rs
1use std::future::Future;
2use std::pin::Pin;
3use std::time::{Duration, SystemTime};
4
5use fedimint_core::time::now;
6use fedimint_logging::LOG_TASK;
7use slotmap::SlotMap;
8use tokio::sync::watch;
9use tracing::{debug, error, info, warn};
10
11use super::{TaskGroup, TaskShutdownToken};
12use crate::runtime::{JoinError, JoinHandle};
13use crate::util::FmtCompact as _;
14
15#[derive(Debug)]
16pub struct TaskGroupInner {
17 on_shutdown_tx: watch::Sender<bool>,
18 on_shutdown_rx: watch::Receiver<bool>,
21 pub(crate) active_tasks_join_handles:
22 std::sync::Mutex<slotmap::SlotMap<slotmap::DefaultKey, (String, JoinHandle<()>)>>,
23 subgroups: std::sync::Mutex<Vec<TaskGroup>>,
26}
27
28impl Default for TaskGroupInner {
29 fn default() -> Self {
30 let (on_shutdown_tx, on_shutdown_rx) = watch::channel(false);
31 Self {
32 on_shutdown_tx,
33 on_shutdown_rx,
34 active_tasks_join_handles: std::sync::Mutex::new(SlotMap::default()),
35 subgroups: std::sync::Mutex::new(vec![]),
36 }
37 }
38}
39
40impl TaskGroupInner {
41 pub fn shutdown(&self) {
42 #[allow(clippy::disallowed_methods)]
45 {
46 self.on_shutdown_tx
47 .send(true)
48 .expect("We must have on_shutdown_rx around so this never fails");
49 }
50
51 let subgroups = self.subgroups.lock().expect("locking failed").clone();
52 for subgroup in subgroups {
53 subgroup.inner.shutdown();
54 }
55 }
56
57 #[inline]
58 pub fn is_shutting_down(&self) -> bool {
59 *self.on_shutdown_tx.borrow()
60 }
61
62 #[inline]
63 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
64 TaskShutdownToken::new(self.on_shutdown_rx.clone())
65 }
66
67 #[inline]
68 pub fn add_subgroup(&self, tg: TaskGroup) {
69 self.subgroups.lock().expect("locking failed").push(tg);
70 }
71
72 #[inline]
73 pub async fn join_all(&self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
74 let subgroups = self.subgroups.lock().expect("locking failed").clone();
75 for subgroup in subgroups {
76 info!(target: LOG_TASK, "Waiting for subgroup to finish");
77 subgroup.join_all_inner(deadline, errors).await;
78 info!(target: LOG_TASK, "Subgroup finished");
79 }
80
81 let tasks: Vec<_> = self
83 .active_tasks_join_handles
84 .lock()
85 .expect("Lock failed")
86 .drain()
87 .collect();
88 for (_, (name, join)) in tasks {
89 debug!(target: LOG_TASK, task=%name, "Waiting for task to finish");
90
91 let timeout = deadline.map(|deadline| {
92 deadline
93 .duration_since(now())
94 .unwrap_or(Duration::from_millis(10))
95 });
96
97 #[cfg(not(target_family = "wasm"))]
98 let join_future: Pin<Box<dyn Future<Output = _> + Send>> =
99 if let Some(timeout) = timeout {
100 Box::pin(crate::runtime::timeout(timeout, join))
101 } else {
102 Box::pin(async { Ok(join.await) })
103 };
104
105 #[cfg(target_family = "wasm")]
106 let join_future: Pin<Box<dyn Future<Output = _>>> = if let Some(timeout) = timeout {
107 Box::pin(crate::runtime::timeout(timeout, join))
108 } else {
109 Box::pin(async { Ok(join.await) })
110 };
111
112 match join_future.await {
113 Ok(Ok(())) => {
114 debug!(target: LOG_TASK, task=%name, "Task finished");
115 }
116 Ok(Err(err)) => {
117 error!(target: LOG_TASK, task=%name, err=%err.fmt_compact(), "Task panicked");
118 errors.push(err);
119 }
120 Err(_) => {
121 warn!(
122 target: LOG_TASK, task=%name,
123 "Timeout waiting for task to shut down"
124 );
125 }
126 }
127 }
128 }
129}