cfx_tasks/
shutdown.rs

1// Copyright 2023-2024 Paradigm.xyz
2// This file is part of reth.
3// Reth is a modular, contributor-friendly and blazing-fast implementation of
4// the Ethereum protocol
5
6// Permission is hereby granted, free of charge, to any
7// person obtaining a copy of this software and associated
8// documentation files (the "Software"), to deal in the
9// Software without restriction, including without
10// limitation the rights to use, copy, modify, merge,
11// publish, distribute, sublicense, and/or sell copies of
12// the Software, and to permit persons to whom the Software
13// is furnished to do so, subject to the following
14// conditions:
15
16// The above copyright notice and this permission notice
17// shall be included in all copies or substantial portions
18// of the Software.
19
20//! Helper for shutdown signals
21
22use futures_util::{
23    future::{FusedFuture, Shared},
24    FutureExt,
25};
26use std::{
27    future::Future,
28    pin::Pin,
29    sync::{atomic::AtomicUsize, Arc},
30    task::{ready, Context, Poll},
31};
32use tokio::sync::oneshot;
33
34/// A Future that resolves when the shutdown event has been fired.
35#[derive(Debug)]
36pub struct GracefulShutdown {
37    shutdown: Shutdown,
38    guard: Option<GracefulShutdownGuard>,
39}
40
41impl GracefulShutdown {
42    pub(crate) const fn new(
43        shutdown: Shutdown, guard: GracefulShutdownGuard,
44    ) -> Self {
45        Self {
46            shutdown,
47            guard: Some(guard),
48        }
49    }
50
51    /// Returns a new shutdown future that is ignores the returned
52    /// [`GracefulShutdownGuard`].
53    ///
54    /// This just maps the return value of the future to `()`, it does not drop
55    /// the guard.
56    pub fn ignore_guard(
57        self,
58    ) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
59        self.map(drop)
60    }
61}
62
63impl Future for GracefulShutdown {
64    type Output = GracefulShutdownGuard;
65
66    fn poll(
67        mut self: Pin<&mut Self>, cx: &mut Context<'_>,
68    ) -> Poll<Self::Output> {
69        ready!(self.shutdown.poll_unpin(cx));
70        Poll::Ready(
71            self.get_mut()
72                .guard
73                .take()
74                .expect("Future polled after completion"),
75        )
76    }
77}
78
79impl Clone for GracefulShutdown {
80    fn clone(&self) -> Self {
81        Self {
82            shutdown: self.shutdown.clone(),
83            guard: self
84                .guard
85                .as_ref()
86                .map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
87        }
88    }
89}
90
91/// A guard that fires once dropped to signal the
92/// [`TaskManager`](crate::TaskManager) that the [`GracefulShutdown`] has
93/// completed.
94#[derive(Debug)]
95#[must_use = "if unused the task will not be gracefully shutdown"]
96pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
97
98impl GracefulShutdownGuard {
99    pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
100        counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
101        Self(counter)
102    }
103}
104
105impl Drop for GracefulShutdownGuard {
106    fn drop(&mut self) {
107        self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
108    }
109}
110
111/// A Future that resolves when the shutdown event has been fired.
112#[derive(Debug, Clone)]
113pub struct Shutdown(Shared<oneshot::Receiver<()>>);
114
115impl Future for Shutdown {
116    type Output = ();
117
118    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119        let pin = self.get_mut();
120        if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
121            Poll::Ready(())
122        } else {
123            Poll::Pending
124        }
125    }
126}
127
128/// Shutdown signal that fires either manually or on drop by closing the channel
129#[derive(Debug)]
130pub struct Signal(oneshot::Sender<()>);
131
132impl Signal {
133    /// Fire the signal manually.
134    pub fn fire(self) { let _ = self.0.send(()); }
135}
136
137/// Create a channel pair that's used to propagate shutdown event
138pub fn signal() -> (Signal, Shutdown) {
139    let (sender, receiver) = oneshot::channel();
140    (Signal(sender), Shutdown(receiver.shared()))
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use futures_util::future::join_all;
147    use std::time::Duration;
148
149    #[tokio::test(flavor = "multi_thread")]
150    async fn test_shutdown() { let (_signal, _shutdown) = signal(); }
151
152    #[tokio::test(flavor = "multi_thread")]
153    async fn test_drop_signal() {
154        let (signal, shutdown) = signal();
155
156        tokio::task::spawn(async move {
157            tokio::time::sleep(Duration::from_millis(500)).await;
158            drop(signal)
159        });
160
161        shutdown.await;
162    }
163
164    #[tokio::test(flavor = "multi_thread")]
165    async fn test_multi_shutdowns() {
166        let (signal, shutdown) = signal();
167
168        let mut tasks = Vec::with_capacity(100);
169        for _ in 0..100 {
170            let shutdown = shutdown.clone();
171            let task = tokio::task::spawn(async move {
172                shutdown.await;
173            });
174            tasks.push(task);
175        }
176
177        drop(signal);
178
179        join_all(tasks).await;
180    }
181
182    #[tokio::test(flavor = "multi_thread")]
183    async fn test_drop_signal_from_thread() {
184        let (signal, shutdown) = signal();
185
186        let _thread = std::thread::spawn(|| {
187            std::thread::sleep(Duration::from_millis(500));
188            drop(signal)
189        });
190
191        shutdown.await;
192    }
193}