fedimint_core/
task.rs

1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5/// Just-in-time initialization
6pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{debug, error, info, trace};
23
24use crate::runtime;
25// TODO: stop using `task::*`, and use `runtime::*` in the code
26// lots of churn though
27pub use crate::runtime::*;
28/// A group of task working together
29///
30/// Using this struct it is possible to spawn one or more
31/// main thread collaborating, which can cooperatively gracefully
32/// shut down, either due to external request, or failure of
33/// one of them.
34///
35/// Each thread should periodically check [`TaskHandle`] or rely
36/// on condition like channel disconnection to detect when it is time
37/// to finish.
38#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40    inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn make_handle(&self) -> TaskHandle {
49        TaskHandle {
50            inner: self.inner.clone(),
51        }
52    }
53
54    /// Create a sub-group
55    ///
56    /// Task subgroup works like an independent [`TaskGroup`], but the parent
57    /// `TaskGroup` will propagate the shut down signal to a sub-group.
58    ///
59    /// In contrast to using the parent group directly, a subgroup allows
60    /// calling [`Self::join_all`] and detecting any panics on just a
61    /// subset of tasks.
62    ///
63    /// The code create a subgroup is responsible for calling
64    /// [`Self::join_all`]. If it won't, the parent subgroup **will not**
65    /// detect any panics in the tasks spawned by the subgroup.
66    pub fn make_subgroup(&self) -> Self {
67        let new_tg = Self::new();
68        self.inner.add_subgroup(new_tg.clone());
69        new_tg
70    }
71
72    /// Is task group shutting down?
73    pub fn is_shutting_down(&self) -> bool {
74        self.inner.is_shutting_down()
75    }
76
77    /// Tell all tasks in the group to shut down. This only initiates the
78    /// shutdown process, it does not wait for the tasks to shut down.
79    pub fn shutdown(&self) {
80        self.inner.shutdown();
81    }
82
83    /// Tell all tasks in the group to shut down and wait for them to finish.
84    pub async fn shutdown_join_all(
85        self,
86        join_timeout: impl Into<Option<Duration>>,
87    ) -> Result<(), anyhow::Error> {
88        self.shutdown();
89        self.join_all(join_timeout.into()).await
90    }
91
92    /// Add a task to the group that waits for CTRL+C or SIGTERM, then
93    /// tells the rest of the task group to shut down.
94    #[cfg(not(target_family = "wasm"))]
95    pub fn install_kill_handler(&self) {
96        /// Wait for CTRL+C or SIGTERM.
97        async fn wait_for_shutdown_signal() {
98            use tokio::signal;
99
100            let ctrl_c = async {
101                signal::ctrl_c()
102                    .await
103                    .expect("failed to install Ctrl+C handler");
104            };
105
106            #[cfg(unix)]
107            let terminate = async {
108                signal::unix::signal(signal::unix::SignalKind::terminate())
109                    .expect("failed to install signal handler")
110                    .recv()
111                    .await;
112            };
113
114            #[cfg(not(unix))]
115            let terminate = std::future::pending::<()>();
116
117            tokio::select! {
118                () = ctrl_c => {},
119                () = terminate => {},
120            }
121        }
122
123        runtime::spawn("kill handlers", {
124            let task_group = self.clone();
125            async move {
126                wait_for_shutdown_signal().await;
127                info!(
128                    target: LOG_TASK,
129                    "signal received, starting graceful shutdown"
130                );
131                task_group.shutdown();
132            }
133        });
134    }
135
136    pub fn spawn<Fut, R>(
137        &self,
138        name: impl Into<String>,
139        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
140    ) -> oneshot::Receiver<R>
141    where
142        Fut: Future<Output = R> + MaybeSend + 'static,
143        R: MaybeSend + 'static,
144    {
145        self.spawn_inner(name, f, false)
146    }
147
148    /// This is a version of [`Self::spawn`] that uses less noisy logging level
149    ///
150    /// Meant for tasks that are spawned often enough to not be as interesting.
151    pub fn spawn_silent<Fut, R>(
152        &self,
153        name: impl Into<String>,
154        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
155    ) -> oneshot::Receiver<R>
156    where
157        Fut: Future<Output = R> + MaybeSend + 'static,
158        R: MaybeSend + 'static,
159    {
160        self.spawn_inner(name, f, true)
161    }
162
163    fn spawn_inner<Fut, R>(
164        &self,
165        name: impl Into<String>,
166        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
167        quiet: bool,
168    ) -> oneshot::Receiver<R>
169    where
170        Fut: Future<Output = R> + MaybeSend + 'static,
171        R: MaybeSend + 'static,
172    {
173        let name = name.into();
174        let mut guard = TaskPanicGuard {
175            name: name.clone(),
176            inner: self.inner.clone(),
177            completed: false,
178        };
179        let handle = self.make_handle();
180
181        let (tx, rx) = oneshot::channel();
182        self.inner
183            .active_tasks_join_handles
184            .lock()
185            .expect("Locking failed")
186            .insert_with_key(move |task_key| {
187                (
188                    name.clone(),
189                    crate::runtime::spawn(&name, {
190                        let name = name.clone();
191                        async move {
192                            defer! {
193                                // Panic or normal completion, it means the task
194                                // is complete, and does not need to be shutdown
195                                // via join handle. This prevents buildup of task
196                                // handles.
197                                if handle
198                                    .inner
199                                    .active_tasks_join_handles
200                                    .lock()
201                                    .expect("Locking failed")
202                                    .remove(task_key)
203                                    .is_none() {
204                                        trace!(target: LOG_TASK, %name, "Task already canceled");
205                                    }
206                            }
207                            // Unfortunately log levels need to be static
208                            if quiet {
209                                trace!(target: LOG_TASK, %name, "Starting task");
210                            } else {
211                                debug!(target: LOG_TASK, %name, "Starting task");
212                            }
213                            let r = f(handle.clone()).await;
214                            guard.completed = true;
215
216                            if quiet {
217                                trace!(target: LOG_TASK, %name, "Finished task");
218                            } else {
219                                debug!(target: LOG_TASK, %name, "Finished task");
220                            }
221                            // if receiver is not interested, just drop the message
222                            let _ = tx.send(r);
223
224                            // NOTE: Since this is a `async move` the guard will not get moved
225                            // if it's not moved inside the body. Weird.
226                            drop(guard);
227                        }
228                    }),
229                )
230            });
231
232        rx
233    }
234
235    /// Spawn a task that will get cancelled automatically on `TaskGroup`
236    /// shutdown.
237    pub fn spawn_cancellable<R>(
238        &self,
239        name: impl Into<String>,
240        future: impl Future<Output = R> + MaybeSend + 'static,
241    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
242    where
243        R: MaybeSend + 'static,
244    {
245        self.spawn(name, |handle| async move {
246            let value = handle.cancel_on_shutdown(future).await;
247            if value.is_err() {
248                // name will part of span
249                debug!(target: LOG_TASK, "task cancelled on shutdown");
250            }
251            value
252        })
253    }
254
255    pub fn spawn_cancellable_silent<R>(
256        &self,
257        name: impl Into<String>,
258        future: impl Future<Output = R> + MaybeSend + 'static,
259    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
260    where
261        R: MaybeSend + 'static,
262    {
263        self.spawn_silent(name, |handle| async move {
264            let value = handle.cancel_on_shutdown(future).await;
265            if value.is_err() {
266                // name will part of span
267                debug!(target: LOG_TASK, "task cancelled on shutdown");
268            }
269            value
270        })
271    }
272
273    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
274        let deadline = timeout.map(|timeout| now() + timeout);
275        let mut errors = vec![];
276
277        self.join_all_inner(deadline, &mut errors).await;
278
279        if errors.is_empty() {
280            Ok(())
281        } else {
282            let num_errors = errors.len();
283            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
284        }
285    }
286
287    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
288    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
289    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
290        self.inner.join_all(deadline, errors).await;
291    }
292}
293
294struct TaskPanicGuard {
295    name: String,
296    inner: Arc<TaskGroupInner>,
297    /// Did the future completed successfully (no panic)
298    completed: bool,
299}
300
301impl Drop for TaskPanicGuard {
302    fn drop(&mut self) {
303        trace!(
304            target: LOG_TASK,
305            name = %self.name,
306            "Task drop"
307        );
308        if !self.completed {
309            info!(
310                target: LOG_TASK,
311                name = %self.name,
312                "Task shut down uncleanly"
313            );
314            self.inner.shutdown();
315        }
316    }
317}
318
319#[derive(Clone, Debug)]
320pub struct TaskHandle {
321    inner: Arc<TaskGroupInner>,
322}
323
324#[derive(thiserror::Error, Debug, Clone)]
325#[error("Task group is shutting down")]
326#[non_exhaustive]
327pub struct ShuttingDownError {}
328
329impl TaskHandle {
330    /// Is task group shutting down?
331    ///
332    /// Every task in a task group should detect and stop if `true`.
333    pub fn is_shutting_down(&self) -> bool {
334        self.inner.is_shutting_down()
335    }
336
337    /// Make a [`oneshot::Receiver`] that will fire on shutdown
338    ///
339    /// Tasks can use `select` on the return value to handle shutdown
340    /// signal during otherwise blocking operation.
341    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
342        self.inner.make_shutdown_rx()
343    }
344
345    /// Run the future or cancel it if the [`TaskGroup`] shuts down.
346    pub async fn cancel_on_shutdown<F: Future>(
347        &self,
348        fut: F,
349    ) -> Result<F::Output, ShuttingDownError> {
350        let rx = self.make_shutdown_rx();
351        match future::select(pin!(rx), pin!(fut)).await {
352            Either::Left(((), _)) => Err(ShuttingDownError {}),
353            Either::Right((value, _)) => Ok(value),
354        }
355    }
356}
357
358pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
359
360impl TaskShutdownToken {
361    fn new(mut rx: watch::Receiver<bool>) -> Self {
362        Self(Box::pin(async move {
363            let _ = rx.wait_for(|v| *v).await;
364        }))
365    }
366}
367
368impl Future for TaskShutdownToken {
369    type Output = ();
370
371    fn poll(
372        mut self: Pin<&mut Self>,
373        cx: &mut std::task::Context<'_>,
374    ) -> std::task::Poll<Self::Output> {
375        self.0.as_mut().poll(cx)
376    }
377}
378
379/// async trait that use MaybeSend
380///
381/// # Example
382///
383/// ```rust
384/// use fedimint_core::{apply, async_trait_maybe_send};
385/// #[apply(async_trait_maybe_send!)]
386/// trait Foo {
387///     // methods
388/// }
389///
390/// #[apply(async_trait_maybe_send!)]
391/// impl Foo for () {
392///     // methods
393/// }
394/// ```
395#[macro_export]
396macro_rules! async_trait_maybe_send {
397    ($($tt:tt)*) => {
398        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
399        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
400        $($tt)*
401    };
402}
403
404/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
405///
406/// # Example
407///
408/// ```rust
409/// use std::any::Any;
410///
411/// use fedimint_core::{apply, maybe_add_send};
412/// type Foo = maybe_add_send!(dyn Any);
413/// ```
414#[cfg(not(target_family = "wasm"))]
415#[macro_export]
416macro_rules! maybe_add_send {
417    ($($tt:tt)*) => {
418        $($tt)* + Send
419    };
420}
421
422/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
423///
424/// # Example
425///
426/// ```rust
427/// type Foo = maybe_add_send!(dyn Any);
428/// ```
429#[cfg(target_family = "wasm")]
430#[macro_export]
431macro_rules! maybe_add_send {
432    ($($tt:tt)*) => {
433        $($tt)*
434    };
435}
436
437/// See `maybe_add_send`
438#[cfg(not(target_family = "wasm"))]
439#[macro_export]
440macro_rules! maybe_add_send_sync {
441    ($($tt:tt)*) => {
442        $($tt)* + Send + Sync
443    };
444}
445
446/// See `maybe_add_send`
447#[cfg(target_family = "wasm")]
448#[macro_export]
449macro_rules! maybe_add_send_sync {
450    ($($tt:tt)*) => {
451        $($tt)*
452    };
453}
454
455/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
456///
457/// On wasm, most types don't implement `Send` because JS types can not sent
458/// between workers directly.
459#[cfg(target_family = "wasm")]
460pub trait MaybeSend {}
461
462/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
463///
464/// On wasm, most types don't implement `Send` because JS types can not sent
465/// between workers directly.
466#[cfg(not(target_family = "wasm"))]
467pub trait MaybeSend: Send {}
468
469#[cfg(not(target_family = "wasm"))]
470impl<T: Send> MaybeSend for T {}
471
472#[cfg(target_family = "wasm")]
473impl<T> MaybeSend for T {}
474
475/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
476#[cfg(target_family = "wasm")]
477pub trait MaybeSync {}
478
479/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
480#[cfg(not(target_family = "wasm"))]
481pub trait MaybeSync: Sync {}
482
483#[cfg(not(target_family = "wasm"))]
484impl<T: Sync> MaybeSync for T {}
485
486#[cfg(target_family = "wasm")]
487impl<T> MaybeSync for T {}
488
489// Used in tests when sleep functionality is desired so it can be logged.
490// Must include comment describing the reason for sleeping.
491pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
492    info!(
493        target: LOG_TEST,
494        "Sleeping for {}.{:03} seconds because: {}",
495        duration.as_secs(),
496        duration.subsec_millis(),
497        comment.as_ref()
498    );
499    sleep(duration).await;
500}
501
502/// An error used as a "cancelled" marker in [`Cancellable`].
503#[derive(Error, Debug)]
504#[error("Operation cancelled")]
505pub struct Cancelled;
506
507/// Operation that can potentially get cancelled returning no result (e.g.
508/// program shutdown).
509pub type Cancellable<T> = std::result::Result<T, Cancelled>;
510
511#[cfg(test)]
512mod tests;