use std::{collections::HashMap, fmt::Debug, mem::discriminant, sync::Arc};
use keccak_hash::keccak;
use parking_lot::RwLock;
use serde::Deserialize;
use cfx_types::H256;
use consensus_types::{
epoch_retrieval::EpochRetrievalRequest, proposal_msg::ProposalMsg,
sync_info::SyncInfo, vote_msg::VoteMsg,
};
use diem_types::{
account_address::{from_consensus_public_key, AccountAddress},
epoch_change::EpochChangeProof,
validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey},
};
use io::TimerToken;
use network::{
node_table::NodeId, service::ProtocolVersion, NetworkContext,
NetworkProtocolHandler, NetworkService, UpdateNodeOperation,
};
use crate::{
message::{Message, MsgId},
pos::{
consensus::network::{
ConsensusMsg, NetworkTask as ConsensusNetworkTask,
},
mempool::network::{MempoolSyncMsg, NetworkTask as MempoolNetworkTask},
protocol::{
message::{
block_retrieval::BlockRetrievalRpcRequest,
block_retrieval_response::BlockRetrievalRpcResponse, msgid,
},
network_event::NetworkEvent,
request_manager::{
request_handler::AsAny, RequestManager, RequestMessage,
},
},
},
sync::{Error, ProtocolConfiguration, CHECK_RPC_REQUEST_TIMER},
};
use super::{HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION};
#[derive(Default)]
pub struct PeerState {
id: NodeId,
peer_hash: H256,
pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
}
impl PeerState {
pub fn new(
id: NodeId, peer_hash: H256,
pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
) -> Self {
Self {
id,
peer_hash,
pos_public_key,
}
}
pub fn set_pos_public_key(
&mut self,
pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
) {
self.pos_public_key = pos_public_key
}
pub fn get_id(&self) -> NodeId { self.id }
}
#[derive(Default)]
pub struct Peers(RwLock<HashMap<H256, Arc<RwLock<PeerState>>>>);
impl Peers {
pub fn new() -> Peers { Self::default() }
pub fn get(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
self.0.read().get(peer).cloned()
}
pub fn insert(
&self, peer: H256, id: NodeId,
pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
) {
self.0.write().entry(peer).or_insert(Arc::new(RwLock::new(
PeerState::new(id, peer, pos_public_key),
)));
}
pub fn len(&self) -> usize { self.0.read().len() }
pub fn is_empty(&self) -> bool { self.0.read().is_empty() }
pub fn contains(&self, peer: &H256) -> bool {
self.0.read().contains_key(peer)
}
pub fn remove(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
self.0.write().remove(peer)
}
pub fn all_peers_satisfying<F>(&self, mut predicate: F) -> Vec<H256>
where F: FnMut(&mut PeerState) -> bool {
self.0
.read()
.iter()
.filter_map(|(id, state)| {
if predicate(&mut *state.write()) {
Some(*id)
} else {
None
}
})
.collect()
}
pub fn fold<B, F>(&self, init: B, f: F) -> B
where F: FnMut(B, &Arc<RwLock<PeerState>>) -> B {
self.0.write().values().fold(init, f)
}
}
pub struct Context<'a> {
pub io: &'a dyn NetworkContext,
pub peer: NodeId,
pub peer_hash: H256,
pub manager: &'a HotStuffSynchronizationProtocol,
}
impl<'a> Context<'a> {
pub fn match_request(
&self, request_id: u64,
) -> Result<RequestMessage, Error> {
self.manager
.request_manager
.match_request(self.io, &self.peer, request_id)
}
pub fn send_response(&self, response: &dyn Message) -> Result<(), Error> {
response.send(self.io, &self.peer)?;
Ok(())
}
pub fn get_peer_account_address(&self) -> Result<AccountAddress, Error> {
let k = self.get_pos_public_key().ok_or(Error::UnknownPeer)?;
Ok(from_consensus_public_key(&k.0, &k.1))
}
fn get_pos_public_key(
&self,
) -> Option<(ConsensusPublicKey, ConsensusVRFPublicKey)> {
self.manager
.peers
.get(&self.peer_hash)
.as_ref()?
.read()
.pos_public_key
.clone()
}
}
pub struct HotStuffSynchronizationProtocol {
pub protocol_config: ProtocolConfiguration,
pub own_node_hash: H256,
pub peers: Arc<Peers>,
pub request_manager: Arc<RequestManager>,
pub consensus_network_task: ConsensusNetworkTask,
pub mempool_network_task: MempoolNetworkTask,
pub pos_peer_mapping: RwLock<HashMap<AccountAddress, H256>>,
}
impl HotStuffSynchronizationProtocol {
pub fn new(
own_node_hash: H256, consensus_network_task: ConsensusNetworkTask,
mempool_network_task: MempoolNetworkTask,
protocol_config: ProtocolConfiguration,
) -> Self {
let request_manager = Arc::new(RequestManager::new(&protocol_config));
HotStuffSynchronizationProtocol {
protocol_config,
own_node_hash,
peers: Arc::new(Peers::new()),
request_manager,
consensus_network_task,
mempool_network_task,
pos_peer_mapping: RwLock::new(Default::default()),
}
}
pub fn with_peers(
protocol_config: ProtocolConfiguration, own_node_hash: H256,
consensus_network_task: ConsensusNetworkTask,
mempool_network_task: MempoolNetworkTask, peers: Arc<Peers>,
) -> Self {
let request_manager = Arc::new(RequestManager::new(&protocol_config));
HotStuffSynchronizationProtocol {
protocol_config,
own_node_hash,
peers,
request_manager,
consensus_network_task,
mempool_network_task,
pos_peer_mapping: RwLock::new(Default::default()),
}
}
pub fn register(
self: Arc<Self>, network: Arc<NetworkService>,
) -> Result<(), String> {
network
.register_protocol(self, HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION)
.map_err(|e| {
format!(
"failed to register HotStuffSynchronizationProtocol: {:?}",
e
)
})
}
pub fn remove_expired_flying_request(&self, io: &dyn NetworkContext) {
self.request_manager.process_timeout_requests(io);
self.request_manager.resend_waiting_requests(io);
}
fn simultaneous_dial_tie_breaking(
own_peer_id: H256, remote_peer_id: H256, existing_origin: bool,
new_origin: bool,
) -> bool {
match (existing_origin, new_origin) {
(false , false ) => true,
(false , true ) => {
remote_peer_id < own_peer_id
}
(true , false ) => {
own_peer_id < remote_peer_id
}
(true , true ) => false,
}
}
fn handle_error(
&self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId, e: Error,
) {
let mut disconnect = true;
let mut warn = false;
let reason = format!("{}", e);
let error_reason = format!("{:?}", e);
let mut op = None;
match e {
Error::InvalidBlock => op = Some(UpdateNodeOperation::Demotion),
Error::InvalidGetBlockTxn(_) => {
op = Some(UpdateNodeOperation::Demotion)
}
Error::InvalidStatus(_) => op = Some(UpdateNodeOperation::Failure),
Error::InvalidMessageFormat => {
op = Some(UpdateNodeOperation::Remove)
}
Error::UnknownPeer => {
warn = false;
op = Some(UpdateNodeOperation::Failure)
}
Error::UnexpectedResponse => disconnect = true,
Error::RequestNotFound => {
warn = false;
disconnect = false;
}
Error::InCatchUpMode(_) => {
disconnect = false;
warn = false;
}
Error::TooManyTrans => {}
Error::InvalidTimestamp => op = Some(UpdateNodeOperation::Demotion),
Error::InvalidSnapshotManifest(_) => {
op = Some(UpdateNodeOperation::Demotion)
}
Error::InvalidSnapshotChunk(_) => {
op = Some(UpdateNodeOperation::Demotion)
}
Error::AlreadyThrottled(_) => {
op = Some(UpdateNodeOperation::Remove)
}
Error::EmptySnapshotChunk => disconnect = false,
Error::Throttled(_, msg) => {
disconnect = false;
if let Err(e) = msg.send(io, peer) {
error!("failed to send throttled packet: {:?}", e);
disconnect = true;
}
}
Error::Decoder(_) => op = Some(UpdateNodeOperation::Remove),
Error::Io(_) => disconnect = false,
Error::Network(kind) => match kind {
network::Error::AddressParse => disconnect = false,
network::Error::AddressResolve(_) => disconnect = false,
network::Error::Auth => disconnect = false,
network::Error::BadProtocol => {
op = Some(UpdateNodeOperation::Remove)
}
network::Error::BadAddr => disconnect = false,
network::Error::Decoder(_) => {
op = Some(UpdateNodeOperation::Remove)
}
network::Error::Expired => disconnect = false,
network::Error::Disconnect(_) => disconnect = false,
network::Error::InvalidNodeId => disconnect = false,
network::Error::OversizedPacket => disconnect = false,
network::Error::Io(_) => disconnect = false,
network::Error::Throttling(_) => disconnect = false,
network::Error::SocketIo(_) => {
op = Some(UpdateNodeOperation::Failure)
}
network::Error::Msg(_) => {
op = Some(UpdateNodeOperation::Failure)
}
network::Error::MessageDeprecated { .. } => {
op = Some(UpdateNodeOperation::Failure)
}
network::Error::SendUnsupportedMessage { .. } => {
op = Some(UpdateNodeOperation::Failure)
}
},
Error::Storage(_) => {}
Error::Msg(_) => op = Some(UpdateNodeOperation::Failure),
Error::InternalError(_) => {}
Error::RpcTimeout => {}
Error::RpcCancelledByDisconnection => {}
Error::UnexpectedMessage(_) => {
op = Some(UpdateNodeOperation::Remove)
}
Error::NotSupported(_) => disconnect = false,
}
if warn {
warn!(
"Error while handling message, peer={}, msgid={:?}, error={}",
peer, msg_id, error_reason
);
} else {
debug!(
"Minor error while handling message, peer={}, msgid={:?}, error={}",
peer, msg_id, error_reason
);
}
if disconnect {
io.disconnect_peer(peer, op, reason.as_str());
}
}
fn dispatch_message(
&self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId,
msg: &[u8],
) -> Result<(), Error> {
trace!("Dispatching message: peer={:?}, msg_id={:?}", peer, msg_id);
let peer_hash = if !io.is_peer_self(peer) {
if *peer == NodeId::default() {
return Err(Error::UnknownPeer.into());
}
let peer_hash = keccak(peer);
if !self.peers.contains(&peer_hash) {
return Err(Error::UnknownPeer.into());
}
peer_hash
} else {
self.own_node_hash.clone()
};
let ctx = Context {
peer_hash,
peer: *peer,
io,
manager: self,
};
if !handle_serialized_message(msg_id, &ctx, msg)? {
warn!("Unknown message: peer={:?} msgid={:?}", peer, msg_id);
let reason =
format!("unknown sync protocol message id {:?}", msg_id);
io.disconnect_peer(
peer,
Some(UpdateNodeOperation::Remove),
reason.as_str(),
);
}
Ok(())
}
}
pub fn handle_serialized_message(
id: MsgId, ctx: &Context, msg: &[u8],
) -> Result<bool, Error> {
match id {
msgid::PROPOSAL => handle_message::<ProposalMsg>(ctx, msg)?,
msgid::VOTE => handle_message::<VoteMsg>(ctx, msg)?,
msgid::SYNC_INFO => handle_message::<SyncInfo>(ctx, msg)?,
msgid::BLOCK_RETRIEVAL => {
handle_message::<BlockRetrievalRpcRequest>(ctx, msg)?
}
msgid::BLOCK_RETRIEVAL_RESPONSE => {
handle_message::<BlockRetrievalRpcResponse>(ctx, msg)?
}
msgid::EPOCH_RETRIEVAL => {
handle_message::<EpochRetrievalRequest>(ctx, msg)?
}
msgid::EPOCH_CHANGE => handle_message::<EpochChangeProof>(ctx, msg)?,
msgid::CONSENSUS_MSG => handle_message::<ConsensusMsg>(ctx, msg)?,
msgid::MEMPOOL_SYNC_MSG => handle_message::<MempoolSyncMsg>(ctx, msg)?,
_ => return Ok(false),
}
Ok(true)
}
fn handle_message<'a, M>(ctx: &Context, msg: &'a [u8]) -> Result<(), Error>
where M: Deserialize<'a> + Handleable + Message {
let msg: M = bcs::from_bytes(msg)?;
let msg_id = msg.msg_id();
let msg_name = msg.msg_name();
let req_id = msg.get_request_id();
trace!(
"handle sync protocol message, peer = {:?}, id = {}, name = {}, request_id = {:?}",
ctx.peer_hash, msg_id, msg_name, req_id,
);
if let Err(e) = msg.handle(ctx) {
info!(
"failed to handle sync protocol message, peer = {}, id = {}, name = {}, request_id = {:?}, error_kind = {:?}",
ctx.peer, msg_id, msg_name, req_id, e,
);
return Err(e);
}
Ok(())
}
impl NetworkProtocolHandler for HotStuffSynchronizationProtocol {
fn minimum_supported_version(&self) -> ProtocolVersion {
ProtocolVersion(0)
}
fn initialize(&self, io: &dyn NetworkContext) {
io.register_timer(
CHECK_RPC_REQUEST_TIMER,
self.protocol_config.check_request_period,
)
.expect("Error registering check rpc request timer");
}
fn on_message(&self, io: &dyn NetworkContext, peer: &NodeId, raw: &[u8]) {
let len = raw.len();
if len < 2 {
return self.handle_error(
io,
peer,
msgid::INVALID,
Error::InvalidMessageFormat.into(),
);
}
let msg_id = raw[len - 1];
debug!("on_message: peer={:?}, msgid={:?}", peer, msg_id);
let msg = &raw[0..raw.len() - 1];
self.dispatch_message(io, peer, msg_id.into(), msg)
.unwrap_or_else(|e| self.handle_error(io, peer, msg_id.into(), e));
}
fn on_peer_connected(
&self, io: &dyn NetworkContext, node_id: &NodeId,
_peer_protocol_version: ProtocolVersion,
pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
) {
let new_originated = io.get_peer_connection_origin(node_id);
if new_originated.is_none() {
debug!("Peer does not exist when just connected");
return;
}
let new_originated = new_originated.unwrap();
let peer_hash = keccak(node_id);
let add_new_peer = if let Some(old_peer) = self.peers.remove(&peer_hash)
{
let old_peer_id = &old_peer.read().id;
let old_originated = io.get_peer_connection_origin(old_peer_id);
if old_originated.is_none() {
debug!("Old session does not exist.");
true
} else {
let old_originated = old_originated.unwrap();
if Self::simultaneous_dial_tie_breaking(
self.own_node_hash.clone(),
peer_hash.clone(),
old_originated,
new_originated,
) {
io.disconnect_peer(
old_peer_id,
Some(UpdateNodeOperation::Failure),
"remove old peer connection",
);
true
} else {
false
}
}
} else {
true
};
if add_new_peer {
self.peers.insert(peer_hash.clone(), *node_id, None);
if let Some(state) = self.peers.get(&peer_hash) {
let mut state = state.write();
state.id = *node_id;
state.peer_hash = peer_hash;
self.request_manager.on_peer_connected(node_id);
} else {
warn!(
"PeerState is missing for peer: peer_hash={:?}",
peer_hash
);
}
} else {
io.disconnect_peer(
node_id,
Some(UpdateNodeOperation::Failure),
"remove new peer connection",
);
}
if let Some(public_key) = pos_public_key {
self.pos_peer_mapping.write().insert(
from_consensus_public_key(&public_key.0, &public_key.1),
peer_hash,
);
if add_new_peer {
let event = NetworkEvent::PeerConnected;
if let Err(e) = self
.mempool_network_task
.network_events_tx
.push((*node_id, discriminant(&event)), (*node_id, event))
{
warn!("error sending PeerConnected: e={:?}", e);
}
}
if let Some(state) = self.peers.get(&peer_hash) {
state.write().set_pos_public_key(Some(public_key));
} else {
warn!(
"PeerState is missing for peer: peer_hash={:?}",
peer_hash
);
}
} else {
info!(
"pos public key is not provided for peer peer_hash={:?}",
peer_hash
);
}
debug!(
"hsb on_peer_connected: peer {:?}, peer_hash {:?}, peer count {}",
node_id,
peer_hash,
self.peers.len()
);
}
fn on_peer_disconnected(&self, io: &dyn NetworkContext, peer: &NodeId) {
let peer_hash = keccak(*peer);
if let Some(peer_state) = self.peers.remove(&peer_hash) {
if let Some(pos_public_key) = &peer_state.read().pos_public_key {
self.pos_peer_mapping.write().remove(
&from_consensus_public_key(
&pos_public_key.0,
&pos_public_key.1,
),
);
}
}
let event = NetworkEvent::PeerDisconnected;
if let Err(e) = self
.mempool_network_task
.network_events_tx
.push((*peer, discriminant(&event)), (*peer, event))
{
warn!("error sending PeerDisconnected: e={:?}", e);
}
self.request_manager.on_peer_disconnected(io, peer);
debug!(
"hsb on_peer_disconnected: peer={}, peer count {}",
peer,
self.peers.len()
);
}
fn on_timeout(&self, io: &dyn NetworkContext, timer: TimerToken) {
trace!("hsb protocol timeout: timer={:?}", timer);
match timer {
CHECK_RPC_REQUEST_TIMER => {
self.remove_expired_flying_request(io);
}
_ => warn!("hsb protocol: unknown timer {} triggered.", timer),
}
}
fn send_local_message(&self, _io: &dyn NetworkContext, _message: Vec<u8>) {
todo!()
}
fn on_work_dispatch(&self, _io: &dyn NetworkContext, _work_type: u8) {
todo!()
}
}
pub trait Handleable {
fn handle(self, ctx: &Context) -> Result<(), Error>;
}
pub trait RpcResponse: Send + Sync + Debug + AsAny {}
impl From<bcs::Error> for Error {
fn from(_: bcs::Error) -> Self { Error::InvalidMessageFormat.into() }
}
impl From<anyhow::Error> for Error {
fn from(error: anyhow::Error) -> Self {
Error::InternalError(format!("{}", error)).into()
}
}