1use crate::{
22 mio_util::{EventLoop, EventLoopBuilder, Handler, Sender},
23 worker::{SocketWorker, Work, WorkType, Worker},
24 IoError, IoHandler,
25};
26use lazy_static::lazy_static;
27use log::{debug, error, trace, warn};
28use metrics::{register_meter_with_group, Meter, MeterTimer};
29use mio::{Events, Poll, Registry, Token, Waker};
30use mio_timer::Timeout;
31use parking_lot::{Mutex, RwLock};
32use slab::Slab;
33use std::{
34 collections::HashMap,
35 sync::{Arc, Condvar as SCondvar, Mutex as SMutex, Weak},
36 thread::{self, JoinHandle},
37 time::Duration,
38};
39
40pub type TimerToken = usize;
43pub type StreamToken = usize;
45pub type HandlerId = usize;
47
48pub const TOKENS_PER_HANDLER: usize = 16384;
50const MAX_HANDLERS: usize = 8;
51
52lazy_static! {
53 static ref NET_POLL_THREAD_TIMER: Arc<dyn Meter> =
54 register_meter_with_group("timer", "service_mio::network_poll_thread");
55}
56
57#[derive(Clone)]
59pub enum IoMessage<Message: Send + Sized> {
60 Shutdown,
62 AddHandler {
64 handler: Arc<dyn IoHandler<Message> + Send>,
65 },
66 RemoveHandler {
67 handler_id: HandlerId,
68 },
69 AddTimer {
70 handler_id: HandlerId,
71 token: TimerToken,
72 delay: Duration,
73 once: bool,
74 cancel_all: bool,
76 },
77 RemoveTimer {
78 handler_id: HandlerId,
79 token: TimerToken,
80 },
81 RegisterStream {
82 handler_id: HandlerId,
83 token: StreamToken,
84 },
85 DeregisterStream {
86 handler_id: HandlerId,
87 token: StreamToken,
88 },
89 UpdateStreamRegistration {
90 handler_id: HandlerId,
91 token: StreamToken,
92 },
93 UserMessage(Arc<Message>),
95 RemoteMessage {
97 peer: StreamToken,
98 handler_id: HandlerId,
99 msg: Arc<Message>,
100 },
101}
102
103#[derive(Clone)]
106pub struct IoContext<Message>
107where Message: Send + Sync + 'static
108{
109 channel: IoChannel<Message>,
110 handler_id: HandlerId,
111}
112
113impl<Message> IoContext<Message>
114where Message: Send + Sync + 'static
115{
116 pub fn new(
119 channel: IoChannel<Message>, handler: HandlerId,
120 ) -> IoContext<Message> {
121 IoContext {
122 handler_id: handler,
123 channel,
124 }
125 }
126
127 pub fn register_timer(
130 &self, token: TimerToken, delay: Duration,
131 ) -> Result<(), IoError> {
132 self.channel.send_io(IoMessage::AddTimer {
133 token,
134 delay,
135 handler_id: self.handler_id,
136 once: false,
137 cancel_all: false,
138 })?;
139 Ok(())
140 }
141
142 pub fn register_timer_once(
145 &self, token: TimerToken, delay: Duration,
146 ) -> Result<(), IoError> {
147 self.channel.send_io(IoMessage::AddTimer {
148 token,
149 delay,
150 handler_id: self.handler_id,
151 once: true,
152 cancel_all: true,
153 })?;
154 Ok(())
155 }
156
157 pub fn register_timer_once_nocancel(
161 &self, token: TimerToken, delay: Duration,
162 ) -> Result<(), IoError> {
163 self.channel.send_io(IoMessage::AddTimer {
164 token,
165 delay,
166 handler_id: self.handler_id,
167 once: true,
168 cancel_all: false,
169 })?;
170 Ok(())
171 }
172
173 pub fn clear_timer(&self, token: TimerToken) -> Result<(), IoError> {
175 self.channel.send_io(IoMessage::RemoveTimer {
176 token,
177 handler_id: self.handler_id,
178 })?;
179 Ok(())
180 }
181
182 pub fn register_stream(&self, token: StreamToken) -> Result<(), IoError> {
184 self.channel.send_io(IoMessage::RegisterStream {
185 token,
186 handler_id: self.handler_id,
187 })?;
188 Ok(())
189 }
190
191 pub fn deregister_stream(&self, token: StreamToken) -> Result<(), IoError> {
193 self.channel.send_io(IoMessage::DeregisterStream {
194 token,
195 handler_id: self.handler_id,
196 })?;
197 Ok(())
198 }
199
200 pub fn update_registration(
202 &self, token: StreamToken,
203 ) -> Result<(), IoError> {
204 self.channel.send_io(IoMessage::UpdateStreamRegistration {
205 token,
206 handler_id: self.handler_id,
207 })?;
208 Ok(())
209 }
210
211 pub fn message(&self, message: Message) -> Result<(), IoError> {
213 self.channel.send(message)?;
214 Ok(())
215 }
216
217 pub fn handle(
218 &self, peer: usize, handler_id: HandlerId, msg: Message,
219 ) -> Result<(), IoError> {
220 self.channel.send_io(IoMessage::RemoteMessage {
221 peer,
222 handler_id,
223 msg: Arc::new(msg),
224 })
225 }
226
227 pub fn channel(&self) -> IoChannel<Message> { self.channel.clone() }
229
230 pub fn unregister_handler(&self) {
232 let _ = self.channel.send_io(IoMessage::RemoveHandler {
237 handler_id: self.handler_id,
238 });
239 }
240}
241
242#[derive(Clone)]
243struct UserTimer {
244 delay: Duration,
245 timeout: Timeout,
246 once: bool,
247
248 cancel_all: bool,
251}
252
253pub struct IoManager<Message>
255where Message: Send + Sync
256{
257 timers: Arc<RwLock<HashMap<HandlerId, UserTimer>>>,
258 handlers: Arc<RwLock<Slab<Arc<dyn IoHandler<Message>>>>>,
259 workers: Vec<Worker>,
260 worker_channel: crossbeam_deque::Worker<Work<Message>>,
261 work_ready: Arc<SCondvar>,
262 socket_workers:
263 Vec<(crossbeam_channel::Sender<Work<Message>>, SocketWorker)>,
264 network_poll_registry: Arc<Registry>,
265}
266
267impl<Message> IoManager<Message>
268where Message: Send + Sync + 'static
269{
270 pub fn start(
272 event_loop: &mut EventLoop<IoManager<Message>>,
273 handlers: Arc<RwLock<Slab<Arc<dyn IoHandler<Message>>>>>,
274 network_poll_registry: Arc<Registry>,
275 ) -> Result<(), IoError> {
276 let worker = crossbeam_deque::Worker::new_fifo();
277 let stealer = worker.stealer();
278 let num_workers = 4;
279 let work_ready_mutex = Arc::new(SMutex::new(()));
280 let work_ready = Arc::new(SCondvar::new());
281 let workers = (0..num_workers)
282 .map(|i| {
283 Worker::new(
284 i,
285 stealer.clone(),
286 IoChannel::new(
287 event_loop.channel(),
288 Arc::downgrade(&handlers),
289 ),
290 work_ready.clone(),
291 work_ready_mutex.clone(),
292 )
293 })
294 .collect();
295
296 let num_socket_workers = 4;
297 let socket_workers = (0..num_socket_workers)
298 .map(|i| {
299 let (tx, rx) = crossbeam_channel::unbounded();
300 (
301 tx,
302 SocketWorker::new(
303 i,
304 rx,
305 IoChannel::new(
306 event_loop.channel(),
307 Arc::downgrade(&handlers),
308 ),
309 ),
310 )
311 })
312 .collect();
313
314 let mut io = IoManager {
315 timers: Arc::new(RwLock::new(HashMap::new())),
316 handlers,
317 worker_channel: worker,
318 workers,
319 work_ready,
320 socket_workers,
321 network_poll_registry,
322 };
323 event_loop.run(&mut io)?;
324 Ok(())
325 }
326}
327
328impl<Message> Handler for IoManager<Message>
329where Message: Send + Sync + 'static
330{
331 type Message = IoMessage<Message>;
332 type TimeoutState = Token;
333
334 fn timeout(&mut self, event_loop: &mut EventLoop<Self>, token: Token) {
339 let handler_index = token.0 / TOKENS_PER_HANDLER;
340 let token_id = token.0 % TOKENS_PER_HANDLER;
341 if let Some(handler) = self.handlers.read().get(handler_index) {
342 let maybe_timer = self.timers.read().get(&token.0).cloned();
343 if let Some(timer) = maybe_timer {
344 if timer.once {
345 if timer.cancel_all {
346 self.timers.write().remove(&token_id);
347 event_loop.clear_timeout(&timer.timeout);
348 }
349 } else {
350 event_loop.timeout(token, timer.delay);
351 }
352 self.worker_channel.push(Work {
353 work_type: WorkType::Timeout,
354 token: token_id,
355 handler: handler.clone(),
356 handler_id: handler_index,
357 });
358 self.work_ready.notify_all();
359 }
360 }
361 }
362
363 fn notify(&mut self, event_loop: &mut EventLoop<Self>, msg: Self::Message) {
364 match msg {
365 IoMessage::Shutdown => {
366 self.workers.clear();
367 event_loop.shutdown();
368 }
369 IoMessage::AddHandler { handler } => {
370 let handler_id = self.handlers.write().insert(handler.clone());
371 assert!(
372 handler_id <= MAX_HANDLERS,
373 "Too many handlers registered"
374 );
375 trace!("add handler {}", handler_id);
376 handler.initialize(&IoContext::new(
377 IoChannel::new(
378 event_loop.channel(),
379 Arc::downgrade(&self.handlers),
380 ),
381 handler_id,
382 ));
383 }
384 IoMessage::RemoveHandler { handler_id } => {
385 self.handlers.write().remove(handler_id);
387 let mut timers = self.timers.write();
389 let to_remove: Vec<_> = timers
390 .keys()
391 .cloned()
392 .filter(|timer_id| {
393 timer_id / TOKENS_PER_HANDLER == handler_id
394 })
395 .collect();
396 for timer_id in to_remove {
397 let timer = timers.remove(&timer_id).expect(
398 "to_remove only contains keys from timers; qed",
399 );
400 event_loop.clear_timeout(&timer.timeout);
401 }
402 }
403 IoMessage::AddTimer {
404 handler_id,
405 token,
406 delay,
407 once,
408 cancel_all,
409 } => {
410 let timer_id = token + handler_id * TOKENS_PER_HANDLER;
411 let timeout = event_loop.timeout(Token(timer_id), delay);
412 self.timers.write().insert(
413 timer_id,
414 UserTimer {
415 delay,
416 timeout,
417 once,
418 cancel_all,
419 },
420 );
421 }
422 IoMessage::RemoveTimer { handler_id, token } => {
423 let timer_id = token + handler_id * TOKENS_PER_HANDLER;
424 if let Some(timer) = self.timers.write().remove(&timer_id) {
425 event_loop.clear_timeout(&timer.timeout);
426 }
427 }
428 IoMessage::RegisterStream { handler_id, token } => {
429 trace!("register stream {} {}", handler_id, token);
430 if let Some(handler) = self.handlers.read().get(handler_id) {
431 trace!("do register stream {} {}", handler_id, token);
432 let registry = self.network_poll_registry.as_ref();
433 handler.register_stream(
434 token,
435 Token(token + handler_id * TOKENS_PER_HANDLER),
436 registry,
437 );
438 }
439 }
440 IoMessage::DeregisterStream { handler_id, token } => {
441 if let Some(handler) = self.handlers.read().get(handler_id) {
442 let registry = self.network_poll_registry.as_ref();
443 handler.deregister_stream(token, registry);
444 let timer_id = token + handler_id * TOKENS_PER_HANDLER;
446 if let Some(timer) = self.timers.write().remove(&timer_id) {
447 event_loop.clear_timeout(&timer.timeout);
448 }
449 }
450 }
451 IoMessage::UpdateStreamRegistration { handler_id, token } => {
452 if let Some(handler) = self.handlers.read().get(handler_id) {
453 let registry = self.network_poll_registry.as_ref();
454 handler.update_stream(
455 token,
456 Token(token + handler_id * TOKENS_PER_HANDLER),
457 registry,
458 );
459 }
460 }
461 IoMessage::UserMessage(data) => {
462 for id in 0..MAX_HANDLERS {
464 if let Some(h) = self.handlers.read().get(id) {
465 let handler = h.clone();
466 self.worker_channel.push(Work {
467 work_type: WorkType::Message(data.clone()),
468 token: 0,
469 handler,
470 handler_id: id,
471 });
472 }
473 }
474 self.work_ready.notify_all();
475 }
476 IoMessage::RemoteMessage {
477 peer,
478 handler_id,
479 msg,
480 } => {
481 let worker_id = peer % 4;
482 if let Some(handler) = self.handlers.read().get(handler_id) {
483 self.socket_workers[worker_id]
484 .0
485 .send(Work {
486 work_type: WorkType::Message(msg),
487 token: peer,
488 handler: handler.clone(),
489 handler_id,
490 })
491 .expect("fail to send message to socket_worker");
492 }
493 }
494 }
495 }
496}
497
498enum Handlers<Message: Send> {
499 SharedCollection(Weak<RwLock<Slab<Arc<dyn IoHandler<Message>>>>>),
500 Single(Weak<dyn IoHandler<Message>>),
501}
502
503impl<Message: Send> Clone for Handlers<Message> {
504 fn clone(&self) -> Self {
505 use self::Handlers::*;
506
507 match *self {
508 SharedCollection(ref w) => SharedCollection(w.clone()),
509 Single(ref w) => Single(w.clone()),
510 }
511 }
512}
513
514pub struct IoChannel<Message: Send> {
517 channel: Option<Sender<IoMessage<Message>>>,
518 handlers: Handlers<Message>,
519}
520
521impl<Message> Clone for IoChannel<Message>
522where Message: Send + Sync + 'static
523{
524 fn clone(&self) -> IoChannel<Message> {
525 IoChannel {
526 channel: self.channel.clone(),
527 handlers: self.handlers.clone(),
528 }
529 }
530}
531
532impl<Message> IoChannel<Message>
533where Message: Send + Sync + 'static
534{
535 pub fn send(&self, message: Message) -> Result<(), IoError> {
537 match self.channel {
538 Some(ref channel) => {
539 channel.send(IoMessage::UserMessage(Arc::new(message)))?
540 }
541 None => self.send_sync(message)?,
542 }
543 Ok(())
544 }
545
546 pub fn send_sync(&self, message: Message) -> Result<(), IoError> {
548 match self.handlers {
549 Handlers::SharedCollection(ref handlers) => {
550 if let Some(handlers) = handlers.upgrade() {
551 for id in 0..MAX_HANDLERS {
552 if let Some(h) = handlers.read().get(id) {
553 let handler = h.clone();
554 handler.message(
555 &IoContext::new(self.clone(), id),
556 &message,
557 );
558 }
559 }
560 }
561 }
562 Handlers::Single(ref handler) => {
563 if let Some(handler) = handler.upgrade() {
564 handler.message(&IoContext::new(self.clone(), 0), &message);
565 }
566 }
567 }
568 Ok(())
569 }
570
571 pub fn send_io(&self, message: IoMessage<Message>) -> Result<(), IoError> {
573 if let Some(ref channel) = self.channel {
574 if let Err(e) = channel.send(message) {
575 warn!("Error sending message to eventloop channel, err={}", e);
576 return Err(e.into());
577 }
578 }
579 Ok(())
580 }
581
582 pub fn disconnected() -> IoChannel<Message> {
584 IoChannel {
585 channel: None,
586 handlers: Handlers::SharedCollection(Weak::default()),
587 }
588 }
589
590 pub fn to_handler(
592 handler: Weak<dyn IoHandler<Message>>,
593 ) -> IoChannel<Message> {
594 IoChannel {
595 channel: None,
596 handlers: Handlers::Single(handler),
597 }
598 }
599
600 fn new(
601 channel: Sender<IoMessage<Message>>,
602 handlers: Weak<RwLock<Slab<Arc<dyn IoHandler<Message>>>>>,
603 ) -> IoChannel<Message> {
604 IoChannel {
605 channel: Some(channel),
606 handlers: Handlers::SharedCollection(handlers),
607 }
608 }
609}
610
611pub struct IoService<Message>
614where Message: Send + Sync + 'static
615{
616 thread: Mutex<Option<JoinHandle<()>>>,
617 host_channel: Mutex<Sender<IoMessage<Message>>>,
618 handlers: Arc<RwLock<Slab<Arc<dyn IoHandler<Message>>>>>,
619 network_poll_thread: Mutex<Option<JoinHandle<()>>>,
620 network_poll_stopped: Arc<Waker>,
621 network_poll: Arc<Mutex<Poll>>,
622 stop_token: usize,
623}
624
625impl<Message> IoService<Message>
626where Message: Send + Sync + 'static
627{
628 pub fn start(stop_token: usize) -> Result<IoService<Message>, IoError> {
630 debug!("start IoService");
631 let mut config = EventLoopBuilder::new();
632 config.messages_per_tick(1024);
633 config.notify_capacity(20960);
634 let mut event_loop = config.build().expect("Error creating event loop");
635 let channel = event_loop.channel();
636 let handlers = Arc::new(RwLock::new(Slab::with_capacity(MAX_HANDLERS)));
637 let h = handlers.clone();
638
639 let network_poll = Poll::new().expect("Failed to create Poll instance");
640 let registry = network_poll
641 .registry()
642 .try_clone()
643 .expect("Failed to clone registry for event loop");
644
645 let waker = Waker::new(network_poll.registry(), Token(stop_token))
646 .expect("Failed to create Waker");
647
648 let thread = thread::Builder::new()
649 .name("io_service".into())
650 .spawn(move || {
651 IoManager::<Message>::start(
652 &mut event_loop,
653 h,
654 Arc::new(registry),
655 )
656 .expect("Error starting IO service");
657 })
658 .expect("only one io_service thread, so it should not fail");
659 Ok(IoService {
660 thread: Mutex::new(Some(thread)),
661 host_channel: Mutex::new(channel),
662 handlers,
663 network_poll_thread: Mutex::new(None),
664 network_poll_stopped: Arc::new(waker),
665 network_poll: Arc::new(Mutex::new(network_poll)),
666 stop_token,
667 })
668 }
669
670 pub fn stop(&self) {
671 debug!("[IoService] Closing...");
672 self.network_poll_stopped
675 .wake()
676 .expect("Failed to wake network poll thread");
677
678 if let Some(thread) = self.network_poll_thread.lock().take() {
679 thread.join().unwrap_or_else(|e| match e.downcast_ref::<&'static str>() {
680 Some(e) => error!("Error joining network poll thread: {}", e),
681 None => error!("Error joining network poll thread: Unknown error: {:?}", e),
682 });
683 }
684 self.handlers.write().clear();
687 self.host_channel
688 .lock()
689 .send(IoMessage::Shutdown)
690 .unwrap_or_else(|e| warn!("Error on IO service shutdown: {:?}", e));
691 if let Some(thread) = self.thread.lock().take() {
692 thread.join().unwrap_or_else(|e| match e.downcast_ref::<&'static str>() {
693 Some(e) => error!("Error joining IO service event loop thread: {}", e),
694 None => error!("Error joining IO service event loop thread: Unknown error: {:?}", e),
695 });
696 }
697 debug!("[IoService] Closed.");
698 }
699
700 pub fn start_network_poll(
701 &self, handler: Arc<dyn IoHandler<Message>>, max_sessions: usize,
702 ) {
703 let main_event_loop_channel = self.channel().clone();
704 let network_poll = self.network_poll.clone();
705 let stop_token = self.stop_token;
706 let thread = thread::Builder::new()
707 .name("network_eventloop".into())
708 .spawn(move || {
709 let mut events = Events::with_capacity(max_sessions);
710 loop {
711 let _timer =
712 MeterTimer::time_func(NET_POLL_THREAD_TIMER.as_ref());
713
714 {
715 let mut poll = network_poll.lock();
716 poll.poll(&mut events, Some(Duration::from_secs(1)))
717 .expect("Network poll failure");
718 }
719
720 for event in &events {
721 if event.token().0 == stop_token {
722 return;
723 }
724
725 let handler_id = 0;
726 let token_id = event.token().0 % TOKENS_PER_HANDLER;
727 if event.is_readable() {
728 handler.stream_readable(
729 &IoContext::new(
730 main_event_loop_channel.clone(),
731 handler_id,
732 ),
733 token_id,
734 );
735 }
736 if event.is_writable() {
737 handler.stream_writable(
738 &IoContext::new(
739 main_event_loop_channel.clone(),
740 handler_id,
741 ),
742 token_id,
743 );
744 }
745 if event.is_read_closed() || event.is_write_closed() {
746 handler.stream_hup(
747 &IoContext::new(
748 main_event_loop_channel.clone(),
749 handler_id,
750 ),
751 token_id,
752 );
753 }
754 }
755 }
756 })
757 .expect("only one io_service thread, so it should not fail");
758 *self.network_poll_thread.lock() = Some(thread);
759 }
760
761 pub fn register_handler(
763 &self, handler: Arc<dyn IoHandler<Message> + Send>,
764 ) -> Result<(), IoError> {
765 self.host_channel
766 .lock()
767 .send(IoMessage::AddHandler { handler })?;
768 Ok(())
769 }
770
771 pub fn send_message(&self, message: Message) -> Result<(), IoError> {
774 self.host_channel
775 .lock()
776 .send(IoMessage::UserMessage(Arc::new(message)))?;
777 Ok(())
778 }
779
780 pub fn channel(&self) -> IoChannel<Message> {
782 IoChannel::new(
783 self.host_channel.lock().clone(),
784 Arc::downgrade(&self.handlers),
785 )
786 }
787}
788
789impl<Message> Drop for IoService<Message>
790where Message: Send + Sync
791{
792 fn drop(&mut self) { self.stop() }
793}