1use crate::message_queues::{PerKeyQueue, QueueStyle};
15use anyhow::{ensure, Result};
16use diem_infallible::{Mutex, NonZeroUsize};
17use diem_metrics::IntCounterVec;
18use futures::{
19 channel::oneshot,
20 stream::{FusedStream, Stream},
21};
22use std::{
23 fmt::{Debug, Formatter},
24 hash::Hash,
25 pin::Pin,
26 sync::Arc,
27 task::{Context, Poll, Waker},
28};
29
30#[derive(Debug)]
33struct SharedState<K: Eq + Hash + Clone, M> {
34 internal_queue:
36 PerKeyQueue<K, (M, Option<oneshot::Sender<ElementStatus<M>>>)>,
37 waker: Option<Waker>,
43 num_senders: usize,
46 receiver_dropped: bool,
48 stream_terminated: bool,
52}
53
54#[derive(Debug)]
56pub struct Sender<K: Eq + Hash + Clone, M> {
57 shared_state: Arc<Mutex<SharedState<K, M>>>,
58}
59
60pub enum ElementStatus<M> {
65 Dequeued,
66 Dropped(M),
67}
68
69impl<M: PartialEq> PartialEq for ElementStatus<M> {
70 fn eq(&self, other: &ElementStatus<M>) -> bool {
71 match (self, other) {
72 (ElementStatus::Dequeued, ElementStatus::Dequeued) => true,
73 (ElementStatus::Dropped(a), ElementStatus::Dropped(b)) => a.eq(b),
74 _ => false,
75 }
76 }
77}
78
79impl<M: Debug> Debug for ElementStatus<M> {
80 fn fmt(
81 &self, f: &mut Formatter,
82 ) -> std::result::Result<(), std::fmt::Error> {
83 match self {
84 ElementStatus::Dequeued => write!(f, "Dequeued"),
85 ElementStatus::Dropped(v) => write!(f, "Dropped({:?})", v),
86 }
87 }
88}
89
90impl<K: Eq + Hash + Clone, M> Sender<K, M> {
91 pub fn push(&self, key: K, message: M) -> Result<()> {
94 self.push_with_feedback(key, message, None)
95 }
96
97 pub fn push_with_feedback(
101 &self, key: K, message: M,
102 status_ch: Option<oneshot::Sender<ElementStatus<M>>>,
103 ) -> Result<()> {
104 let mut shared_state = self.shared_state.lock();
105 ensure!(!shared_state.receiver_dropped, "Channel is closed");
106 debug_assert!(shared_state.num_senders > 0);
107
108 let dropped =
109 shared_state.internal_queue.push(key, (message, status_ch));
110 if let Some((dropped_val, Some(dropped_status_ch))) = dropped {
114 let _err =
116 dropped_status_ch.send(ElementStatus::Dropped(dropped_val));
117 }
118 if let Some(w) = shared_state.waker.take() {
119 w.wake();
120 }
121 Ok(())
122 }
123}
124
125impl<K: Eq + Hash + Clone, M> Clone for Sender<K, M> {
126 fn clone(&self) -> Self {
127 let shared_state = self.shared_state.clone();
128 {
129 let mut shared_state_lock = shared_state.lock();
130 debug_assert!(shared_state_lock.num_senders > 0);
131 shared_state_lock.num_senders += 1;
132 }
133 Sender { shared_state }
134 }
135}
136
137impl<K: Eq + Hash + Clone, M> Drop for Sender<K, M> {
138 fn drop(&mut self) {
139 let mut shared_state = self.shared_state.lock();
140
141 debug_assert!(shared_state.num_senders > 0);
142 shared_state.num_senders -= 1;
143
144 if shared_state.num_senders == 0 {
145 if let Some(waker) = shared_state.waker.take() {
146 waker.wake();
147 }
148 }
149 }
150}
151
152pub struct Receiver<K: Eq + Hash + Clone, M> {
154 shared_state: Arc<Mutex<SharedState<K, M>>>,
155}
156
157impl<K: Eq + Hash + Clone, M> Receiver<K, M> {
158 pub fn clear(&mut self) {
162 let mut shared_state = self.shared_state.lock();
163 shared_state.internal_queue.clear();
164 }
165}
166
167impl<K: Eq + Hash + Clone, M> Drop for Receiver<K, M> {
168 fn drop(&mut self) {
169 let mut shared_state = self.shared_state.lock();
170 debug_assert!(!shared_state.receiver_dropped);
171 shared_state.receiver_dropped = true;
172 }
173}
174
175impl<K: Eq + Hash + Clone, M> Stream for Receiver<K, M> {
176 type Item = M;
177
178 fn poll_next(
183 self: Pin<&mut Self>, cx: &mut Context<'_>,
184 ) -> Poll<Option<Self::Item>> {
185 let mut shared_state = self.shared_state.lock();
186 if let Some((val, status_ch)) = shared_state.internal_queue.pop() {
187 if let Some(status_ch) = status_ch {
188 let _err = status_ch.send(ElementStatus::Dequeued);
189 }
190 Poll::Ready(Some(val))
191 } else if shared_state.num_senders == 0 {
193 shared_state.stream_terminated = true;
194 Poll::Ready(None)
195 } else {
196 shared_state.waker = Some(cx.waker().clone());
197 Poll::Pending
198 }
199 }
200}
201
202impl<K: Eq + Hash + Clone, M> FusedStream for Receiver<K, M> {
203 fn is_terminated(&self) -> bool {
204 self.shared_state.lock().stream_terminated
205 }
206}
207
208pub fn new<K: Eq + Hash + Clone, M>(
210 queue_style: QueueStyle, max_queue_size_per_key: usize,
211 counters: Option<&'static IntCounterVec>,
212) -> (Sender<K, M>, Receiver<K, M>) {
213 let max_queue_size_per_key = NonZeroUsize!(
214 max_queue_size_per_key,
215 "diem_channel cannot be of size 0"
216 );
217 let shared_state = Arc::new(Mutex::new(SharedState {
218 internal_queue: PerKeyQueue::new(
219 queue_style,
220 max_queue_size_per_key,
221 counters,
222 ),
223 waker: None,
224 num_senders: 1,
225 receiver_dropped: false,
226 stream_terminated: false,
227 }));
228 let shared_state_clone = Arc::clone(&shared_state);
229 (
230 Sender { shared_state },
231 Receiver {
232 shared_state: shared_state_clone,
233 },
234 )
235}