1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{debug, error, info, trace};
23
24use crate::runtime;
25pub use crate::runtime::*;
28#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40 inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn make_handle(&self) -> TaskHandle {
49 TaskHandle {
50 inner: self.inner.clone(),
51 }
52 }
53
54 pub fn make_subgroup(&self) -> Self {
67 let new_tg = Self::new();
68 self.inner.add_subgroup(new_tg.clone());
69 new_tg
70 }
71
72 pub fn shutdown(&self) {
75 self.inner.shutdown();
76 }
77
78 pub async fn shutdown_join_all(
80 self,
81 join_timeout: impl Into<Option<Duration>>,
82 ) -> Result<(), anyhow::Error> {
83 self.shutdown();
84 self.join_all(join_timeout.into()).await
85 }
86
87 #[cfg(not(target_family = "wasm"))]
90 pub fn install_kill_handler(&self) {
91 async fn wait_for_shutdown_signal() {
93 use tokio::signal;
94
95 let ctrl_c = async {
96 signal::ctrl_c()
97 .await
98 .expect("failed to install Ctrl+C handler");
99 };
100
101 #[cfg(unix)]
102 let terminate = async {
103 signal::unix::signal(signal::unix::SignalKind::terminate())
104 .expect("failed to install signal handler")
105 .recv()
106 .await;
107 };
108
109 #[cfg(not(unix))]
110 let terminate = std::future::pending::<()>();
111
112 tokio::select! {
113 () = ctrl_c => {},
114 () = terminate => {},
115 }
116 }
117
118 runtime::spawn("kill handlers", {
119 let task_group = self.clone();
120 async move {
121 wait_for_shutdown_signal().await;
122 info!(
123 target: LOG_TASK,
124 "signal received, starting graceful shutdown"
125 );
126 task_group.shutdown();
127 }
128 });
129 }
130
131 pub fn spawn<Fut, R>(
132 &self,
133 name: impl Into<String>,
134 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
135 ) -> oneshot::Receiver<R>
136 where
137 Fut: Future<Output = R> + MaybeSend + 'static,
138 R: MaybeSend + 'static,
139 {
140 self.spawn_inner(name, f, false)
141 }
142
143 pub fn spawn_silent<Fut, R>(
147 &self,
148 name: impl Into<String>,
149 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
150 ) -> oneshot::Receiver<R>
151 where
152 Fut: Future<Output = R> + MaybeSend + 'static,
153 R: MaybeSend + 'static,
154 {
155 self.spawn_inner(name, f, true)
156 }
157
158 fn spawn_inner<Fut, R>(
159 &self,
160 name: impl Into<String>,
161 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
162 quiet: bool,
163 ) -> oneshot::Receiver<R>
164 where
165 Fut: Future<Output = R> + MaybeSend + 'static,
166 R: MaybeSend + 'static,
167 {
168 let name = name.into();
169 let mut guard = TaskPanicGuard {
170 name: name.clone(),
171 inner: self.inner.clone(),
172 completed: false,
173 };
174 let handle = self.make_handle();
175
176 let (tx, rx) = oneshot::channel();
177 self.inner
178 .active_tasks_join_handles
179 .lock()
180 .expect("Locking failed")
181 .insert_with_key(move |task_key| {
182 (
183 name.clone(),
184 crate::runtime::spawn(&name, {
185 let name = name.clone();
186 async move {
187 defer! {
188 if handle
193 .inner
194 .active_tasks_join_handles
195 .lock()
196 .expect("Locking failed")
197 .remove(task_key)
198 .is_none() {
199 trace!(target: LOG_TASK, %name, "Task already canceled");
200 }
201 }
202 if quiet {
204 trace!(target: LOG_TASK, %name, "Starting task");
205 } else {
206 debug!(target: LOG_TASK, %name, "Starting task");
207 }
208 let r = f(handle.clone()).await;
209 guard.completed = true;
210
211 if quiet {
212 trace!(target: LOG_TASK, %name, "Finished task");
213 } else {
214 debug!(target: LOG_TASK, %name, "Finished task");
215 }
216 let _ = tx.send(r);
218
219 drop(guard);
222 }
223 }),
224 )
225 });
226
227 rx
228 }
229
230 pub fn spawn_cancellable<R>(
233 &self,
234 name: impl Into<String>,
235 future: impl Future<Output = R> + MaybeSend + 'static,
236 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
237 where
238 R: MaybeSend + 'static,
239 {
240 self.spawn(name, |handle| async move {
241 let value = handle.cancel_on_shutdown(future).await;
242 if value.is_err() {
243 debug!(target: LOG_TASK, "task cancelled on shutdown");
245 }
246 value
247 })
248 }
249
250 pub fn spawn_cancellable_silent<R>(
251 &self,
252 name: impl Into<String>,
253 future: impl Future<Output = R> + MaybeSend + 'static,
254 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
255 where
256 R: MaybeSend + 'static,
257 {
258 self.spawn_silent(name, |handle| async move {
259 let value = handle.cancel_on_shutdown(future).await;
260 if value.is_err() {
261 debug!(target: LOG_TASK, "task cancelled on shutdown");
263 }
264 value
265 })
266 }
267
268 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
269 let deadline = timeout.map(|timeout| now() + timeout);
270 let mut errors = vec![];
271
272 self.join_all_inner(deadline, &mut errors).await;
273
274 if errors.is_empty() {
275 Ok(())
276 } else {
277 let num_errors = errors.len();
278 bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
279 }
280 }
281
282 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
283 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
284 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
285 self.inner.join_all(deadline, errors).await;
286 }
287}
288
289struct TaskPanicGuard {
290 name: String,
291 inner: Arc<TaskGroupInner>,
292 completed: bool,
294}
295
296impl Drop for TaskPanicGuard {
297 fn drop(&mut self) {
298 trace!(
299 target: LOG_TASK,
300 name = %self.name,
301 "Task drop"
302 );
303 if !self.completed {
304 info!(
305 target: LOG_TASK,
306 name = %self.name,
307 "Task shut down uncleanly"
308 );
309 self.inner.shutdown();
310 }
311 }
312}
313
314#[derive(Clone, Debug)]
315pub struct TaskHandle {
316 inner: Arc<TaskGroupInner>,
317}
318
319#[derive(thiserror::Error, Debug, Clone)]
320#[error("Task group is shutting down")]
321#[non_exhaustive]
322pub struct ShuttingDownError {}
323
324impl TaskHandle {
325 pub fn is_shutting_down(&self) -> bool {
329 self.inner.is_shutting_down()
330 }
331
332 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
337 self.inner.make_shutdown_rx()
338 }
339
340 pub async fn cancel_on_shutdown<F: Future>(
342 &self,
343 fut: F,
344 ) -> Result<F::Output, ShuttingDownError> {
345 let rx = self.make_shutdown_rx();
346 match future::select(pin!(rx), pin!(fut)).await {
347 Either::Left(((), _)) => Err(ShuttingDownError {}),
348 Either::Right((value, _)) => Ok(value),
349 }
350 }
351}
352
353pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
354
355impl TaskShutdownToken {
356 fn new(mut rx: watch::Receiver<bool>) -> Self {
357 Self(Box::pin(async move {
358 let _ = rx.wait_for(|v| *v).await;
359 }))
360 }
361}
362
363impl Future for TaskShutdownToken {
364 type Output = ();
365
366 fn poll(
367 mut self: Pin<&mut Self>,
368 cx: &mut std::task::Context<'_>,
369 ) -> std::task::Poll<Self::Output> {
370 self.0.as_mut().poll(cx)
371 }
372}
373
374#[macro_export]
391macro_rules! async_trait_maybe_send {
392 ($($tt:tt)*) => {
393 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
394 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
395 $($tt)*
396 };
397}
398
399#[cfg(not(target_family = "wasm"))]
410#[macro_export]
411macro_rules! maybe_add_send {
412 ($($tt:tt)*) => {
413 $($tt)* + Send
414 };
415}
416
417#[cfg(target_family = "wasm")]
425#[macro_export]
426macro_rules! maybe_add_send {
427 ($($tt:tt)*) => {
428 $($tt)*
429 };
430}
431
432#[cfg(not(target_family = "wasm"))]
434#[macro_export]
435macro_rules! maybe_add_send_sync {
436 ($($tt:tt)*) => {
437 $($tt)* + Send + Sync
438 };
439}
440
441#[cfg(target_family = "wasm")]
443#[macro_export]
444macro_rules! maybe_add_send_sync {
445 ($($tt:tt)*) => {
446 $($tt)*
447 };
448}
449
450#[cfg(target_family = "wasm")]
455pub trait MaybeSend {}
456
457#[cfg(not(target_family = "wasm"))]
462pub trait MaybeSend: Send {}
463
464#[cfg(not(target_family = "wasm"))]
465impl<T: Send> MaybeSend for T {}
466
467#[cfg(target_family = "wasm")]
468impl<T> MaybeSend for T {}
469
470#[cfg(target_family = "wasm")]
472pub trait MaybeSync {}
473
474#[cfg(not(target_family = "wasm"))]
476pub trait MaybeSync: Sync {}
477
478#[cfg(not(target_family = "wasm"))]
479impl<T: Sync> MaybeSync for T {}
480
481#[cfg(target_family = "wasm")]
482impl<T> MaybeSync for T {}
483
484pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
487 info!(
488 target: LOG_TEST,
489 "Sleeping for {}.{:03} seconds because: {}",
490 duration.as_secs(),
491 duration.subsec_millis(),
492 comment.as_ref()
493 );
494 sleep(duration).await;
495}
496
497#[derive(Error, Debug)]
499#[error("Operation cancelled")]
500pub struct Cancelled;
501
502pub type Cancellable<T> = std::result::Result<T, Cancelled>;
505
506#[cfg(test)]
507mod tests;