cfxcore_types/
channel.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5use crate::UniqueId;
6use cfx_types::H256;
7use log::warn;
8use parking_lot::RwLock;
9use std::{collections::BTreeMap, sync::Arc, time::Duration};
10use tokio::{runtime, sync::mpsc, time::timeout};
11
12pub use tokio::{sync::mpsc::error::TryRecvError, time::error::Elapsed};
13
14pub struct Receiver<T> {
15    pub id: u64,
16    receiver: mpsc::UnboundedReceiver<T>,
17}
18
19impl<T> Receiver<T> {
20    pub async fn recv(&mut self) -> Option<T> { self.receiver.recv().await }
21
22    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
23        self.receiver.try_recv()
24    }
25
26    pub fn recv_blocking(&mut self) -> Option<T> {
27        futures::executor::block_on(self.receiver.recv())
28    }
29
30    pub fn recv_with_timeout(
31        &mut self, wait_for: Duration,
32    ) -> Result<Option<T>, Elapsed> {
33        runtime::Builder::new_current_thread()
34            .enable_time()
35            .build()
36            .expect("Runtime can be created")
37            // this only works in an async block, see:
38            // https://users.rust-lang.org/t/tokio-interval-not-working-in-runtime/41260/2
39            .block_on(
40                async move { timeout(wait_for, self.receiver.recv()).await },
41            )
42    }
43
44    // NOTE: do not capture anything in `f` that might have references to
45    // `Notifications`, otherwise the loop might never terminate.
46    pub async fn for_each(mut self, f: impl Fn(T) -> ()) {
47        while let Some(t) = self.recv().await {
48            f(t);
49        }
50    }
51}
52
53/// Implements an unbounded SPMC broadcast channel.
54pub struct Channel<T> {
55    // Used for generating subscription ids unique to this channel.
56    id_allocator: UniqueId,
57
58    // Name of the current instance.
59    name: String,
60
61    // Set of subscriptions, represented as ID => Sender pairs.
62    subscriptions: RwLock<BTreeMap<u64, mpsc::UnboundedSender<T>>>,
63}
64
65impl<T: Clone> Channel<T> {
66    pub fn new(name: &str) -> Self {
67        Self {
68            id_allocator: UniqueId::new(),
69            name: name.to_owned(),
70            subscriptions: RwLock::new(BTreeMap::new()),
71        }
72    }
73
74    pub fn subscribe(&self) -> Receiver<T> {
75        let (sender, receiver) = mpsc::unbounded_channel();
76        let id = self.id_allocator.next();
77        self.subscriptions.write().insert(id, sender);
78        Receiver { id, receiver }
79    }
80
81    pub fn unsubscribe(&self, id: u64) -> bool {
82        self.subscriptions.write().remove(&id).is_some()
83    }
84
85    pub fn num_subscriptions(&self) -> usize { self.subscriptions.read().len() }
86
87    pub fn send(&self, t: T) -> bool {
88        let mut sent = false;
89        let mut invalid = vec![];
90
91        for (id, send) in &*self.subscriptions.write() {
92            match send.send(t.clone()) {
93                Ok(_) => sent = true,
94                Err(_e) => {
95                    warn!(
96                        "Channel {}::{} dropped without unsubscribe",
97                        self.name, id
98                    );
99                    invalid.push(*id);
100                }
101            }
102        }
103
104        for id in invalid {
105            self.unsubscribe(id);
106        }
107
108        sent
109    }
110}
111
112pub struct Notifications {
113    pub new_block_hashes: Arc<Channel<H256>>,
114    pub epochs_ordered: Arc<Channel<(u64, Vec<H256>)>>,
115    pub blame_verification_results: Arc<Channel<(u64, Option<u64>)>>, /* <height, witness> */
116}
117
118impl Notifications {
119    pub fn init() -> Arc<Self> {
120        Arc::new(Notifications {
121            new_block_hashes: Arc::new(Channel::new("new-block-hashes")),
122            epochs_ordered: Arc::new(Channel::new("epochs-executed")),
123            blame_verification_results: Arc::new(Channel::new(
124                "blame-verification-results",
125            )),
126        })
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::Channel;
133    use futures::future::join3;
134    use rand::Rng;
135    use tokio::runtime::Runtime;
136
137    #[test]
138    fn test_sync() {
139        let chan = Channel::<u64>::new("test-chan");
140
141        // try send without subscribers
142        let sent = chan.send(1001);
143        assert!(!sent);
144
145        // add one subscription
146        let mut rec1 = chan.subscribe();
147        assert_eq!(chan.num_subscriptions(), 1);
148
149        let sent = chan.send(1002);
150        assert!(sent);
151
152        assert_eq!(rec1.recv_blocking(), Some(1002));
153
154        // add one subscription
155        let mut rec2 = chan.subscribe();
156        assert_eq!(chan.num_subscriptions(), 2);
157
158        let sent = chan.send(1003);
159        assert!(sent);
160        let sent = chan.send(1004);
161        assert!(sent);
162
163        assert_eq!(rec1.recv_blocking(), Some(1003));
164        assert_eq!(rec1.recv_blocking(), Some(1004));
165        assert_eq!(rec2.recv_blocking(), Some(1003));
166        assert_eq!(rec2.recv_blocking(), Some(1004));
167
168        // unsubscribe first
169        assert!(chan.unsubscribe(rec1.id));
170        assert_eq!(chan.num_subscriptions(), 1);
171
172        let sent = chan.send(1005);
173        assert!(sent);
174
175        assert_eq!(rec2.recv_blocking(), Some(1005));
176
177        // unsubscribe second
178        assert!(chan.unsubscribe(rec2.id));
179        assert_eq!(chan.num_subscriptions(), 0);
180
181        let sent = chan.send(1005);
182        assert!(!sent);
183    }
184
185    #[test]
186    fn test_drop_receivers() {
187        let chan = Channel::<u64>::new("test-chan");
188
189        // add subscriptions
190        let rec1 = chan.subscribe();
191        let mut rec2 = chan.subscribe();
192
193        // drop first
194        drop(rec1);
195        assert_eq!(chan.num_subscriptions(), 2);
196
197        let sent = chan.send(1004);
198        assert!(sent);
199        assert_eq!(chan.num_subscriptions(), 1);
200
201        assert_eq!(rec2.recv_blocking(), Some(1004));
202
203        // drop second
204        drop(rec2);
205        assert_eq!(chan.num_subscriptions(), 1);
206
207        let sent = chan.send(1005);
208        assert!(!sent);
209        assert_eq!(chan.num_subscriptions(), 0);
210    }
211
212    #[test]
213    fn test_drop_sender() {
214        // create channel add subscriptions
215        let chan = Channel::<u64>::new("test-chan");
216        let mut rec1 = chan.subscribe();
217        let mut rec2 = chan.subscribe();
218
219        // send normally
220        let sent = chan.send(1001);
221        assert!(sent);
222
223        assert_eq!(rec1.recv_blocking(), Some(1001));
224        assert_eq!(rec2.recv_blocking(), Some(1001));
225
226        // drop sender
227        drop(chan);
228
229        assert_eq!(rec1.recv_blocking(), None);
230        assert_eq!(rec2.recv_blocking(), None);
231    }
232
233    #[test]
234    fn test_async() {
235        // create channel add subscriptions
236        let chan = Channel::<u64>::new("test-chan");
237        let mut rec1 = chan.subscribe();
238        let mut rec2 = chan.subscribe();
239
240        // create async receiver
241        let fut1 = async move {
242            let mut received = vec![];
243            while let Some(t) = rec1.recv().await {
244                received.push(t);
245            }
246            received
247        };
248
249        // create async receiver
250        let fut2 = async move {
251            let mut received = vec![];
252            while let Some(t) = rec2.recv().await {
253                received.push(t);
254            }
255            received
256        };
257
258        // create async sender
259        let fut3 = async move {
260            let mut rng = rand::rng();
261            let mut sent = vec![];
262            for t in (0..100).map(|_| rng.random()) {
263                chan.send(t);
264                sent.push(t);
265            }
266            sent
267        };
268
269        let runtime = Runtime::new().expect("Unable to create a runtime");
270        let (res1, res2, res3) = runtime.block_on(join3(fut1, fut2, fut3));
271
272        assert_eq!(res1, res3);
273        assert_eq!(res2, res3);
274    }
275
276    #[test]
277    fn test_ring() {
278        // create channels and add subscriptions
279        let send_a = Channel::<u64>::new("test-chan-ab");
280        let send_b = Channel::<u64>::new("test-chan-bc");
281        let send_c = Channel::<u64>::new("test-chan-ca");
282
283        let mut rec_b = send_a.subscribe();
284        let mut rec_c = send_b.subscribe();
285        let mut rec_a = send_c.subscribe();
286
287        // create async sender
288        let fut_a = async move {
289            let mut rng = rand::rng();
290
291            for t in (0..100).map(|_| rng.random()) {
292                send_a.send(t);
293                let t2 = rec_a.recv().await;
294
295                if t2 != Some(t) {
296                    return Err(format!("Not equal: {:?}, {:?}", t2, Some(t)));
297                }
298            }
299
300            Ok(())
301        };
302
303        // create async receiver
304        let fut_b = async move {
305            while let Some(t) = rec_b.recv().await {
306                send_b.send(t);
307            }
308        };
309
310        // create async receiver
311        let fut_c = async move {
312            while let Some(t) = rec_c.recv().await {
313                send_c.send(t);
314            }
315        };
316
317        let runtime = Runtime::new().expect("Unable to create a runtime");
318        let (res, (), ()) = runtime.block_on(join3(fut_a, fut_b, fut_c));
319        assert_eq!(res, Ok(()))
320    }
321}