1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
// Copyright (c) The Diem Core Contributors
// SPDX-License-Identifier: Apache-2.0
// Copyright 2021 Conflux Foundation. All rights reserved.
// Conflux is free software and distributed under GNU General Public License.
// See http://www.gnu.org/licenses/
#![forbid(unsafe_code)]
//! A bounded tokio [`Handle`]. Only a bounded number of tasks can run
//! concurrently when spawned through this executor, defined by the initial
//! `capacity`.
use futures::future::{Future, FutureExt};
use std::sync::Arc;
use tokio::{
runtime::Handle,
sync::{OwnedSemaphorePermit, Semaphore},
task::JoinHandle,
};
#[derive(Clone, Debug)]
pub struct BoundedExecutor {
semaphore: Arc<Semaphore>,
executor: Handle,
}
impl BoundedExecutor {
/// Create a new `BoundedExecutor` from an existing tokio [`Handle`]
/// with a maximum concurrent task capacity of `capacity`.
pub fn new(capacity: usize, executor: Handle) -> Self {
let semaphore = Arc::new(Semaphore::new(capacity));
Self {
semaphore,
executor,
}
}
/// Spawn a [`Future`] on the `BoundedExecutor`. This function is async and
/// will block if the executor is at capacity until one of the other spawned
/// futures completes. This function returns a [`JoinHandle`] that the
/// caller can `.await` on for the results of the [`Future`].
pub async fn spawn<F>(&self, f: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
self.spawn_with_permit(f, permit)
}
/// Try to spawn a [`Future`] on the `BoundedExecutor`. If the
/// `BoundedExecutor` is at capacity, this will return an `Err(F)`,
/// passing back the future the caller attempted to spawn. Otherwise,
/// this will spawn the future on the executor and send back a
/// [`JoinHandle`] that the caller can `.await` on for the results of
/// the [`Future`].
pub fn try_spawn<F>(&self, f: F) -> Result<JoinHandle<F::Output>, F>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self.semaphore.clone().try_acquire_owned().ok() {
Some(permit) => Ok(self.spawn_with_permit(f, permit)),
None => Err(f),
}
}
fn spawn_with_permit<F>(
&self, f: F, spawn_permit: OwnedSemaphorePermit,
) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
// Release the permit back to the semaphore when this task completes.
let f = f.map(move |ret| {
drop(spawn_permit);
ret
});
self.executor.spawn(f)
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::{channel::oneshot, executor::block_on, future::Future};
use std::{
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
use tokio::{runtime::Runtime, time::sleep};
#[test]
fn try_spawn() {
let rt = Runtime::new().unwrap();
let executor = rt.handle().clone();
let executor = BoundedExecutor::new(1, executor);
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
// executor has a free slot, spawn should succeed
let f1 = executor.try_spawn(rx1).unwrap();
// executor is full, try_spawn should return err and give back the task
// we attempted to spawn
let rx2 = executor.try_spawn(rx2).unwrap_err();
// complete f1 future, should open a free slot in executor
tx1.send(()).unwrap();
block_on(f1).unwrap().unwrap();
// should successfully spawn a new task now that the first is complete
let f2 = executor.try_spawn(rx2).unwrap();
// cleanup
tx2.send(()).unwrap();
block_on(f2).unwrap().unwrap();
}
fn yield_task() -> impl Future<Output = ()> {
sleep(Duration::from_millis(1)).map(|_| ())
}
// spawn NUM_TASKS futures on a BoundedExecutor, ensuring that no more than
// MAX_WORKERS ever enter the critical section.
#[test]
fn concurrent_bounded_executor() {
const MAX_WORKERS: u32 = 20;
const NUM_TASKS: u32 = 1000;
static WORKERS: AtomicU32 = AtomicU32::new(0);
static COMPLETED_TASKS: AtomicU32 = AtomicU32::new(0);
let rt = Runtime::new().unwrap();
let executor = rt.handle().clone();
let executor = BoundedExecutor::new(MAX_WORKERS as usize, executor);
for _ in 0..NUM_TASKS {
block_on(executor.spawn(async move {
// acquired permit, there should only ever be MAX_WORKERS in
// this critical section
let prev_workers = WORKERS.fetch_add(1, Ordering::SeqCst);
assert!(prev_workers < MAX_WORKERS);
// yield back to the tokio scheduler
yield_task().await;
let prev_workers = WORKERS.fetch_sub(1, Ordering::SeqCst);
assert!(prev_workers > 0 && prev_workers <= MAX_WORKERS);
COMPLETED_TASKS.fetch_add(1, Ordering::Relaxed);
}));
}
// spin until completed
loop {
let completed = COMPLETED_TASKS.load(Ordering::Relaxed);
if completed == NUM_TASKS {
break;
} else {
std::hint::spin_loop()
}
}
}
}