1use futures_util::{
23 future::{FusedFuture, Shared},
24 FutureExt,
25};
26use std::{
27 future::Future,
28 pin::Pin,
29 sync::{atomic::AtomicUsize, Arc},
30 task::{ready, Context, Poll},
31};
32use tokio::sync::oneshot;
33
34#[derive(Debug)]
36pub struct GracefulShutdown {
37 shutdown: Shutdown,
38 guard: Option<GracefulShutdownGuard>,
39}
40
41impl GracefulShutdown {
42 pub(crate) const fn new(
43 shutdown: Shutdown, guard: GracefulShutdownGuard,
44 ) -> Self {
45 Self {
46 shutdown,
47 guard: Some(guard),
48 }
49 }
50
51 pub fn ignore_guard(
57 self,
58 ) -> impl Future<Output = ()> + Send + Sync + Unpin + 'static {
59 self.map(drop)
60 }
61}
62
63impl Future for GracefulShutdown {
64 type Output = GracefulShutdownGuard;
65
66 fn poll(
67 mut self: Pin<&mut Self>, cx: &mut Context<'_>,
68 ) -> Poll<Self::Output> {
69 ready!(self.shutdown.poll_unpin(cx));
70 Poll::Ready(
71 self.get_mut()
72 .guard
73 .take()
74 .expect("Future polled after completion"),
75 )
76 }
77}
78
79impl Clone for GracefulShutdown {
80 fn clone(&self) -> Self {
81 Self {
82 shutdown: self.shutdown.clone(),
83 guard: self
84 .guard
85 .as_ref()
86 .map(|g| GracefulShutdownGuard::new(Arc::clone(&g.0))),
87 }
88 }
89}
90
91#[derive(Debug)]
95#[must_use = "if unused the task will not be gracefully shutdown"]
96pub struct GracefulShutdownGuard(Arc<AtomicUsize>);
97
98impl GracefulShutdownGuard {
99 pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
100 counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
101 Self(counter)
102 }
103}
104
105impl Drop for GracefulShutdownGuard {
106 fn drop(&mut self) {
107 self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct Shutdown(Shared<oneshot::Receiver<()>>);
114
115impl Future for Shutdown {
116 type Output = ();
117
118 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119 let pin = self.get_mut();
120 if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
121 Poll::Ready(())
122 } else {
123 Poll::Pending
124 }
125 }
126}
127
128#[derive(Debug)]
130pub struct Signal(oneshot::Sender<()>);
131
132impl Signal {
133 pub fn fire(self) { let _ = self.0.send(()); }
135}
136
137pub fn signal() -> (Signal, Shutdown) {
139 let (sender, receiver) = oneshot::channel();
140 (Signal(sender), Shutdown(receiver.shared()))
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use futures_util::future::join_all;
147 use std::time::Duration;
148
149 #[tokio::test(flavor = "multi_thread")]
150 async fn test_shutdown() { let (_signal, _shutdown) = signal(); }
151
152 #[tokio::test(flavor = "multi_thread")]
153 async fn test_drop_signal() {
154 let (signal, shutdown) = signal();
155
156 tokio::task::spawn(async move {
157 tokio::time::sleep(Duration::from_millis(500)).await;
158 drop(signal)
159 });
160
161 shutdown.await;
162 }
163
164 #[tokio::test(flavor = "multi_thread")]
165 async fn test_multi_shutdowns() {
166 let (signal, shutdown) = signal();
167
168 let mut tasks = Vec::with_capacity(100);
169 for _ in 0..100 {
170 let shutdown = shutdown.clone();
171 let task = tokio::task::spawn(async move {
172 shutdown.await;
173 });
174 tasks.push(task);
175 }
176
177 drop(signal);
178
179 join_all(tasks).await;
180 }
181
182 #[tokio::test(flavor = "multi_thread")]
183 async fn test_drop_signal_from_thread() {
184 let (signal, shutdown) = signal();
185
186 let _thread = std::thread::spawn(|| {
187 std::thread::sleep(Duration::from_millis(500));
188 drop(signal)
189 });
190
191 shutdown.await;
192 }
193}