1use std::{collections::HashMap, fmt::Debug, mem::discriminant, sync::Arc};
6
7use keccak_hash::keccak;
8use parking_lot::RwLock;
9use serde::Deserialize;
10
11use cfx_types::H256;
12use consensus_types::{
13 epoch_retrieval::EpochRetrievalRequest, proposal_msg::ProposalMsg,
14 sync_info::SyncInfo, vote_msg::VoteMsg,
15};
16use diem_types::{
17 account_address::{from_consensus_public_key, AccountAddress},
18 epoch_change::EpochChangeProof,
19 validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey},
20};
21use network::{
22 node_table::NodeId, service::ProtocolVersion, NetworkContext,
23 NetworkProtocolHandler, NetworkService, UpdateNodeOperation,
24};
25
26use crate::{
27 message::{Message, MsgId},
28 pos::{
29 consensus::network::{
30 ConsensusMsg, NetworkTask as ConsensusNetworkTask,
31 },
32 mempool::network::{MempoolSyncMsg, NetworkTask as MempoolNetworkTask},
33 protocol::{
34 message::{
35 block_retrieval::BlockRetrievalRpcRequest,
36 block_retrieval_response::BlockRetrievalRpcResponse, msgid,
37 },
38 network_event::NetworkEvent,
39 request_manager::{
40 request_handler::AsAny, RequestManager, RequestMessage,
41 },
42 },
43 },
44 sync::{Error, ProtocolConfiguration, CHECK_RPC_REQUEST_TIMER},
45};
46
47use super::{HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION};
48
49type TimerToken = usize;
50
51#[derive(Default)]
52pub struct PeerState {
53 id: NodeId,
54 peer_hash: H256,
55 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
57}
58
59impl PeerState {
60 pub fn new(
61 id: NodeId, peer_hash: H256,
62 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
63 ) -> Self {
64 Self {
65 id,
66 peer_hash,
67 pos_public_key,
68 }
69 }
70
71 pub fn set_pos_public_key(
72 &mut self,
73 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
74 ) {
75 self.pos_public_key = pos_public_key
76 }
77
78 pub fn get_id(&self) -> NodeId { self.id }
79}
80
81#[derive(Default)]
82pub struct Peers(RwLock<HashMap<H256, Arc<RwLock<PeerState>>>>);
83
84impl Peers {
85 pub fn new() -> Peers { Self::default() }
86
87 pub fn get(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
88 self.0.read().get(peer).cloned()
89 }
90
91 pub fn insert(
92 &self, peer: H256, id: NodeId,
93 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
94 ) {
95 self.0.write().entry(peer).or_insert(Arc::new(RwLock::new(
96 PeerState::new(id, peer, pos_public_key),
97 )));
98 }
99
100 pub fn len(&self) -> usize { self.0.read().len() }
101
102 pub fn is_empty(&self) -> bool { self.0.read().is_empty() }
103
104 pub fn contains(&self, peer: &H256) -> bool {
105 self.0.read().contains_key(peer)
106 }
107
108 pub fn remove(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
109 self.0.write().remove(peer)
110 }
111
112 pub fn all_peers_satisfying<F>(&self, mut predicate: F) -> Vec<H256>
113 where F: FnMut(&mut PeerState) -> bool {
114 self.0
115 .read()
116 .iter()
117 .filter_map(|(id, state)| {
118 if predicate(&mut *state.write()) {
119 Some(*id)
120 } else {
121 None
122 }
123 })
124 .collect()
125 }
126
127 pub fn fold<B, F>(&self, init: B, f: F) -> B
128 where F: FnMut(B, &Arc<RwLock<PeerState>>) -> B {
129 self.0.write().values().fold(init, f)
130 }
131}
132
133pub struct Context<'a> {
134 pub io: &'a dyn NetworkContext,
135 pub peer: NodeId,
136 pub peer_hash: H256,
137 pub manager: &'a HotStuffSynchronizationProtocol,
138}
139
140impl<'a> Context<'a> {
141 pub fn match_request(
142 &self, request_id: u64,
143 ) -> Result<RequestMessage, Error> {
144 self.manager
145 .request_manager
146 .match_request(self.io, &self.peer, request_id)
147 }
148
149 pub fn send_response(&self, response: &dyn Message) -> Result<(), Error> {
150 response.send(self.io, &self.peer)?;
151 Ok(())
152 }
153
154 pub fn get_peer_account_address(&self) -> Result<AccountAddress, Error> {
155 let k = self.get_pos_public_key().ok_or(Error::UnknownPeer)?;
156 Ok(from_consensus_public_key(&k.0, &k.1))
157 }
158
159 fn get_pos_public_key(
160 &self,
161 ) -> Option<(ConsensusPublicKey, ConsensusVRFPublicKey)> {
162 self.manager
163 .peers
164 .get(&self.peer_hash)
165 .as_ref()?
166 .read()
167 .pos_public_key
168 .clone()
169 }
170}
171
172pub struct HotStuffSynchronizationProtocol {
173 pub protocol_config: ProtocolConfiguration,
174 pub own_node_hash: H256,
175 pub peers: Arc<Peers>,
176 pub request_manager: Arc<RequestManager>,
177 pub consensus_network_task: ConsensusNetworkTask,
178 pub mempool_network_task: MempoolNetworkTask,
179 pub pos_peer_mapping: RwLock<HashMap<AccountAddress, H256>>,
180}
181
182impl HotStuffSynchronizationProtocol {
183 pub fn new(
184 own_node_hash: H256, consensus_network_task: ConsensusNetworkTask,
185 mempool_network_task: MempoolNetworkTask,
186 protocol_config: ProtocolConfiguration,
187 ) -> Self {
188 let request_manager = Arc::new(RequestManager::new(&protocol_config));
189 HotStuffSynchronizationProtocol {
190 protocol_config,
191 own_node_hash,
192 peers: Arc::new(Peers::new()),
193 request_manager,
194 consensus_network_task,
195 mempool_network_task,
196 pos_peer_mapping: RwLock::new(Default::default()),
197 }
198 }
199
200 pub fn with_peers(
201 protocol_config: ProtocolConfiguration, own_node_hash: H256,
202 consensus_network_task: ConsensusNetworkTask,
203 mempool_network_task: MempoolNetworkTask, peers: Arc<Peers>,
204 ) -> Self {
205 let request_manager = Arc::new(RequestManager::new(&protocol_config));
206 HotStuffSynchronizationProtocol {
207 protocol_config,
208 own_node_hash,
209 peers,
210 request_manager,
211 consensus_network_task,
212 mempool_network_task,
213 pos_peer_mapping: RwLock::new(Default::default()),
214 }
215 }
216
217 pub fn register(
218 self: Arc<Self>, network: Arc<NetworkService>,
219 ) -> Result<(), String> {
220 network
221 .register_protocol(self, HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION)
222 .map_err(|e| {
223 format!(
224 "failed to register HotStuffSynchronizationProtocol: {:?}",
225 e
226 )
227 })
228 }
229
230 pub fn remove_expired_flying_request(&self, io: &dyn NetworkContext) {
231 self.request_manager.process_timeout_requests(io);
232 self.request_manager.resend_waiting_requests(io);
233 }
234
235 fn simultaneous_dial_tie_breaking(
245 own_peer_id: H256, remote_peer_id: H256, existing_origin: bool,
246 new_origin: bool,
247 ) -> bool {
248 match (existing_origin, new_origin) {
249 (false , false ) => true,
252 (false , true ) => {
253 remote_peer_id < own_peer_id
254 }
255 (true , false ) => {
256 own_peer_id < remote_peer_id
257 }
258 (true , true ) => false,
261 }
262 }
263
264 fn handle_error(
265 &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId, e: Error,
266 ) {
267 let mut disconnect = true;
268 let mut warn = false;
269 let reason = format!("{}", e);
270 let error_reason = format!("{:?}", e);
271 let mut op = None;
272
273 match e {
276 Error::InvalidBlock => op = Some(UpdateNodeOperation::Demotion),
277 Error::InvalidGetBlockTxn(_) => {
278 op = Some(UpdateNodeOperation::Demotion)
279 }
280 Error::InvalidStatus(_) => op = Some(UpdateNodeOperation::Failure),
281 Error::InvalidMessageFormat => {
282 op = Some(UpdateNodeOperation::Remove)
283 }
284 Error::UnknownPeer => {
285 warn = false;
286 op = Some(UpdateNodeOperation::Failure)
287 }
288 Error::UnexpectedResponse => disconnect = true,
291 Error::RequestNotFound => {
292 warn = false;
293 disconnect = false;
294 }
295 Error::InCatchUpMode(_) => {
296 disconnect = false;
297 warn = false;
298 }
299 Error::TooManyTrans => {}
300 Error::InvalidTimestamp => op = Some(UpdateNodeOperation::Demotion),
301 Error::InvalidSnapshotManifest(_) => {
302 op = Some(UpdateNodeOperation::Demotion)
303 }
304 Error::InvalidSnapshotChunk(_) => {
305 op = Some(UpdateNodeOperation::Demotion)
306 }
307 Error::AlreadyThrottled(_) => {
308 op = Some(UpdateNodeOperation::Remove)
309 }
310 Error::EmptySnapshotChunk => disconnect = false,
311 Error::Throttled(_, msg) => {
312 disconnect = false;
313
314 if let Err(e) = msg.send(io, peer) {
315 error!("failed to send throttled packet: {:?}", e);
316 disconnect = true;
317 }
318 }
319 Error::Decoder(_) => op = Some(UpdateNodeOperation::Remove),
320 Error::Io(_) => disconnect = false,
321 Error::Network(kind) => match kind {
322 network::Error::AddressParse => disconnect = false,
323 network::Error::AddressResolve(_) => disconnect = false,
324 network::Error::Auth => disconnect = false,
325 network::Error::BadProtocol => {
326 op = Some(UpdateNodeOperation::Remove)
327 }
328 network::Error::BadAddr => disconnect = false,
329 network::Error::Decoder(_) => {
330 op = Some(UpdateNodeOperation::Remove)
331 }
332 network::Error::Expired => disconnect = false,
333 network::Error::Disconnect(_) => disconnect = false,
334 network::Error::InvalidNodeId => disconnect = false,
335 network::Error::OversizedPacket => disconnect = false,
336 network::Error::Io(_) => disconnect = false,
337 network::Error::Throttling(_) => disconnect = false,
338 network::Error::SocketIo(_) => {
339 op = Some(UpdateNodeOperation::Failure)
340 }
341 network::Error::Msg(_) => {
342 op = Some(UpdateNodeOperation::Failure)
343 }
344 network::Error::MessageDeprecated { .. } => {
345 op = Some(UpdateNodeOperation::Failure)
346 }
347 network::Error::SendUnsupportedMessage { .. } => {
348 op = Some(UpdateNodeOperation::Failure)
349 }
350 },
351 Error::Storage(_) => {}
352 Error::Msg(_) => op = Some(UpdateNodeOperation::Failure),
353 Error::InternalError(_) => {}
357 Error::RpcTimeout => {}
358 Error::RpcCancelledByDisconnection => {}
359 Error::UnexpectedMessage(_) => {
360 op = Some(UpdateNodeOperation::Remove)
361 }
362 Error::NotSupported(_) => disconnect = false,
363 }
364
365 if warn {
366 warn!(
367 "Error while handling message, peer={}, msgid={:?}, error={}",
368 peer, msg_id, error_reason
369 );
370 } else {
371 debug!(
372 "Minor error while handling message, peer={}, msgid={:?}, error={}",
373 peer, msg_id, error_reason
374 );
375 }
376
377 if disconnect {
378 io.disconnect_peer(peer, op, reason.as_str());
379 }
380 }
381
382 fn dispatch_message(
383 &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId,
384 msg: &[u8],
385 ) -> Result<(), Error> {
386 trace!("Dispatching message: peer={:?}, msg_id={:?}", peer, msg_id);
387 let peer_hash = if !io.is_peer_self(peer) {
388 if *peer == NodeId::default() {
389 return Err(Error::UnknownPeer.into());
390 }
391 let peer_hash = keccak(peer);
392 if !self.peers.contains(&peer_hash) {
393 return Err(Error::UnknownPeer.into());
394 }
395 peer_hash
396 } else {
397 self.own_node_hash.clone()
398 };
399
400 let ctx = Context {
401 peer_hash,
402 peer: *peer,
403 io,
404 manager: self,
405 };
406
407 if !handle_serialized_message(msg_id, &ctx, msg)? {
408 warn!("Unknown message: peer={:?} msgid={:?}", peer, msg_id);
409 let reason =
410 format!("unknown sync protocol message id {:?}", msg_id);
411 io.disconnect_peer(
412 peer,
413 Some(UpdateNodeOperation::Remove),
414 reason.as_str(),
415 );
416 }
417
418 Ok(())
419 }
420}
421
422pub fn handle_serialized_message(
423 id: MsgId, ctx: &Context, msg: &[u8],
424) -> Result<bool, Error> {
425 match id {
426 msgid::PROPOSAL => handle_message::<ProposalMsg>(ctx, msg)?,
427 msgid::VOTE => handle_message::<VoteMsg>(ctx, msg)?,
428 msgid::SYNC_INFO => handle_message::<SyncInfo>(ctx, msg)?,
429 msgid::BLOCK_RETRIEVAL => {
430 handle_message::<BlockRetrievalRpcRequest>(ctx, msg)?
431 }
432 msgid::BLOCK_RETRIEVAL_RESPONSE => {
433 handle_message::<BlockRetrievalRpcResponse>(ctx, msg)?
434 }
435 msgid::EPOCH_RETRIEVAL => {
436 handle_message::<EpochRetrievalRequest>(ctx, msg)?
437 }
438 msgid::EPOCH_CHANGE => handle_message::<EpochChangeProof>(ctx, msg)?,
439 msgid::CONSENSUS_MSG => handle_message::<ConsensusMsg>(ctx, msg)?,
440 msgid::MEMPOOL_SYNC_MSG => handle_message::<MempoolSyncMsg>(ctx, msg)?,
441 _ => return Ok(false),
442 }
443 Ok(true)
444}
445
446fn handle_message<'a, M>(ctx: &Context, msg: &'a [u8]) -> Result<(), Error>
447where M: Deserialize<'a> + Handleable + Message {
448 let msg: M = bcs::from_bytes(msg)?;
449 let msg_id = msg.msg_id();
450 let msg_name = msg.msg_name();
451 let req_id = msg.get_request_id();
452
453 trace!(
454 "handle sync protocol message, peer = {:?}, id = {}, name = {}, request_id = {:?}",
455 ctx.peer_hash, msg_id, msg_name, req_id,
456 );
457
458 if let Err(e) = msg.handle(ctx) {
461 info!(
462 "failed to handle sync protocol message, peer = {}, id = {}, name = {}, request_id = {:?}, error_kind = {:?}",
463 ctx.peer, msg_id, msg_name, req_id, e,
464 );
465
466 return Err(e);
467 }
468
469 Ok(())
470}
471
472impl NetworkProtocolHandler for HotStuffSynchronizationProtocol {
473 fn minimum_supported_version(&self) -> ProtocolVersion {
474 ProtocolVersion(0)
475 }
476
477 fn initialize(&self, io: &dyn NetworkContext) {
478 io.register_timer(
479 CHECK_RPC_REQUEST_TIMER,
480 self.protocol_config.check_request_period,
481 )
482 .expect("Error registering check rpc request timer");
483 }
484
485 fn on_message(&self, io: &dyn NetworkContext, peer: &NodeId, raw: &[u8]) {
486 let len = raw.len();
487 if len < 2 {
488 return self.handle_error(
490 io,
491 peer,
492 msgid::INVALID,
493 Error::InvalidMessageFormat.into(),
494 );
495 }
496
497 let msg_id = raw[len - 1];
498 debug!("on_message: peer={:?}, msgid={:?}", peer, msg_id);
499
500 let msg = &raw[0..raw.len() - 1];
501 self.dispatch_message(io, peer, msg_id.into(), msg)
502 .unwrap_or_else(|e| self.handle_error(io, peer, msg_id.into(), e));
503 }
504
505 fn on_peer_connected(
506 &self, io: &dyn NetworkContext, node_id: &NodeId,
507 _peer_protocol_version: ProtocolVersion,
508 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
509 ) {
510 let new_originated = io.get_peer_connection_origin(node_id);
512 if new_originated.is_none() {
513 debug!("Peer does not exist when just connected");
514 return;
515 }
516 let new_originated = new_originated.unwrap();
517 let peer_hash = keccak(node_id);
518
519 let add_new_peer = if let Some(old_peer) = self.peers.remove(&peer_hash)
520 {
521 let old_peer_id = &old_peer.read().id;
522 let old_originated = io.get_peer_connection_origin(old_peer_id);
523 if old_originated.is_none() {
524 debug!("Old session does not exist.");
525 true
526 } else {
527 let old_originated = old_originated.unwrap();
528 if Self::simultaneous_dial_tie_breaking(
529 self.own_node_hash.clone(),
530 peer_hash.clone(),
531 old_originated,
532 new_originated,
533 ) {
534 io.disconnect_peer(
537 old_peer_id,
538 Some(UpdateNodeOperation::Failure),
539 "remove old peer connection",
540 );
541 true
542 } else {
543 false
545 }
546 }
547 } else {
548 true
549 };
550
551 if add_new_peer {
552 self.peers.insert(peer_hash.clone(), *node_id, None);
553 if let Some(state) = self.peers.get(&peer_hash) {
554 let mut state = state.write();
555 state.id = *node_id;
556 state.peer_hash = peer_hash;
557 self.request_manager.on_peer_connected(node_id);
558 } else {
559 warn!(
560 "PeerState is missing for peer: peer_hash={:?}",
561 peer_hash
562 );
563 }
564 } else {
565 io.disconnect_peer(
566 node_id,
567 Some(UpdateNodeOperation::Failure),
568 "remove new peer connection",
569 );
570 }
571
572 if let Some(public_key) = pos_public_key {
573 self.pos_peer_mapping.write().insert(
574 from_consensus_public_key(&public_key.0, &public_key.1),
575 peer_hash,
576 );
577 if add_new_peer {
578 let event = NetworkEvent::PeerConnected;
579 if let Err(e) = self
580 .mempool_network_task
581 .network_events_tx
582 .push((*node_id, discriminant(&event)), (*node_id, event))
583 {
584 warn!("error sending PeerConnected: e={:?}", e);
585 }
586 }
587 if let Some(state) = self.peers.get(&peer_hash) {
588 state.write().set_pos_public_key(Some(public_key));
589 } else {
590 warn!(
591 "PeerState is missing for peer: peer_hash={:?}",
592 peer_hash
593 );
594 }
595 } else {
596 info!(
597 "pos public key is not provided for peer peer_hash={:?}",
598 peer_hash
599 );
600 }
601
602 debug!(
603 "hsb on_peer_connected: peer {:?}, peer_hash {:?}, peer count {}",
604 node_id,
605 peer_hash,
606 self.peers.len()
607 );
608 }
609
610 fn on_peer_disconnected(&self, io: &dyn NetworkContext, peer: &NodeId) {
611 let peer_hash = keccak(*peer);
612 if let Some(peer_state) = self.peers.remove(&peer_hash) {
613 if let Some(pos_public_key) = &peer_state.read().pos_public_key {
614 self.pos_peer_mapping.write().remove(
615 &from_consensus_public_key(
616 &pos_public_key.0,
617 &pos_public_key.1,
618 ),
619 );
620 }
621 }
622 let event = NetworkEvent::PeerDisconnected;
624 if let Err(e) = self
625 .mempool_network_task
626 .network_events_tx
627 .push((*peer, discriminant(&event)), (*peer, event))
628 {
629 warn!("error sending PeerDisconnected: e={:?}", e);
630 }
631
632 self.request_manager.on_peer_disconnected(io, peer);
633 debug!(
634 "hsb on_peer_disconnected: peer={}, peer count {}",
635 peer,
636 self.peers.len()
637 );
638 }
639
640 fn on_timeout(&self, io: &dyn NetworkContext, timer: TimerToken) {
641 trace!("hsb protocol timeout: timer={:?}", timer);
642 match timer {
643 CHECK_RPC_REQUEST_TIMER => {
644 self.remove_expired_flying_request(io);
645 }
646 _ => warn!("hsb protocol: unknown timer {} triggered.", timer),
647 }
648 }
649
650 fn send_local_message(&self, _io: &dyn NetworkContext, _message: Vec<u8>) {
651 todo!()
652 }
653
654 fn on_work_dispatch(&self, _io: &dyn NetworkContext, _work_type: u8) {
655 todo!()
656 }
657}
658
659pub trait Handleable {
660 fn handle(self, ctx: &Context) -> Result<(), Error>;
661}
662
663pub trait RpcResponse: Send + Sync + Debug + AsAny {}
664
665impl From<bcs::Error> for Error {
666 fn from(_: bcs::Error) -> Self { Error::InvalidMessageFormat.into() }
667}
668
669impl From<anyhow::Error> for Error {
670 fn from(error: anyhow::Error) -> Self {
671 Error::InternalError(format!("{}", error)).into()
672 }
673}