1use crate::shutdown::{
23 signal, GracefulShutdown, GracefulShutdownGuard, Shutdown, Signal,
24};
25use dyn_clone::DynClone;
26use futures_util::{
27 future::{select, BoxFuture},
28 Future, FutureExt, TryFutureExt,
29};
30use std::{
31 any::Any,
32 fmt::{Display, Formatter},
33 pin::{pin, Pin},
34 sync::{
35 atomic::{AtomicUsize, Ordering},
36 Arc,
37 },
38 task::{ready, Context, Poll},
39};
40use tokio::{
41 runtime::Handle,
42 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
43 task::JoinHandle,
44};
45
46#[cfg(feature = "rayon")]
47pub mod pool;
48pub mod shutdown;
49
50#[auto_impl::auto_impl(&, Arc)]
51pub trait TaskSpawner:
52 Send + Sync + Unpin + std::fmt::Debug + DynClone
53{
54 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
57
58 fn spawn_critical(
60 &self, name: &'static str, fut: BoxFuture<'static, ()>,
61 ) -> JoinHandle<()>;
62
63 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()>;
65
66 fn spawn_critical_blocking(
68 &self, name: &'static str, fut: BoxFuture<'static, ()>,
69 ) -> JoinHandle<()>;
70}
71
72dyn_clone::clone_trait_object!(TaskSpawner);
73
74#[derive(Debug, Clone, Default)]
76#[non_exhaustive]
77pub struct TokioTaskExecutor;
78
79impl TokioTaskExecutor {
80 pub fn boxed(self) -> Box<dyn TaskSpawner + 'static> { Box::new(self) }
82}
83
84impl TaskSpawner for TokioTaskExecutor {
85 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
86 tokio::task::spawn(fut)
87 }
88
89 fn spawn_critical(
90 &self, _name: &'static str, fut: BoxFuture<'static, ()>,
91 ) -> JoinHandle<()> {
92 tokio::task::spawn(fut)
93 }
94
95 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
96 tokio::task::spawn_blocking(move || {
97 tokio::runtime::Handle::current().block_on(fut)
98 })
99 }
100
101 fn spawn_critical_blocking(
102 &self, _name: &'static str, fut: BoxFuture<'static, ()>,
103 ) -> JoinHandle<()> {
104 tokio::task::spawn_blocking(move || {
105 tokio::runtime::Handle::current().block_on(fut)
106 })
107 }
108}
109
110#[derive(Debug)]
126#[must_use = "TaskManager must be polled to monitor critical tasks"]
127pub struct TaskManager {
128 handle: Handle,
132 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
134 panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
136 signal: Option<Signal>,
140 on_shutdown: Shutdown,
142 graceful_tasks: Arc<AtomicUsize>,
144}
145
146impl TaskManager {
149 pub fn current() -> Self {
155 let handle = Handle::current();
156 Self::new(handle)
157 }
158
159 pub fn new(handle: Handle) -> Self {
161 let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
162 let (signal, on_shutdown) = signal();
163 Self {
164 handle,
165 panicked_tasks_tx,
166 panicked_tasks_rx,
167 signal: Some(signal),
168 on_shutdown,
169 graceful_tasks: Arc::new(AtomicUsize::new(0)),
170 }
171 }
172
173 pub fn executor(&self) -> TaskExecutor {
176 TaskExecutor {
177 handle: self.handle.clone(),
178 on_shutdown: self.on_shutdown.clone(),
179 panicked_tasks_tx: self.panicked_tasks_tx.clone(),
180 graceful_tasks: Arc::clone(&self.graceful_tasks),
181 }
182 }
183
184 pub fn graceful_shutdown(self) { let _ = self.do_graceful_shutdown(None); }
186
187 pub fn graceful_shutdown_with_timeout(
191 self, timeout: std::time::Duration,
192 ) -> bool {
193 self.do_graceful_shutdown(Some(timeout))
194 }
195
196 fn do_graceful_shutdown(
197 self, timeout: Option<std::time::Duration>,
198 ) -> bool {
199 drop(self.signal);
200 let when = timeout.map(|t| std::time::Instant::now() + t);
201 while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
202 if when
203 .map(|when| std::time::Instant::now() > when)
204 .unwrap_or(false)
205 {
206 return false;
207 }
208 std::hint::spin_loop();
209 }
210
211 true
212 }
213}
214
215impl Future for TaskManager {
219 type Output = PanickedTaskError;
220
221 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222 let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
223 Poll::Ready(err.expect("stream can not end"))
224 }
225}
226
227#[derive(Debug, thiserror::Error)]
230pub struct PanickedTaskError {
231 task_name: &'static str,
232 error: Option<String>,
233}
234
235impl Display for PanickedTaskError {
236 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237 let task_name = self.task_name;
238 if let Some(error) = &self.error {
239 write!(f, "Critical task `{task_name}` panicked: `{error}`")
240 } else {
241 write!(f, "Critical task `{task_name}` panicked")
242 }
243 }
244}
245
246impl PanickedTaskError {
247 fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
248 let error = match error.downcast::<String>() {
249 Ok(value) => Some(*value),
250 Err(error) => match error.downcast::<&str>() {
251 Ok(value) => Some(value.to_string()),
252 Err(_) => None,
253 },
254 };
255
256 Self { task_name, error }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct TaskExecutor {
263 handle: Handle,
267 on_shutdown: Shutdown,
269 panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
271 graceful_tasks: Arc<AtomicUsize>,
275}
276
277impl TaskExecutor {
280 pub const fn handle(&self) -> &Handle { &self.handle }
282
283 pub const fn on_shutdown_signal(&self) -> &Shutdown { &self.on_shutdown }
285
286 fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
288 where F: Future<Output = ()> + Send + 'static {
289 match task_kind {
290 TaskKind::Default => self.handle.spawn(fut),
291 TaskKind::Blocking => {
292 let handle = self.handle.clone();
293 self.handle.spawn_blocking(move || handle.block_on(fut))
294 }
295 }
296 }
297
298 fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
300 where F: Future<Output = ()> + Send + 'static {
301 let on_shutdown = self.on_shutdown.clone();
302
303 let task = {
306 async move {
307 let fut = pin!(fut);
308 let _ = select(on_shutdown, fut).await;
309 }
310 };
311
312 self.spawn_on_rt(task, task_kind)
313 }
314
315 pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
320 where F: Future<Output = ()> + Send + 'static {
321 self.spawn_task_as(fut, TaskKind::Default)
322 }
323
324 pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
329 where F: Future<Output = ()> + Send + 'static {
330 self.spawn_task_as(fut, TaskKind::Blocking)
331 }
332
333 pub fn spawn_with_signal<F>(
338 &self, f: impl FnOnce(Shutdown) -> F,
339 ) -> JoinHandle<()>
340 where F: Future<Output = ()> + Send + 'static {
341 let on_shutdown = self.on_shutdown.clone();
342 let fut = f(on_shutdown);
343
344 let task = fut;
345
346 self.handle.spawn(task)
347 }
348
349 fn spawn_critical_as<F>(
351 &self, name: &'static str, fut: F, task_kind: TaskKind,
352 ) -> JoinHandle<()>
353 where F: Future<Output = ()> + Send + 'static {
354 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
355 let on_shutdown = self.on_shutdown.clone();
356
357 let task = std::panic::AssertUnwindSafe(fut).catch_unwind().map_err(
359 move |error| {
360 let task_error = PanickedTaskError::new(name, error);
361 let _ = panicked_tasks_tx.send(task_error);
362 },
363 );
364
365 let task = async move {
366 let task = pin!(task);
367 let _ = select(on_shutdown, task).await;
368 };
369
370 self.spawn_on_rt(task, task_kind)
371 }
372
373 pub fn spawn_critical_blocking<F>(
378 &self, name: &'static str, fut: F,
379 ) -> JoinHandle<()>
380 where F: Future<Output = ()> + Send + 'static {
381 self.spawn_critical_as(name, fut, TaskKind::Blocking)
382 }
383
384 pub fn spawn_critical<F>(
389 &self, name: &'static str, fut: F,
390 ) -> JoinHandle<()>
391 where F: Future<Output = ()> + Send + 'static {
392 self.spawn_critical_as(name, fut, TaskKind::Default)
393 }
394
395 pub fn spawn_critical_with_shutdown_signal<F>(
399 &self, name: &'static str, f: impl FnOnce(Shutdown) -> F,
400 ) -> JoinHandle<()>
401 where F: Future<Output = ()> + Send + 'static {
402 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
403 let on_shutdown = self.on_shutdown.clone();
404 let fut = f(on_shutdown);
405
406 let task = std::panic::AssertUnwindSafe(fut)
408 .catch_unwind()
409 .map_err(move |error| {
410 let task_error = PanickedTaskError::new(name, error);
411 let _ = panicked_tasks_tx.send(task_error);
412 })
413 .map(drop);
414
415 self.handle.spawn(task)
416 }
417
418 pub fn spawn_critical_with_graceful_shutdown_signal<F>(
443 &self, name: &'static str, f: impl FnOnce(GracefulShutdown) -> F,
444 ) -> JoinHandle<()>
445 where F: Future<Output = ()> + Send + 'static {
446 let panicked_tasks_tx = self.panicked_tasks_tx.clone();
447 let on_shutdown = GracefulShutdown::new(
448 self.on_shutdown.clone(),
449 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
450 );
451 let fut = f(on_shutdown);
452
453 let task = std::panic::AssertUnwindSafe(fut)
455 .catch_unwind()
456 .map_err(move |error| {
457 let task_error = PanickedTaskError::new(name, error);
458 let _ = panicked_tasks_tx.send(task_error);
459 })
460 .map(drop);
461
462 self.handle.spawn(task)
463 }
464
465 pub fn spawn_with_graceful_shutdown_signal<F>(
486 &self, f: impl FnOnce(GracefulShutdown) -> F,
487 ) -> JoinHandle<()>
488 where F: Future<Output = ()> + Send + 'static {
489 let on_shutdown = GracefulShutdown::new(
490 self.on_shutdown.clone(),
491 GracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
492 );
493 let fut = f(on_shutdown);
494
495 self.handle.spawn(fut)
496 }
497}
498
499impl TaskSpawner for TaskExecutor {
500 fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
501 self.spawn(fut)
502 }
503
504 fn spawn_critical(
505 &self, name: &'static str, fut: BoxFuture<'static, ()>,
506 ) -> JoinHandle<()> {
507 Self::spawn_critical(self, name, fut)
508 }
509
510 fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
511 self.spawn_blocking(fut)
512 }
513
514 fn spawn_critical_blocking(
515 &self, name: &'static str, fut: BoxFuture<'static, ()>,
516 ) -> JoinHandle<()> {
517 Self::spawn_critical_blocking(self, name, fut)
518 }
519}
520
521#[auto_impl::auto_impl(&, Arc)]
523pub trait TaskSpawnerExt:
524 Send + Sync + Unpin + std::fmt::Debug + DynClone
525{
526 fn spawn_critical_with_graceful_shutdown_signal<F>(
532 &self, name: &'static str, f: impl FnOnce(GracefulShutdown) -> F,
533 ) -> JoinHandle<()>
534 where F: Future<Output = ()> + Send + 'static;
535
536 fn spawn_with_graceful_shutdown_signal<F>(
541 &self, f: impl FnOnce(GracefulShutdown) -> F,
542 ) -> JoinHandle<()>
543 where F: Future<Output = ()> + Send + 'static;
544}
545
546impl TaskSpawnerExt for TaskExecutor {
547 fn spawn_critical_with_graceful_shutdown_signal<F>(
548 &self, name: &'static str, f: impl FnOnce(GracefulShutdown) -> F,
549 ) -> JoinHandle<()>
550 where F: Future<Output = ()> + Send + 'static {
551 Self::spawn_critical_with_graceful_shutdown_signal(self, name, f)
552 }
553
554 fn spawn_with_graceful_shutdown_signal<F>(
555 &self, f: impl FnOnce(GracefulShutdown) -> F,
556 ) -> JoinHandle<()>
557 where F: Future<Output = ()> + Send + 'static {
558 Self::spawn_with_graceful_shutdown_signal(self, f)
559 }
560}
561
562enum TaskKind {
564 Default,
566 Blocking,
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use std::{sync::atomic::AtomicBool, time::Duration};
574
575 #[test]
576 fn test_cloneable() {
577 #[derive(Clone)]
578 struct ExecutorWrapper {
579 _e: Box<dyn TaskSpawner>,
580 }
581
582 let executor: Box<dyn TaskSpawner> =
583 Box::<TokioTaskExecutor>::default();
584 let _e = dyn_clone::clone_box(&*executor);
585
586 let e = ExecutorWrapper { _e };
587 let _e2 = e;
588 }
589
590 #[test]
591 fn test_critical() {
592 let runtime = tokio::runtime::Runtime::new().unwrap();
593 let handle = runtime.handle().clone();
594 let manager = TaskManager::new(handle);
595 let executor = manager.executor();
596
597 executor.spawn_critical("this is a critical task", async {
598 panic!("intentionally panic")
599 });
600
601 runtime.block_on(async move {
602 let err = manager.await;
603 assert_eq!(err.task_name, "this is a critical task");
604 assert_eq!(err.error, Some("intentionally panic".to_string()));
605 })
606 }
607
608 #[test]
610 fn test_manager_shutdown_critical() {
611 let runtime = tokio::runtime::Runtime::new().unwrap();
612 let handle = runtime.handle().clone();
613 let manager = TaskManager::new(handle.clone());
614 let executor = manager.executor();
615
616 let (signal, shutdown) = signal();
617
618 executor.spawn_critical("this is a critical task", async move {
619 tokio::time::sleep(Duration::from_millis(200)).await;
620 drop(signal);
621 });
622
623 drop(manager);
624
625 handle.block_on(shutdown);
626 }
627
628 #[test]
630 fn test_manager_shutdown() {
631 let runtime = tokio::runtime::Runtime::new().unwrap();
632 let handle = runtime.handle().clone();
633 let manager = TaskManager::new(handle.clone());
634 let executor = manager.executor();
635
636 let (signal, shutdown) = signal();
637
638 executor.spawn(Box::pin(async move {
639 tokio::time::sleep(Duration::from_millis(200)).await;
640 drop(signal);
641 }));
642
643 drop(manager);
644
645 handle.block_on(shutdown);
646 }
647
648 #[test]
649 fn test_manager_graceful_shutdown() {
650 let runtime = tokio::runtime::Runtime::new().unwrap();
651 let handle = runtime.handle().clone();
652 let manager = TaskManager::new(handle);
653 let executor = manager.executor();
654
655 let val = Arc::new(AtomicBool::new(false));
656 let c = val.clone();
657 executor.spawn_critical_with_graceful_shutdown_signal(
658 "grace",
659 |shutdown| async move {
660 let _guard = shutdown.await;
661 tokio::time::sleep(Duration::from_millis(200)).await;
662 c.store(true, Ordering::Relaxed);
663 },
664 );
665
666 manager.graceful_shutdown();
667 assert!(val.load(Ordering::Relaxed));
668 }
669
670 #[test]
671 fn test_manager_graceful_shutdown_many() {
672 let runtime = tokio::runtime::Runtime::new().unwrap();
673 let handle = runtime.handle().clone();
674 let manager = TaskManager::new(handle);
675 let executor = manager.executor();
676
677 let counter = Arc::new(AtomicUsize::new(0));
678 let num = 10;
679 for _ in 0..num {
680 let c = counter.clone();
681 executor.spawn_critical_with_graceful_shutdown_signal(
682 "grace",
683 move |shutdown| async move {
684 let _guard = shutdown.await;
685 tokio::time::sleep(Duration::from_millis(200)).await;
686 c.fetch_add(1, Ordering::SeqCst);
687 },
688 );
689 }
690
691 manager.graceful_shutdown();
692 assert_eq!(counter.load(Ordering::Relaxed), num);
693 }
694
695 #[test]
696 fn test_manager_graceful_shutdown_timeout() {
697 let runtime = tokio::runtime::Runtime::new().unwrap();
698 let handle = runtime.handle().clone();
699 let manager = TaskManager::new(handle);
700 let executor = manager.executor();
701
702 let timeout = Duration::from_millis(500);
703 let val = Arc::new(AtomicBool::new(false));
704 let val2 = val.clone();
705 executor.spawn_critical_with_graceful_shutdown_signal(
706 "grace",
707 |shutdown| async move {
708 let _guard = shutdown.await;
709 tokio::time::sleep(timeout * 3).await;
710 val2.store(true, Ordering::Relaxed);
711 unreachable!("should not be reached");
712 },
713 );
714
715 manager.graceful_shutdown_with_timeout(timeout);
716 assert!(!val.load(Ordering::Relaxed));
717 }
718}