#![cfg_attr(target_family = "wasm", allow(dead_code))]
mod inner;
pub mod jit;
pub mod waiter;
use std::future::Future;
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use anyhow::bail;
use fedimint_core::time::now;
use fedimint_logging::{LOG_TASK, LOG_TEST};
use futures::future::{self, Either};
use inner::TaskGroupInner;
use thiserror::Error;
use tokio::sync::{oneshot, watch};
use tracing::{debug, error, info};
use crate::runtime;
pub use crate::runtime::*;
#[derive(Clone, Default, Debug)]
pub struct TaskGroup {
inner: Arc<TaskGroupInner>,
}
impl TaskGroup {
pub fn new() -> Self {
Self::default()
}
pub fn make_handle(&self) -> TaskHandle {
TaskHandle {
inner: self.inner.clone(),
}
}
pub fn make_subgroup(&self) -> Self {
let new_tg = Self::new();
self.inner.add_subgroup(new_tg.clone());
new_tg
}
pub fn shutdown(&self) {
self.inner.shutdown();
}
pub async fn shutdown_join_all(
self,
join_timeout: impl Into<Option<Duration>>,
) -> Result<(), anyhow::Error> {
self.shutdown();
self.join_all(join_timeout.into()).await
}
#[cfg(not(target_family = "wasm"))]
pub fn install_kill_handler(&self) {
async fn wait_for_shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
runtime::spawn("kill handlers", {
let task_group = self.clone();
async move {
wait_for_shutdown_signal().await;
info!(
target: LOG_TASK,
"signal received, starting graceful shutdown"
);
task_group.shutdown();
}
});
}
pub fn spawn<Fut, R>(
&self,
name: impl Into<String>,
f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
) -> oneshot::Receiver<R>
where
Fut: Future<Output = R> + MaybeSend + 'static,
R: MaybeSend + 'static,
{
let name = name.into();
let mut guard = TaskPanicGuard {
name: name.clone(),
inner: self.inner.clone(),
completed: false,
};
let handle = self.make_handle();
let (tx, rx) = oneshot::channel();
let handle = crate::runtime::spawn(&name, {
let name = name.clone();
async move {
debug!("Starting task {name}");
let r = f(handle).await;
debug!("Finished task {name}");
let _ = tx.send(r);
}
});
self.inner.add_join_handle(name, handle);
guard.completed = true;
rx
}
pub fn spawn_cancellable<R>(
&self,
name: impl Into<String>,
future: impl Future<Output = R> + MaybeSend + 'static,
) -> oneshot::Receiver<Result<R, ShuttingDownError>>
where
R: MaybeSend + 'static,
{
self.spawn(name, |handle| async move {
let value = handle.cancel_on_shutdown(future).await;
if value.is_err() {
debug!("task cancelled on shutdown");
}
value
})
}
pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
let deadline = timeout.map(|timeout| now() + timeout);
let mut errors = vec![];
self.join_all_inner(deadline, &mut errors).await;
if errors.is_empty() {
Ok(())
} else {
let num_errors = errors.len();
bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
}
}
#[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
#[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
self.inner.join_all(deadline, errors).await;
}
}
struct TaskPanicGuard {
name: String,
inner: Arc<TaskGroupInner>,
completed: bool,
}
impl Drop for TaskPanicGuard {
fn drop(&mut self) {
if !self.completed {
info!(
target: LOG_TASK,
"Task {} shut down uncleanly. Shutting down task group.", self.name
);
self.inner.shutdown();
}
}
}
#[derive(Clone, Debug)]
pub struct TaskHandle {
inner: Arc<TaskGroupInner>,
}
#[derive(thiserror::Error, Debug, Clone)]
#[error("Task group is shutting down")]
#[non_exhaustive]
pub struct ShuttingDownError {}
impl TaskHandle {
pub fn is_shutting_down(&self) -> bool {
self.inner.is_shutting_down()
}
pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
self.inner.make_shutdown_rx()
}
pub async fn cancel_on_shutdown<F: Future>(
&self,
fut: F,
) -> Result<F::Output, ShuttingDownError> {
let rx = self.make_shutdown_rx();
match future::select(pin!(rx), pin!(fut)).await {
Either::Left(((), _)) => Err(ShuttingDownError {}),
Either::Right((value, _)) => Ok(value),
}
}
}
pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
impl TaskShutdownToken {
fn new(mut rx: watch::Receiver<bool>) -> Self {
Self(Box::pin(async move {
let _ = rx.wait_for(|v| *v).await;
}))
}
}
impl Future for TaskShutdownToken {
type Output = ();
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
#[macro_export]
macro_rules! async_trait_maybe_send {
($($tt:tt)*) => {
#[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
$($tt)*
};
}
#[cfg(not(target_family = "wasm"))]
#[macro_export]
macro_rules! maybe_add_send {
($($tt:tt)*) => {
$($tt)* + Send
};
}
#[cfg(target_family = "wasm")]
#[macro_export]
macro_rules! maybe_add_send {
($($tt:tt)*) => {
$($tt)*
};
}
#[cfg(not(target_family = "wasm"))]
#[macro_export]
macro_rules! maybe_add_send_sync {
($($tt:tt)*) => {
$($tt)* + Send + Sync
};
}
#[cfg(target_family = "wasm")]
#[macro_export]
macro_rules! maybe_add_send_sync {
($($tt:tt)*) => {
$($tt)*
};
}
#[cfg(target_family = "wasm")]
pub trait MaybeSend {}
#[cfg(not(target_family = "wasm"))]
pub trait MaybeSend: Send {}
#[cfg(not(target_family = "wasm"))]
impl<T: Send> MaybeSend for T {}
#[cfg(target_family = "wasm")]
impl<T> MaybeSend for T {}
#[cfg(target_family = "wasm")]
pub trait MaybeSync {}
#[cfg(not(target_family = "wasm"))]
pub trait MaybeSync: Sync {}
#[cfg(not(target_family = "wasm"))]
impl<T: Sync> MaybeSync for T {}
#[cfg(target_family = "wasm")]
impl<T> MaybeSync for T {}
pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
info!(
target: LOG_TEST,
"Sleeping for {}.{:03} seconds because: {}",
duration.as_secs(),
duration.subsec_millis(),
comment.as_ref()
);
sleep(duration).await;
}
#[derive(Error, Debug)]
#[error("Operation cancelled")]
pub struct Cancelled;
pub type Cancellable<T> = std::result::Result<T, Cancelled>;
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test(tokio::test)]
async fn shutdown_task_group_after() -> anyhow::Result<()> {
let tg = TaskGroup::new();
tg.spawn("shutdown waiter", |handle| async move {
handle.make_shutdown_rx().await;
});
sleep(Duration::from_millis(10)).await;
tg.shutdown_join_all(None).await?;
Ok(())
}
#[test_log::test(tokio::test)]
async fn shutdown_task_group_before() -> anyhow::Result<()> {
let tg = TaskGroup::new();
tg.spawn("shutdown waiter", |handle| async move {
sleep(Duration::from_millis(10)).await;
handle.make_shutdown_rx().await;
});
tg.shutdown_join_all(None).await?;
Ok(())
}
#[test_log::test(tokio::test)]
async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
let tg = TaskGroup::new();
tg.make_subgroup()
.spawn("shutdown waiter", |handle| async move {
handle.make_shutdown_rx().await;
});
sleep(Duration::from_millis(10)).await;
tg.shutdown_join_all(None).await?;
Ok(())
}
#[test_log::test(tokio::test)]
async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
let tg = TaskGroup::new();
tg.make_subgroup()
.spawn("shutdown waiter", |handle| async move {
sleep(Duration::from_millis(10)).await;
handle.make_shutdown_rx().await;
});
tg.shutdown_join_all(None).await?;
Ok(())
}
}