fedimint_core/
task.rs

1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5/// Just-in-time initialization
6pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{debug, error, info, trace};
23
24use crate::runtime;
25// TODO: stop using `task::*`, and use `runtime::*` in the code
26// lots of churn though
27pub use crate::runtime::*;
28/// A group of task working together
29///
30/// Using this struct it is possible to spawn one or more
31/// main thread collaborating, which can cooperatively gracefully
32/// shut down, either due to external request, or failure of
33/// one of them.
34///
35/// Each thread should periodically check [`TaskHandle`] or rely
36/// on condition like channel disconnection to detect when it is time
37/// to finish.
38#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40    inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn make_handle(&self) -> TaskHandle {
49        TaskHandle {
50            inner: self.inner.clone(),
51        }
52    }
53
54    /// Create a sub-group
55    ///
56    /// Task subgroup works like an independent [`TaskGroup`], but the parent
57    /// `TaskGroup` will propagate the shut down signal to a sub-group.
58    ///
59    /// In contrast to using the parent group directly, a subgroup allows
60    /// calling [`Self::join_all`] and detecting any panics on just a
61    /// subset of tasks.
62    ///
63    /// The code create a subgroup is responsible for calling
64    /// [`Self::join_all`]. If it won't, the parent subgroup **will not**
65    /// detect any panics in the tasks spawned by the subgroup.
66    pub fn make_subgroup(&self) -> Self {
67        let new_tg = Self::new();
68        self.inner.add_subgroup(new_tg.clone());
69        new_tg
70    }
71
72    /// Tell all tasks in the group to shut down. This only initiates the
73    /// shutdown process, it does not wait for the tasks to shut down.
74    pub fn shutdown(&self) {
75        self.inner.shutdown();
76    }
77
78    /// Tell all tasks in the group to shut down and wait for them to finish.
79    pub async fn shutdown_join_all(
80        self,
81        join_timeout: impl Into<Option<Duration>>,
82    ) -> Result<(), anyhow::Error> {
83        self.shutdown();
84        self.join_all(join_timeout.into()).await
85    }
86
87    /// Add a task to the group that waits for CTRL+C or SIGTERM, then
88    /// tells the rest of the task group to shut down.
89    #[cfg(not(target_family = "wasm"))]
90    pub fn install_kill_handler(&self) {
91        /// Wait for CTRL+C or SIGTERM.
92        async fn wait_for_shutdown_signal() {
93            use tokio::signal;
94
95            let ctrl_c = async {
96                signal::ctrl_c()
97                    .await
98                    .expect("failed to install Ctrl+C handler");
99            };
100
101            #[cfg(unix)]
102            let terminate = async {
103                signal::unix::signal(signal::unix::SignalKind::terminate())
104                    .expect("failed to install signal handler")
105                    .recv()
106                    .await;
107            };
108
109            #[cfg(not(unix))]
110            let terminate = std::future::pending::<()>();
111
112            tokio::select! {
113                () = ctrl_c => {},
114                () = terminate => {},
115            }
116        }
117
118        runtime::spawn("kill handlers", {
119            let task_group = self.clone();
120            async move {
121                wait_for_shutdown_signal().await;
122                info!(
123                    target: LOG_TASK,
124                    "signal received, starting graceful shutdown"
125                );
126                task_group.shutdown();
127            }
128        });
129    }
130
131    pub fn spawn<Fut, R>(
132        &self,
133        name: impl Into<String>,
134        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
135    ) -> oneshot::Receiver<R>
136    where
137        Fut: Future<Output = R> + MaybeSend + 'static,
138        R: MaybeSend + 'static,
139    {
140        self.spawn_inner(name, f, false)
141    }
142
143    /// This is a version of [`Self::spawn`] that uses less noisy logging level
144    ///
145    /// Meant for tasks that are spawned often enough to not be as interesting.
146    pub fn spawn_silent<Fut, R>(
147        &self,
148        name: impl Into<String>,
149        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
150    ) -> oneshot::Receiver<R>
151    where
152        Fut: Future<Output = R> + MaybeSend + 'static,
153        R: MaybeSend + 'static,
154    {
155        self.spawn_inner(name, f, true)
156    }
157
158    fn spawn_inner<Fut, R>(
159        &self,
160        name: impl Into<String>,
161        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
162        quiet: bool,
163    ) -> oneshot::Receiver<R>
164    where
165        Fut: Future<Output = R> + MaybeSend + 'static,
166        R: MaybeSend + 'static,
167    {
168        let name = name.into();
169        let mut guard = TaskPanicGuard {
170            name: name.clone(),
171            inner: self.inner.clone(),
172            completed: false,
173        };
174        let handle = self.make_handle();
175
176        let (tx, rx) = oneshot::channel();
177        self.inner
178            .active_tasks_join_handles
179            .lock()
180            .expect("Locking failed")
181            .insert_with_key(move |task_key| {
182                (
183                    name.clone(),
184                    crate::runtime::spawn(&name, {
185                        let name = name.clone();
186                        async move {
187                            defer! {
188                                // Panic or normal completion, it means the task
189                                // is complete, and does not need to be shutdown
190                                // via join handle. This prevents buildup of task
191                                // handles.
192                                if handle
193                                    .inner
194                                    .active_tasks_join_handles
195                                    .lock()
196                                    .expect("Locking failed")
197                                    .remove(task_key)
198                                    .is_none() {
199                                        trace!(target: LOG_TASK, %name, "Task already canceled");
200                                    }
201                            }
202                            // Unfortunately log levels need to be static
203                            if quiet {
204                                trace!(target: LOG_TASK, %name, "Starting task");
205                            } else {
206                                debug!(target: LOG_TASK, %name, "Starting task");
207                            }
208                            let r = f(handle.clone()).await;
209                            guard.completed = true;
210
211                            if quiet {
212                                trace!(target: LOG_TASK, %name, "Finished task");
213                            } else {
214                                debug!(target: LOG_TASK, %name, "Finished task");
215                            }
216                            // if receiver is not interested, just drop the message
217                            let _ = tx.send(r);
218
219                            // NOTE: Since this is a `async move` the guard will not get moved
220                            // if it's not moved inside the body. Weird.
221                            drop(guard);
222                        }
223                    }),
224                )
225            });
226
227        rx
228    }
229
230    /// Spawn a task that will get cancelled automatically on `TaskGroup`
231    /// shutdown.
232    pub fn spawn_cancellable<R>(
233        &self,
234        name: impl Into<String>,
235        future: impl Future<Output = R> + MaybeSend + 'static,
236    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
237    where
238        R: MaybeSend + 'static,
239    {
240        self.spawn(name, |handle| async move {
241            let value = handle.cancel_on_shutdown(future).await;
242            if value.is_err() {
243                // name will part of span
244                debug!(target: LOG_TASK, "task cancelled on shutdown");
245            }
246            value
247        })
248    }
249
250    pub fn spawn_cancellable_silent<R>(
251        &self,
252        name: impl Into<String>,
253        future: impl Future<Output = R> + MaybeSend + 'static,
254    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
255    where
256        R: MaybeSend + 'static,
257    {
258        self.spawn_silent(name, |handle| async move {
259            let value = handle.cancel_on_shutdown(future).await;
260            if value.is_err() {
261                // name will part of span
262                debug!(target: LOG_TASK, "task cancelled on shutdown");
263            }
264            value
265        })
266    }
267
268    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
269        let deadline = timeout.map(|timeout| now() + timeout);
270        let mut errors = vec![];
271
272        self.join_all_inner(deadline, &mut errors).await;
273
274        if errors.is_empty() {
275            Ok(())
276        } else {
277            let num_errors = errors.len();
278            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
279        }
280    }
281
282    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
283    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
284    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
285        self.inner.join_all(deadline, errors).await;
286    }
287}
288
289struct TaskPanicGuard {
290    name: String,
291    inner: Arc<TaskGroupInner>,
292    /// Did the future completed successfully (no panic)
293    completed: bool,
294}
295
296impl Drop for TaskPanicGuard {
297    fn drop(&mut self) {
298        trace!(
299            target: LOG_TASK,
300            name = %self.name,
301            "Task drop"
302        );
303        if !self.completed {
304            info!(
305                target: LOG_TASK,
306                name = %self.name,
307                "Task shut down uncleanly"
308            );
309            self.inner.shutdown();
310        }
311    }
312}
313
314#[derive(Clone, Debug)]
315pub struct TaskHandle {
316    inner: Arc<TaskGroupInner>,
317}
318
319#[derive(thiserror::Error, Debug, Clone)]
320#[error("Task group is shutting down")]
321#[non_exhaustive]
322pub struct ShuttingDownError {}
323
324impl TaskHandle {
325    /// Is task group shutting down?
326    ///
327    /// Every task in a task group should detect and stop if `true`.
328    pub fn is_shutting_down(&self) -> bool {
329        self.inner.is_shutting_down()
330    }
331
332    /// Make a [`oneshot::Receiver`] that will fire on shutdown
333    ///
334    /// Tasks can use `select` on the return value to handle shutdown
335    /// signal during otherwise blocking operation.
336    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
337        self.inner.make_shutdown_rx()
338    }
339
340    /// Run the future or cancel it if the [`TaskGroup`] shuts down.
341    pub async fn cancel_on_shutdown<F: Future>(
342        &self,
343        fut: F,
344    ) -> Result<F::Output, ShuttingDownError> {
345        let rx = self.make_shutdown_rx();
346        match future::select(pin!(rx), pin!(fut)).await {
347            Either::Left(((), _)) => Err(ShuttingDownError {}),
348            Either::Right((value, _)) => Ok(value),
349        }
350    }
351}
352
353pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
354
355impl TaskShutdownToken {
356    fn new(mut rx: watch::Receiver<bool>) -> Self {
357        Self(Box::pin(async move {
358            let _ = rx.wait_for(|v| *v).await;
359        }))
360    }
361}
362
363impl Future for TaskShutdownToken {
364    type Output = ();
365
366    fn poll(
367        mut self: Pin<&mut Self>,
368        cx: &mut std::task::Context<'_>,
369    ) -> std::task::Poll<Self::Output> {
370        self.0.as_mut().poll(cx)
371    }
372}
373
374/// async trait that use MaybeSend
375///
376/// # Example
377///
378/// ```rust
379/// use fedimint_core::{apply, async_trait_maybe_send};
380/// #[apply(async_trait_maybe_send!)]
381/// trait Foo {
382///     // methods
383/// }
384///
385/// #[apply(async_trait_maybe_send!)]
386/// impl Foo for () {
387///     // methods
388/// }
389/// ```
390#[macro_export]
391macro_rules! async_trait_maybe_send {
392    ($($tt:tt)*) => {
393        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
394        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
395        $($tt)*
396    };
397}
398
399/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
400///
401/// # Example
402///
403/// ```rust
404/// use std::any::Any;
405///
406/// use fedimint_core::{apply, maybe_add_send};
407/// type Foo = maybe_add_send!(dyn Any);
408/// ```
409#[cfg(not(target_family = "wasm"))]
410#[macro_export]
411macro_rules! maybe_add_send {
412    ($($tt:tt)*) => {
413        $($tt)* + Send
414    };
415}
416
417/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
418///
419/// # Example
420///
421/// ```rust
422/// type Foo = maybe_add_send!(dyn Any);
423/// ```
424#[cfg(target_family = "wasm")]
425#[macro_export]
426macro_rules! maybe_add_send {
427    ($($tt:tt)*) => {
428        $($tt)*
429    };
430}
431
432/// See `maybe_add_send`
433#[cfg(not(target_family = "wasm"))]
434#[macro_export]
435macro_rules! maybe_add_send_sync {
436    ($($tt:tt)*) => {
437        $($tt)* + Send + Sync
438    };
439}
440
441/// See `maybe_add_send`
442#[cfg(target_family = "wasm")]
443#[macro_export]
444macro_rules! maybe_add_send_sync {
445    ($($tt:tt)*) => {
446        $($tt)*
447    };
448}
449
450/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
451///
452/// On wasm, most types don't implement `Send` because JS types can not sent
453/// between workers directly.
454#[cfg(target_family = "wasm")]
455pub trait MaybeSend {}
456
457/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
458///
459/// On wasm, most types don't implement `Send` because JS types can not sent
460/// between workers directly.
461#[cfg(not(target_family = "wasm"))]
462pub trait MaybeSend: Send {}
463
464#[cfg(not(target_family = "wasm"))]
465impl<T: Send> MaybeSend for T {}
466
467#[cfg(target_family = "wasm")]
468impl<T> MaybeSend for T {}
469
470/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
471#[cfg(target_family = "wasm")]
472pub trait MaybeSync {}
473
474/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
475#[cfg(not(target_family = "wasm"))]
476pub trait MaybeSync: Sync {}
477
478#[cfg(not(target_family = "wasm"))]
479impl<T: Sync> MaybeSync for T {}
480
481#[cfg(target_family = "wasm")]
482impl<T> MaybeSync for T {}
483
484// Used in tests when sleep functionality is desired so it can be logged.
485// Must include comment describing the reason for sleeping.
486pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
487    info!(
488        target: LOG_TEST,
489        "Sleeping for {}.{:03} seconds because: {}",
490        duration.as_secs(),
491        duration.subsec_millis(),
492        comment.as_ref()
493    );
494    sleep(duration).await;
495}
496
497/// An error used as a "cancelled" marker in [`Cancellable`].
498#[derive(Error, Debug)]
499#[error("Operation cancelled")]
500pub struct Cancelled;
501
502/// Operation that can potentially get cancelled returning no result (e.g.
503/// program shutdown).
504pub type Cancellable<T> = std::result::Result<T, Cancelled>;
505
506#[cfg(test)]
507mod tests;