1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{debug, error, info, trace};
23
24use crate::runtime;
25pub use crate::runtime::*;
28#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40 inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn make_handle(&self) -> TaskHandle {
49 TaskHandle {
50 inner: self.inner.clone(),
51 }
52 }
53
54 pub fn make_subgroup(&self) -> Self {
67 let new_tg = Self::new();
68 self.inner.add_subgroup(new_tg.clone());
69 new_tg
70 }
71
72 pub fn is_shutting_down(&self) -> bool {
74 self.inner.is_shutting_down()
75 }
76
77 pub fn shutdown(&self) {
80 self.inner.shutdown();
81 }
82
83 pub async fn shutdown_join_all(
85 self,
86 join_timeout: impl Into<Option<Duration>>,
87 ) -> Result<(), anyhow::Error> {
88 self.shutdown();
89 self.join_all(join_timeout.into()).await
90 }
91
92 #[cfg(not(target_family = "wasm"))]
95 pub fn install_kill_handler(&self) {
96 async fn wait_for_shutdown_signal() {
98 use tokio::signal;
99
100 let ctrl_c = async {
101 signal::ctrl_c()
102 .await
103 .expect("failed to install Ctrl+C handler");
104 };
105
106 #[cfg(unix)]
107 let terminate = async {
108 signal::unix::signal(signal::unix::SignalKind::terminate())
109 .expect("failed to install signal handler")
110 .recv()
111 .await;
112 };
113
114 #[cfg(not(unix))]
115 let terminate = std::future::pending::<()>();
116
117 tokio::select! {
118 () = ctrl_c => {},
119 () = terminate => {},
120 }
121 }
122
123 runtime::spawn("kill handlers", {
124 let task_group = self.clone();
125 async move {
126 wait_for_shutdown_signal().await;
127 info!(
128 target: LOG_TASK,
129 "signal received, starting graceful shutdown"
130 );
131 task_group.shutdown();
132 }
133 });
134 }
135
136 pub fn spawn<Fut, R>(
137 &self,
138 name: impl Into<String>,
139 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
140 ) -> oneshot::Receiver<R>
141 where
142 Fut: Future<Output = R> + MaybeSend + 'static,
143 R: MaybeSend + 'static,
144 {
145 self.spawn_inner(name, f, false)
146 }
147
148 pub fn spawn_silent<Fut, R>(
152 &self,
153 name: impl Into<String>,
154 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
155 ) -> oneshot::Receiver<R>
156 where
157 Fut: Future<Output = R> + MaybeSend + 'static,
158 R: MaybeSend + 'static,
159 {
160 self.spawn_inner(name, f, true)
161 }
162
163 fn spawn_inner<Fut, R>(
164 &self,
165 name: impl Into<String>,
166 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
167 quiet: bool,
168 ) -> oneshot::Receiver<R>
169 where
170 Fut: Future<Output = R> + MaybeSend + 'static,
171 R: MaybeSend + 'static,
172 {
173 let name = name.into();
174 let mut guard = TaskPanicGuard {
175 name: name.clone(),
176 inner: self.inner.clone(),
177 completed: false,
178 };
179 let handle = self.make_handle();
180
181 let (tx, rx) = oneshot::channel();
182 self.inner
183 .active_tasks_join_handles
184 .lock()
185 .expect("Locking failed")
186 .insert_with_key(move |task_key| {
187 (
188 name.clone(),
189 crate::runtime::spawn(&name, {
190 let name = name.clone();
191 async move {
192 defer! {
193 if handle
198 .inner
199 .active_tasks_join_handles
200 .lock()
201 .expect("Locking failed")
202 .remove(task_key)
203 .is_none() {
204 trace!(target: LOG_TASK, %name, "Task already canceled");
205 }
206 }
207 if quiet {
209 trace!(target: LOG_TASK, %name, "Starting task");
210 } else {
211 debug!(target: LOG_TASK, %name, "Starting task");
212 }
213 let r = f(handle.clone()).await;
214 guard.completed = true;
215
216 if quiet {
217 trace!(target: LOG_TASK, %name, "Finished task");
218 } else {
219 debug!(target: LOG_TASK, %name, "Finished task");
220 }
221 let _ = tx.send(r);
223
224 drop(guard);
227 }
228 }),
229 )
230 });
231
232 rx
233 }
234
235 pub fn spawn_cancellable<R>(
238 &self,
239 name: impl Into<String>,
240 future: impl Future<Output = R> + MaybeSend + 'static,
241 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
242 where
243 R: MaybeSend + 'static,
244 {
245 self.spawn(name, |handle| async move {
246 let value = handle.cancel_on_shutdown(future).await;
247 if value.is_err() {
248 debug!(target: LOG_TASK, "task cancelled on shutdown");
250 }
251 value
252 })
253 }
254
255 pub fn spawn_cancellable_silent<R>(
256 &self,
257 name: impl Into<String>,
258 future: impl Future<Output = R> + MaybeSend + 'static,
259 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
260 where
261 R: MaybeSend + 'static,
262 {
263 self.spawn_silent(name, |handle| async move {
264 let value = handle.cancel_on_shutdown(future).await;
265 if value.is_err() {
266 debug!(target: LOG_TASK, "task cancelled on shutdown");
268 }
269 value
270 })
271 }
272
273 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
274 let deadline = timeout.map(|timeout| now() + timeout);
275 let mut errors = vec![];
276
277 self.join_all_inner(deadline, &mut errors).await;
278
279 if errors.is_empty() {
280 Ok(())
281 } else {
282 let num_errors = errors.len();
283 bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
284 }
285 }
286
287 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
288 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
289 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
290 self.inner.join_all(deadline, errors).await;
291 }
292}
293
294struct TaskPanicGuard {
295 name: String,
296 inner: Arc<TaskGroupInner>,
297 completed: bool,
299}
300
301impl Drop for TaskPanicGuard {
302 fn drop(&mut self) {
303 trace!(
304 target: LOG_TASK,
305 name = %self.name,
306 "Task drop"
307 );
308 if !self.completed {
309 info!(
310 target: LOG_TASK,
311 name = %self.name,
312 "Task shut down uncleanly"
313 );
314 self.inner.shutdown();
315 }
316 }
317}
318
319#[derive(Clone, Debug)]
320pub struct TaskHandle {
321 inner: Arc<TaskGroupInner>,
322}
323
324#[derive(thiserror::Error, Debug, Clone)]
325#[error("Task group is shutting down")]
326#[non_exhaustive]
327pub struct ShuttingDownError {}
328
329impl TaskHandle {
330 pub fn is_shutting_down(&self) -> bool {
334 self.inner.is_shutting_down()
335 }
336
337 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
342 self.inner.make_shutdown_rx()
343 }
344
345 pub async fn cancel_on_shutdown<F: Future>(
347 &self,
348 fut: F,
349 ) -> Result<F::Output, ShuttingDownError> {
350 let rx = self.make_shutdown_rx();
351 match future::select(pin!(rx), pin!(fut)).await {
352 Either::Left(((), _)) => Err(ShuttingDownError {}),
353 Either::Right((value, _)) => Ok(value),
354 }
355 }
356}
357
358pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
359
360impl TaskShutdownToken {
361 fn new(mut rx: watch::Receiver<bool>) -> Self {
362 Self(Box::pin(async move {
363 let _ = rx.wait_for(|v| *v).await;
364 }))
365 }
366}
367
368impl Future for TaskShutdownToken {
369 type Output = ();
370
371 fn poll(
372 mut self: Pin<&mut Self>,
373 cx: &mut std::task::Context<'_>,
374 ) -> std::task::Poll<Self::Output> {
375 self.0.as_mut().poll(cx)
376 }
377}
378
379#[macro_export]
396macro_rules! async_trait_maybe_send {
397 ($($tt:tt)*) => {
398 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
399 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
400 $($tt)*
401 };
402}
403
404#[cfg(not(target_family = "wasm"))]
415#[macro_export]
416macro_rules! maybe_add_send {
417 ($($tt:tt)*) => {
418 $($tt)* + Send
419 };
420}
421
422#[cfg(target_family = "wasm")]
430#[macro_export]
431macro_rules! maybe_add_send {
432 ($($tt:tt)*) => {
433 $($tt)*
434 };
435}
436
437#[cfg(not(target_family = "wasm"))]
439#[macro_export]
440macro_rules! maybe_add_send_sync {
441 ($($tt:tt)*) => {
442 $($tt)* + Send + Sync
443 };
444}
445
446#[cfg(target_family = "wasm")]
448#[macro_export]
449macro_rules! maybe_add_send_sync {
450 ($($tt:tt)*) => {
451 $($tt)*
452 };
453}
454
455#[cfg(target_family = "wasm")]
460pub trait MaybeSend {}
461
462#[cfg(not(target_family = "wasm"))]
467pub trait MaybeSend: Send {}
468
469#[cfg(not(target_family = "wasm"))]
470impl<T: Send> MaybeSend for T {}
471
472#[cfg(target_family = "wasm")]
473impl<T> MaybeSend for T {}
474
475#[cfg(target_family = "wasm")]
477pub trait MaybeSync {}
478
479#[cfg(not(target_family = "wasm"))]
481pub trait MaybeSync: Sync {}
482
483#[cfg(not(target_family = "wasm"))]
484impl<T: Sync> MaybeSync for T {}
485
486#[cfg(target_family = "wasm")]
487impl<T> MaybeSync for T {}
488
489pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
492 info!(
493 target: LOG_TEST,
494 "Sleeping for {}.{:03} seconds because: {}",
495 duration.as_secs(),
496 duration.subsec_millis(),
497 comment.as_ref()
498 );
499 sleep(duration).await;
500}
501
502#[derive(Error, Debug)]
504#[error("Operation cancelled")]
505pub struct Cancelled;
506
507pub type Cancellable<T> = std::result::Result<T, Cancelled>;
510
511#[cfg(test)]
512mod tests;