use crate::{
iolib::{IoContext, StreamToken},
throttling::THROTTLING_SERVICE,
Error,
};
use bytes::{Bytes, BytesMut};
use lazy_static::lazy_static;
use log::{debug, trace};
use metrics::{
register_meter_with_group, Gauge, GaugeUsize, Histogram, Meter, Sample,
};
use mio::{tcp::*, *};
use priority_send_queue::{PrioritySendQueue, SendQueuePriority};
use serde::Deserialize;
use serde_derive::Serialize;
use std::{
io::{self, Read, Write},
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering as AtomicOrdering},
Arc,
},
time::Instant,
};
lazy_static! {
static ref READ_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "read");
static ref WRITE_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "write");
static ref SEND_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "send");
static ref SEND_LOW_PRIORITY_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "send_low");
static ref SEND_NORMAL_PRIORITY_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "send_normal");
static ref SEND_HIGH_PRIORITY_METER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "send_high");
static ref HIGH_PACKET_SEND_TO_WRITE_ELAPSED_TIME: Arc<dyn Histogram> =
Sample::ExpDecay(0.015).register_with_group(
"network_system_data",
"high_packet_wait_time",
1024
);
static ref LOW_PACKET_SEND_TO_WRITE_ELAPSED_TIME: Arc<dyn Histogram> =
Sample::ExpDecay(0.015).register_with_group(
"network_system_data",
"low_packet_wait_time",
1024
);
static ref WRITABLE_COUNTER: Arc<dyn Meter> =
register_meter_with_group("network_system_data", "writable_counter");
static ref WRITABLE_YIELD_SEND_RIGHT_COUNTER: Arc<dyn Meter> =
register_meter_with_group(
"network_system_data",
"writable_yield_send_right_counter"
);
static ref WRITABLE_ZERO_COUNTER: Arc<dyn Meter> =
register_meter_with_group(
"network_system_data",
"writable_zero_counter"
);
static ref WRITABLE_PACKET_COUNTER: Arc<dyn Meter> =
register_meter_with_group(
"network_system_data",
"writable_packet_counter"
);
static ref NETWORK_SEND_QUEUE_SIZE: Arc<dyn Gauge<usize>> =
GaugeUsize::register_with_group(
"network_system_data",
"send_queue_size"
);
}
#[derive(Debug, PartialEq, Eq)]
pub enum WriteStatus {
Ongoing,
Complete,
}
const MAX_PAYLOAD_SIZE: usize = (1 << 24) - 1;
pub trait GenericSocket: Read + Write {}
impl GenericSocket for TcpStream {}
pub trait PacketAssembler: Send + Sync {
fn is_oversized(&self, len: usize) -> bool;
fn assemble(&self, data: &mut Vec<u8>) -> Result<(), Error>;
fn load(&self, buf: &mut BytesMut) -> Option<BytesMut>;
}
struct Packet {
data: Vec<u8>,
sending_pos: usize,
original_is_high_priority: bool,
throttling_size: usize,
creation_time: Instant,
}
impl Packet {
fn new(data: Vec<u8>, priority: SendQueuePriority) -> Result<Self, Error> {
let throttling_size = data.len();
THROTTLING_SERVICE
.write()
.on_enqueue(throttling_size, priority == SendQueuePriority::High)?;
let is_high_priority = priority == SendQueuePriority::High;
Ok(Packet {
data,
sending_pos: 0,
original_is_high_priority: is_high_priority,
throttling_size,
creation_time: Instant::now(),
})
}
fn write(&mut self, writer: &mut dyn Write) -> Result<usize, Error> {
if self.is_send_completed() {
return Ok(0);
}
let size = writer.write(&self.data[self.sending_pos..])?;
self.sending_pos += size;
Ok(size)
}
fn is_send_completed(&self) -> bool { self.sending_pos >= self.data.len() }
}
impl Drop for Packet {
fn drop(&mut self) {
THROTTLING_SERVICE
.write()
.on_dequeue(self.throttling_size, self.original_is_high_priority);
if self.original_is_high_priority {
HIGH_PACKET_SEND_TO_WRITE_ELAPSED_TIME
.update_since(self.creation_time);
} else {
LOW_PACKET_SEND_TO_WRITE_ELAPSED_TIME
.update_since(self.creation_time)
}
}
}
#[allow(dead_code)]
pub struct SendQueueStatus {
queue_length: usize,
}
pub struct GenericConnection<Socket: GenericSocket> {
token: StreamToken,
socket: Socket,
recv_buf: BytesMut,
send_queue: PrioritySendQueue<Packet>,
sending_packet: Option<Packet>,
interest: Ready,
registered: AtomicBool,
assembler: Box<dyn PacketAssembler>,
}
impl<Socket: GenericSocket> GenericConnection<Socket> {
pub fn readable(&mut self) -> io::Result<Option<Bytes>> {
let mut buf: [u8; 1024] = [0; 1024];
loop {
match self.socket.read(&mut buf) {
Ok(size) => {
trace!(
"Succeed to read socket data, token = {}, size = {}",
self.token,
size
);
READ_METER.mark(size);
if size == 0 {
break;
}
self.recv_buf.extend_from_slice(&buf[0..size]);
}
Err(e) => match e.kind() {
io::ErrorKind::Interrupted => continue,
io::ErrorKind::WouldBlock => break,
_ => {
debug!("Failed to read socket data, token = {}, err = {:?}", self.token, e);
return Err(e);
}
},
}
}
let packet = self.assembler.load(&mut self.recv_buf);
if let Some(ref p) = packet {
trace!(
"Packet received, token = {}, size = {}",
self.token,
p.len()
);
}
Ok(packet.map(|p| p.freeze()))
}
pub fn write_raw_data(
&mut self, mut data: Vec<u8>,
) -> Result<usize, Error> {
trace!(
"Sending raw buffer, token = {} data_len = {}, data = {:?}",
self.token,
data.len(),
data
);
self.assembler.assemble(&mut data)?;
let size = self.socket.write(&data)?;
trace!(
"Succeed to send socket data, token = {}, size = {}",
self.token,
size
);
WRITE_METER.mark(size);
Ok(size)
}
fn write_next_from_queue(&mut self) -> Result<WriteStatus, Error> {
if self.sending_packet.is_none() {
let (mut packet, _) = match self.send_queue.pop_front() {
Some(item) => item,
None => return Ok(WriteStatus::Complete),
};
self.assembler.assemble(&mut packet.data)?;
trace!(
"Packet ready for sent, token = {}, size = {}",
self.token,
packet.data.len()
);
self.sending_packet = Some(packet);
}
let packet = self
.sending_packet
.as_mut()
.expect("should pop packet from send queue");
let size = packet.write(&mut self.socket)?;
if size == 0 {
WRITABLE_ZERO_COUNTER.mark(1);
}
trace!(
"Succeed to send socket data, token = {}, size = {}",
self.token,
size
);
WRITE_METER.mark(size);
WRITABLE_COUNTER.mark(1);
if packet.is_send_completed() {
trace!("Packet sent, token = {}", self.token);
self.sending_packet = None;
WRITABLE_PACKET_COUNTER.mark(1);
Ok(WriteStatus::Complete)
} else {
Ok(WriteStatus::Ongoing)
}
}
pub fn writable<Message: Sync + Send + Clone + 'static>(
&mut self, io: &IoContext<Message>,
) -> Result<WriteStatus, Error> {
let status = self.write_next_from_queue()?;
if self.sending_packet.is_none() && self.send_queue.is_empty() {
self.interest.remove(Ready::writable());
}
NETWORK_SEND_QUEUE_SIZE.update(self.send_queue.len());
io.update_registration(self.token)?;
Ok(status)
}
pub fn send<Message: Sync + Send + Clone + 'static>(
&mut self, io: &IoContext<Message>, data: Vec<u8>,
priority: SendQueuePriority,
) -> Result<SendQueueStatus, Error> {
if !data.is_empty() {
let size = data.len();
if self.assembler.is_oversized(size) {
return Err(Error::OversizedPacket.into());
}
trace!("Sending packet, token = {}, size = {}", self.token, size);
let packet = Packet::new(data, priority)?;
self.send_queue.push_back(packet, priority);
SEND_METER.mark(size);
match priority {
SendQueuePriority::High => {
SEND_HIGH_PRIORITY_METER.mark(size);
}
SendQueuePriority::Normal => {
SEND_NORMAL_PRIORITY_METER.mark(size);
}
SendQueuePriority::Low => {
SEND_LOW_PRIORITY_METER.mark(size);
}
}
if !self.interest.is_writable() {
self.interest.insert(Ready::writable());
}
io.update_registration(self.token).ok();
}
Ok(SendQueueStatus {
queue_length: self.send_queue.len(),
})
}
pub fn is_sending(&self) -> bool { self.interest.is_writable() }
}
pub type Connection = GenericConnection<TcpStream>;
impl Connection {
pub fn new(token: StreamToken, socket: TcpStream) -> Self {
Connection {
token,
socket,
recv_buf: BytesMut::new(),
send_queue: PrioritySendQueue::default(),
sending_packet: None,
interest: Ready::hup() | Ready::readable(),
registered: AtomicBool::new(false),
assembler: Box::new(PacketWithLenAssembler::default()),
}
}
pub fn register_socket(
&self, reg: Token, event_loop: &Poll,
) -> io::Result<()> {
if self.registered.load(AtomicOrdering::SeqCst) {
return Ok(());
}
trace!(
"Connection register, token = {}, reg = {:?}",
self.token,
reg
);
if let Err(e) = event_loop.register(
&self.socket,
reg,
self.interest,
PollOpt::edge(),
) {
trace!(
"Failed to register socket, token = {}, reg = {:?}, err = {:?}",
self.token,
reg,
e
);
}
self.registered.store(true, AtomicOrdering::SeqCst);
Ok(())
}
pub fn update_socket(
&self, reg: Token, event_loop: &Poll,
) -> io::Result<()> {
trace!(
"Connection reregister, token = {}, reg = {:?}",
self.token,
reg
);
if !self.registered.load(AtomicOrdering::SeqCst) {
self.register_socket(reg, event_loop)
} else {
event_loop
.reregister(&self.socket, reg, self.interest, PollOpt::edge())
.unwrap_or_else(|e| {
trace!("Failed to reregister socket, token = {}, reg = {:?}, err = {:?}", self.token, reg, e);
});
Ok(())
}
}
pub fn deregister_socket(&self, event_loop: &Poll) -> io::Result<()> {
trace!("Connection deregister, token = {}", self.token);
event_loop.deregister(&self.socket).ok();
Ok(())
}
pub fn token(&self) -> StreamToken { self.token }
pub fn remote_addr(&self) -> io::Result<SocketAddr> {
self.socket.peer_addr()
}
pub fn remote_addr_str(&self) -> String {
self.remote_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "Unknown".to_owned())
}
pub fn details(&self) -> ConnectionDetails {
ConnectionDetails {
token: self.token,
recv_buf: self.recv_buf.len(),
sending_buf: self
.sending_packet
.as_ref()
.map_or(0, |p| p.data.len() - p.sending_pos),
priority_queue_normal: self
.send_queue
.len_by_priority(SendQueuePriority::Normal),
priority_queue_high: self
.send_queue
.len_by_priority(SendQueuePriority::High),
interest: format!("{:?}", self.interest),
registered: self.registered.load(AtomicOrdering::SeqCst),
}
}
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionDetails {
pub token: StreamToken,
pub recv_buf: usize,
pub sending_buf: usize,
pub priority_queue_normal: usize,
pub priority_queue_high: usize,
pub interest: String,
pub registered: bool,
}
pub struct PacketWithLenAssembler {
data_len_bytes: usize,
max_data_len: usize,
}
impl PacketWithLenAssembler {
fn new(data_len_bytes: usize, max_packet_len: Option<usize>) -> Self {
assert!(data_len_bytes > 0 && data_len_bytes <= 3);
let max = usize::max_value() >> (64 - 8 * data_len_bytes);
let max_packet_len = max_packet_len.unwrap_or(max);
assert!(max_packet_len > data_len_bytes && max_packet_len <= max);
PacketWithLenAssembler {
data_len_bytes,
max_data_len: max_packet_len - data_len_bytes,
}
}
}
impl Default for PacketWithLenAssembler {
fn default() -> Self {
PacketWithLenAssembler::new(3, Some(MAX_PAYLOAD_SIZE))
}
}
impl PacketAssembler for PacketWithLenAssembler {
#[inline]
fn is_oversized(&self, len: usize) -> bool { len > self.max_data_len }
fn assemble(&self, data: &mut Vec<u8>) -> Result<(), Error> {
if self.is_oversized(data.len()) {
return Err(Error::OversizedPacket.into());
}
let swapped: Vec<u8> =
data.iter().take(self.data_len_bytes).cloned().collect();
let data_len = data.len();
data.resize(data_len + self.data_len_bytes, 0);
data[..self.data_len_bytes]
.copy_from_slice(&data_len.to_le_bytes()[..self.data_len_bytes]);
let start = data.len() - swapped.len();
data[start..].copy_from_slice(&swapped);
Ok(())
}
fn load(&self, buf: &mut BytesMut) -> Option<BytesMut> {
if buf.len() < self.data_len_bytes {
return None;
}
let mut le_bytes = [0u8; 8];
le_bytes
.split_at_mut(self.data_len_bytes)
.0
.copy_from_slice(&buf[..self.data_len_bytes]);
let data_size = usize::from_le_bytes(le_bytes);
if buf.len() < self.data_len_bytes + data_size {
return None;
}
let mut packet = buf.split_to(self.data_len_bytes + data_size);
if data_size >= self.data_len_bytes {
let swapped = packet.split_off(data_size);
packet[..self.data_len_bytes].copy_from_slice(&swapped);
} else {
let _ = packet.split_to(self.data_len_bytes);
};
Some(packet)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::iolib::*;
use mio::Ready;
use std::{
cmp,
io::{Read, Result, Write},
};
struct TestSocket {
read_buf: Vec<u8>,
write_buf: Vec<u8>,
cursor: usize,
buf_size: usize,
}
impl TestSocket {
fn new() -> Self {
TestSocket {
read_buf: vec![],
write_buf: vec![],
cursor: 0,
buf_size: 0,
}
}
fn with_buf(buf_size: usize) -> Self {
TestSocket {
read_buf: vec![],
write_buf: vec![],
cursor: 0,
buf_size,
}
}
}
impl Read for TestSocket {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let end = cmp::min(self.read_buf.len(), self.cursor + buf.len());
if self.cursor > end {
return Ok(0);
}
let len = end - self.cursor;
if len == 0 {
Ok(0)
} else {
for i in self.cursor..end {
buf[i - self.cursor] = self.read_buf[i];
}
self.cursor = end;
Ok(len)
}
}
}
impl Write for TestSocket {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
if self.buf_size == 0 || buf.len() < self.buf_size {
self.write_buf.extend(buf.iter().cloned());
Ok(buf.len())
} else {
self.write_buf
.extend(buf.iter().take(self.buf_size).cloned());
Ok(self.buf_size)
}
}
fn flush(&mut self) -> Result<()> {
unimplemented!();
}
}
impl GenericSocket for TestSocket {}
type TestConnection = GenericConnection<TestSocket>;
impl TestConnection {
fn new() -> Self {
TestConnection {
token: 1_234_567_890usize,
socket: TestSocket::new(),
send_queue: PrioritySendQueue::default(),
sending_packet: None,
recv_buf: BytesMut::new(),
interest: Ready::hup() | Ready::readable(),
registered: AtomicBool::new(false),
assembler: Box::new(PacketWithLenAssembler::new(1, None)),
}
}
}
fn test_io() -> IoContext<i32> {
IoContext::new(IoChannel::disconnected(), 0)
}
#[test]
fn connection_write_empty() {
let mut connection = TestConnection::new();
let status = connection.writable(&test_io());
assert!(status.is_ok());
let status = status.unwrap();
assert!(WriteStatus::Complete == status);
}
#[test]
fn connection_write_is_buffered() {
let mut connection = TestConnection::new();
connection.socket = TestSocket::with_buf(10);
let packet = Packet::new(vec![0; 60], SendQueuePriority::High).unwrap();
connection
.send_queue
.push_back(packet, SendQueuePriority::High);
let status = connection.writable(&test_io());
assert!(status.is_ok());
assert_eq!(0, connection.send_queue.len());
let sending_packet = connection.sending_packet.unwrap();
assert_eq!(sending_packet.data.len(), 61);
assert_eq!(sending_packet.sending_pos, 10);
}
#[test]
fn connection_read() {
let mut connection = TestConnection::new();
let mut data = vec![1, 3, 5, 7];
connection.assembler.assemble(&mut data).unwrap();
connection.socket.read_buf = data[..2].to_vec();
{
let status = connection.readable();
assert!(status.is_ok());
assert!(status.unwrap().is_none());
}
connection.socket.read_buf.extend_from_slice(&data[2..]);
{
let status = connection.readable();
assert!(status.is_ok());
assert_eq!(&status.unwrap().unwrap()[..], &[1, 3, 5, 7]);
}
{
let status = connection.readable();
assert!(status.is_ok());
assert!(status.unwrap().is_none());
}
}
#[test]
fn test_assembler_oversized() {
let assembler = PacketWithLenAssembler::default();
assert_eq!(assembler.is_oversized(MAX_PAYLOAD_SIZE - 4), false);
assert_eq!(assembler.is_oversized(MAX_PAYLOAD_SIZE - 3), false);
assert_eq!(assembler.is_oversized(MAX_PAYLOAD_SIZE - 2), true);
}
#[test]
fn test_assembler_assemble() {
let assembler = PacketWithLenAssembler::default();
let mut data = vec![1, 2, 3, 4, 5];
assembler.assemble(&mut data).unwrap();
assert_eq!(data, vec![5, 0, 0, 4, 5, 1, 2, 3]);
let mut data = vec![1, 2, 3];
assembler.assemble(&mut data).unwrap();
assert_eq!(data, vec![3, 0, 0, 1, 2, 3]);
let mut data = vec![1, 2];
assembler.assemble(&mut data).unwrap();
assert_eq!(data, vec![2, 0, 0, 1, 2]);
}
#[test]
fn test_assembler_load() {
let assembler = PacketWithLenAssembler::default();
assert_eq!(assembler.load(&mut BytesMut::from(&vec![5][..])), None);
assert_eq!(
assembler.load(&mut BytesMut::from(&vec![5, 0, 0][..])),
None
);
assert_eq!(
assembler.load(&mut BytesMut::from(&vec![5, 0, 0, 4, 5, 1, 2][..])),
None
);
let mut buf = BytesMut::from(&vec![5, 0, 0, 4, 5, 1, 2, 3][..]);
assert_eq!(&assembler.load(&mut buf).unwrap()[..], &[1, 2, 3, 4, 5]);
assert_eq!(buf.is_empty(), true);
let mut buf = BytesMut::from(&vec![2, 0, 0, 1, 2][..]);
assert_eq!(&assembler.load(&mut buf).unwrap()[..], &[1, 2]);
assert_eq!(buf.is_empty(), true);
let mut buf = BytesMut::from(&vec![5, 0, 0, 4, 5, 1, 2, 3, 6, 7][..]);
assert_eq!(&assembler.load(&mut buf).unwrap()[..], &[1, 2, 3, 4, 5]);
assert_eq!(&buf[..], &[6, 7]);
}
}