1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::SystemTime;
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{Span, debug, info, trace};
23
24use crate::runtime;
25pub use crate::runtime::*;
28#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40 inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn make_handle(&self) -> TaskHandle {
49 TaskHandle {
50 inner: self.inner.clone(),
51 }
52 }
53
54 pub fn make_subgroup(&self) -> Self {
67 let new_tg = Self::new();
68 self.inner.add_subgroup(new_tg.clone());
69 new_tg
70 }
71
72 pub fn is_shutting_down(&self) -> bool {
74 self.inner.is_shutting_down()
75 }
76
77 pub fn shutdown(&self) {
80 self.inner.shutdown();
81 }
82
83 pub async fn shutdown_join_all(
85 self,
86 join_timeout: impl Into<Option<Duration>>,
87 ) -> Result<(), anyhow::Error> {
88 self.shutdown();
89 self.join_all(join_timeout.into()).await
90 }
91
92 #[cfg(not(target_family = "wasm"))]
95 pub fn install_kill_handler(&self) {
96 async fn wait_for_shutdown_signal() {
98 use tokio::signal;
99
100 let ctrl_c = async {
101 signal::ctrl_c()
102 .await
103 .expect("failed to install Ctrl+C handler");
104 };
105
106 #[cfg(unix)]
107 let terminate = async {
108 signal::unix::signal(signal::unix::SignalKind::terminate())
109 .expect("failed to install signal handler")
110 .recv()
111 .await;
112 };
113
114 #[cfg(not(unix))]
115 let terminate = std::future::pending::<()>();
116
117 tokio::select! {
118 () = ctrl_c => {},
119 () = terminate => {},
120 }
121 }
122
123 runtime::spawn("kill handlers", {
124 let task_group = self.clone();
125 async move {
126 wait_for_shutdown_signal().await;
127 info!(
128 target: LOG_TASK,
129 "signal received, starting graceful shutdown"
130 );
131 task_group.shutdown();
132 }
133 });
134 }
135
136 pub fn spawn<Fut, R>(
137 &self,
138 name: impl Into<String>,
139 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
140 ) -> oneshot::Receiver<R>
141 where
142 Fut: Future<Output = R> + MaybeSend + 'static,
143 R: MaybeSend + 'static,
144 {
145 self.spawn_inner(name, f, false, None)
146 }
147
148 pub fn spawn_silent<Fut, R>(
152 &self,
153 name: impl Into<String>,
154 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
155 ) -> oneshot::Receiver<R>
156 where
157 Fut: Future<Output = R> + MaybeSend + 'static,
158 R: MaybeSend + 'static,
159 {
160 self.spawn_inner(name, f, true, None)
161 }
162
163 pub fn spawn_with_span<Fut, R>(
167 &self,
168 parent_span: Span,
169 name: impl Into<String>,
170 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
171 ) -> oneshot::Receiver<R>
172 where
173 Fut: Future<Output = R> + MaybeSend + 'static,
174 R: MaybeSend + 'static,
175 {
176 self.spawn_inner(name, f, false, Some(parent_span))
177 }
178
179 fn spawn_inner<Fut, R>(
180 &self,
181 name: impl Into<String>,
182 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
183 quiet: bool,
184 parent_span: Option<Span>,
185 ) -> oneshot::Receiver<R>
186 where
187 Fut: Future<Output = R> + MaybeSend + 'static,
188 R: MaybeSend + 'static,
189 {
190 let name = name.into();
191 let mut guard = TaskPanicGuard {
192 name: name.clone(),
193 inner: self.inner.clone(),
194 completed: false,
195 };
196 let handle = self.make_handle();
197
198 let (tx, rx) = oneshot::channel();
199 self.inner
200 .active_tasks_join_handles
201 .lock()
202 .expect("Locking failed")
203 .insert_with_key(move |task_key| {
204 let task_future = {
205 let name = name.clone();
206 async move {
207 defer! {
208 if handle
213 .inner
214 .active_tasks_join_handles
215 .lock()
216 .expect("Locking failed")
217 .remove(task_key)
218 .is_none() {
219 trace!(target: LOG_TASK, %name, "Task already canceled");
220 }
221 }
222 if quiet {
224 trace!(target: LOG_TASK, %name, "Starting task");
225 } else {
226 debug!(target: LOG_TASK, %name, "Starting task");
227 }
228 let r = f(handle.clone()).await;
229 guard.completed = true;
230
231 if quiet {
232 trace!(target: LOG_TASK, %name, "Finished task");
233 } else {
234 debug!(target: LOG_TASK, %name, "Finished task");
235 }
236 let _ = tx.send(r);
238
239 drop(guard);
242 }
243 };
244 let join_handle = match parent_span.as_ref() {
245 Some(parent) => crate::runtime::spawn_with_span(parent, &name, task_future),
246 None => crate::runtime::spawn(&name, task_future),
247 };
248 (name, join_handle)
249 });
250
251 rx
252 }
253
254 pub fn spawn_cancellable<R>(
257 &self,
258 name: impl Into<String>,
259 future: impl Future<Output = R> + MaybeSend + 'static,
260 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
261 where
262 R: MaybeSend + 'static,
263 {
264 self.spawn(name, |handle| async move {
265 let value = handle.cancel_on_shutdown(future).await;
266 if value.is_err() {
267 debug!(target: LOG_TASK, "task cancelled on shutdown");
269 }
270 value
271 })
272 }
273
274 pub fn spawn_cancellable_with_span<R>(
277 &self,
278 parent_span: Span,
279 name: impl Into<String>,
280 future: impl Future<Output = R> + MaybeSend + 'static,
281 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
282 where
283 R: MaybeSend + 'static,
284 {
285 self.spawn_with_span(parent_span, name, |handle| async move {
286 let value = handle.cancel_on_shutdown(future).await;
287 if value.is_err() {
288 debug!(target: LOG_TASK, "task cancelled on shutdown");
290 }
291 value
292 })
293 }
294
295 pub fn spawn_cancellable_silent<R>(
296 &self,
297 name: impl Into<String>,
298 future: impl Future<Output = R> + MaybeSend + 'static,
299 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
300 where
301 R: MaybeSend + 'static,
302 {
303 self.spawn_silent(name, |handle| async move {
304 let value = handle.cancel_on_shutdown(future).await;
305 if value.is_err() {
306 debug!(target: LOG_TASK, "task cancelled on shutdown");
308 }
309 value
310 })
311 }
312
313 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
314 let deadline = timeout.map(|timeout| now() + timeout);
315 let mut errors = vec![];
316
317 self.join_all_inner(deadline, &mut errors).await;
318
319 if errors.is_empty() {
320 Ok(())
321 } else {
322 let num_errors = errors.len();
323 bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
324 }
325 }
326
327 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
328 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
329 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
330 self.inner.join_all(deadline, errors).await;
331 }
332}
333
334struct TaskPanicGuard {
335 name: String,
336 inner: Arc<TaskGroupInner>,
337 completed: bool,
339}
340
341impl Drop for TaskPanicGuard {
342 fn drop(&mut self) {
343 trace!(
344 target: LOG_TASK,
345 name = %self.name,
346 "Task drop"
347 );
348 if !self.completed {
349 info!(
350 target: LOG_TASK,
351 name = %self.name,
352 "Task shut down uncleanly"
353 );
354 self.inner.shutdown();
355 }
356 }
357}
358
359#[derive(Clone, Debug)]
360pub struct TaskHandle {
361 inner: Arc<TaskGroupInner>,
362}
363
364#[derive(thiserror::Error, Debug, Clone)]
365#[error("Task group is shutting down")]
366#[non_exhaustive]
367pub struct ShuttingDownError {}
368
369impl TaskHandle {
370 pub fn is_shutting_down(&self) -> bool {
374 self.inner.is_shutting_down()
375 }
376
377 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
382 self.inner.make_shutdown_rx()
383 }
384
385 pub async fn cancel_on_shutdown<F: Future>(
387 &self,
388 fut: F,
389 ) -> Result<F::Output, ShuttingDownError> {
390 let rx = self.make_shutdown_rx();
391 match future::select(pin!(rx), pin!(fut)).await {
392 Either::Left(((), _)) => Err(ShuttingDownError {}),
393 Either::Right((value, _)) => Ok(value),
394 }
395 }
396}
397
398pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
399
400impl TaskShutdownToken {
401 fn new(mut rx: watch::Receiver<bool>) -> Self {
402 Self(Box::pin(async move {
403 let _ = rx.wait_for(|v| *v).await;
404 }))
405 }
406}
407
408impl Future for TaskShutdownToken {
409 type Output = ();
410
411 fn poll(
412 mut self: Pin<&mut Self>,
413 cx: &mut std::task::Context<'_>,
414 ) -> std::task::Poll<Self::Output> {
415 self.0.as_mut().poll(cx)
416 }
417}
418
419#[macro_export]
436macro_rules! async_trait_maybe_send {
437 ($($tt:tt)*) => {
438 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
439 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
440 $($tt)*
441 };
442}
443
444#[cfg(not(target_family = "wasm"))]
455#[macro_export]
456macro_rules! maybe_add_send {
457 ($($tt:tt)*) => {
458 $($tt)* + Send
459 };
460}
461
462#[cfg(target_family = "wasm")]
470#[macro_export]
471macro_rules! maybe_add_send {
472 ($($tt:tt)*) => {
473 $($tt)*
474 };
475}
476
477#[cfg(not(target_family = "wasm"))]
479#[macro_export]
480macro_rules! maybe_add_send_sync {
481 ($($tt:tt)*) => {
482 $($tt)* + Send + Sync
483 };
484}
485
486#[cfg(target_family = "wasm")]
488#[macro_export]
489macro_rules! maybe_add_send_sync {
490 ($($tt:tt)*) => {
491 $($tt)*
492 };
493}
494
495#[cfg(target_family = "wasm")]
500pub trait MaybeSend {}
501
502#[cfg(not(target_family = "wasm"))]
507pub trait MaybeSend: Send {}
508
509#[cfg(not(target_family = "wasm"))]
510impl<T: Send> MaybeSend for T {}
511
512#[cfg(target_family = "wasm")]
513impl<T> MaybeSend for T {}
514
515#[cfg(target_family = "wasm")]
517pub trait MaybeSync {}
518
519#[cfg(not(target_family = "wasm"))]
521pub trait MaybeSync: Sync {}
522
523#[cfg(not(target_family = "wasm"))]
524impl<T: Sync> MaybeSync for T {}
525
526#[cfg(target_family = "wasm")]
527impl<T> MaybeSync for T {}
528
529pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
532 info!(
533 target: LOG_TEST,
534 "Sleeping for {}.{:03} seconds because: {}",
535 duration.as_secs(),
536 duration.subsec_millis(),
537 comment.as_ref()
538 );
539 sleep(duration).await;
540}
541
542#[derive(Error, Debug)]
544#[error("Operation cancelled")]
545pub struct Cancelled;
546
547pub type Cancellable<T> = std::result::Result<T, Cancelled>;
550
551#[cfg(test)]
552mod tests;