1use std::{collections::HashMap, fmt::Debug, mem::discriminant, sync::Arc};
6
7use keccak_hash::keccak;
8use parking_lot::RwLock;
9use serde::Deserialize;
10
11use cfx_types::H256;
12use consensus_types::{
13 epoch_retrieval::EpochRetrievalRequest, proposal_msg::ProposalMsg,
14 sync_info::SyncInfo, vote_msg::VoteMsg,
15};
16use diem_types::{
17 account_address::{from_consensus_public_key, AccountAddress},
18 epoch_change::EpochChangeProof,
19 validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey},
20};
21use network::{
22 node_table::NodeId, service::ProtocolVersion, NetworkContext,
23 NetworkProtocolHandler, NetworkService, UpdateNodeOperation,
24};
25
26use crate::{
27 message::{Message, MsgId},
28 pos::{
29 consensus::network::{
30 ConsensusMsg, NetworkTask as ConsensusNetworkTask,
31 },
32 mempool::network::{MempoolSyncMsg, NetworkTask as MempoolNetworkTask},
33 protocol::{
34 message::{
35 block_retrieval::BlockRetrievalRpcRequest,
36 block_retrieval_response::BlockRetrievalRpcResponse, msgid,
37 },
38 network_event::NetworkEvent,
39 request_manager::{
40 request_handler::AsAny, RequestManager, RequestMessage,
41 },
42 },
43 },
44 sync::{Error, ProtocolConfiguration, CHECK_RPC_REQUEST_TIMER},
45};
46
47use super::{HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION};
48
49type TimerToken = usize;
50
51#[derive(Default)]
52pub struct PeerState {
53 id: NodeId,
54 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
56}
57
58impl PeerState {
59 pub fn new(
60 id: NodeId,
61 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
62 ) -> Self {
63 Self { id, pos_public_key }
64 }
65
66 pub fn set_pos_public_key(
67 &mut self,
68 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
69 ) {
70 self.pos_public_key = pos_public_key
71 }
72
73 pub fn get_id(&self) -> NodeId { self.id }
74}
75
76#[derive(Default)]
77pub struct Peers(RwLock<HashMap<H256, Arc<RwLock<PeerState>>>>);
78
79impl Peers {
80 pub fn new() -> Peers { Self::default() }
81
82 pub fn get(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
83 self.0.read().get(peer).cloned()
84 }
85
86 pub fn insert(
87 &self, peer: H256, id: NodeId,
88 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
89 ) {
90 self.0.write().insert(
91 peer,
92 Arc::new(RwLock::new(PeerState::new(id, pos_public_key))),
93 );
94 }
95
96 pub fn len(&self) -> usize { self.0.read().len() }
97
98 pub fn is_empty(&self) -> bool { self.0.read().is_empty() }
99
100 pub fn contains(&self, peer: &H256) -> bool {
101 self.0.read().contains_key(peer)
102 }
103
104 pub fn remove(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
105 self.0.write().remove(peer)
106 }
107
108 pub fn all_peers_satisfying<F>(&self, mut predicate: F) -> Vec<H256>
109 where F: FnMut(&mut PeerState) -> bool {
110 self.0
111 .read()
112 .iter()
113 .filter_map(|(id, state)| {
114 if predicate(&mut *state.write()) {
115 Some(*id)
116 } else {
117 None
118 }
119 })
120 .collect()
121 }
122
123 pub fn fold<B, F>(&self, init: B, f: F) -> B
124 where F: FnMut(B, &Arc<RwLock<PeerState>>) -> B {
125 self.0.write().values().fold(init, f)
126 }
127}
128
129pub struct Context<'a> {
130 pub io: &'a dyn NetworkContext,
131 pub peer: NodeId,
132 pub peer_hash: H256,
133 pub manager: &'a HotStuffSynchronizationProtocol,
134}
135
136impl<'a> Context<'a> {
137 pub fn match_request(
138 &self, request_id: u64,
139 ) -> Result<RequestMessage, Error> {
140 self.manager
141 .request_manager
142 .match_request(self.io, &self.peer, request_id)
143 }
144
145 pub fn send_response(&self, response: &dyn Message) -> Result<(), Error> {
146 response.send(self.io, &self.peer)?;
147 Ok(())
148 }
149
150 pub fn get_peer_account_address(&self) -> Result<AccountAddress, Error> {
151 let k = self.get_pos_public_key().ok_or(Error::UnknownPeer)?;
152 Ok(from_consensus_public_key(&k.0, &k.1))
153 }
154
155 fn get_pos_public_key(
156 &self,
157 ) -> Option<(ConsensusPublicKey, ConsensusVRFPublicKey)> {
158 self.manager
159 .peers
160 .get(&self.peer_hash)
161 .as_ref()?
162 .read()
163 .pos_public_key
164 .clone()
165 }
166}
167
168pub struct HotStuffSynchronizationProtocol {
169 pub protocol_config: ProtocolConfiguration,
170 pub own_node_hash: H256,
171 pub peers: Arc<Peers>,
172 pub request_manager: Arc<RequestManager>,
173 pub consensus_network_task: ConsensusNetworkTask,
174 pub mempool_network_task: MempoolNetworkTask,
175 pub pos_peer_mapping: RwLock<HashMap<AccountAddress, H256>>,
176}
177
178impl HotStuffSynchronizationProtocol {
179 pub fn new(
180 own_node_hash: H256, consensus_network_task: ConsensusNetworkTask,
181 mempool_network_task: MempoolNetworkTask,
182 protocol_config: ProtocolConfiguration,
183 ) -> Self {
184 let request_manager = Arc::new(RequestManager::new(&protocol_config));
185 HotStuffSynchronizationProtocol {
186 protocol_config,
187 own_node_hash,
188 peers: Arc::new(Peers::new()),
189 request_manager,
190 consensus_network_task,
191 mempool_network_task,
192 pos_peer_mapping: RwLock::new(Default::default()),
193 }
194 }
195
196 pub fn with_peers(
197 protocol_config: ProtocolConfiguration, own_node_hash: H256,
198 consensus_network_task: ConsensusNetworkTask,
199 mempool_network_task: MempoolNetworkTask, peers: Arc<Peers>,
200 ) -> Self {
201 let request_manager = Arc::new(RequestManager::new(&protocol_config));
202 HotStuffSynchronizationProtocol {
203 protocol_config,
204 own_node_hash,
205 peers,
206 request_manager,
207 consensus_network_task,
208 mempool_network_task,
209 pos_peer_mapping: RwLock::new(Default::default()),
210 }
211 }
212
213 pub fn register(
214 self: Arc<Self>, network: Arc<NetworkService>,
215 ) -> Result<(), String> {
216 network
217 .register_protocol(self, HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION)
218 .map_err(|e| {
219 format!(
220 "failed to register HotStuffSynchronizationProtocol: {:?}",
221 e
222 )
223 })
224 }
225
226 pub fn remove_expired_flying_request(&self, io: &dyn NetworkContext) {
227 self.request_manager.process_timeout_requests(io);
228 self.request_manager.resend_waiting_requests(io);
229 }
230
231 fn handle_error(
232 &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId, e: Error,
233 ) {
234 let mut disconnect = true;
235 let mut warn = false;
236 let reason = format!("{}", e);
237 let error_reason = format!("{:?}", e);
238 let mut op = None;
239
240 match e {
243 Error::InvalidBlock => op = Some(UpdateNodeOperation::Demotion),
244 Error::InvalidGetBlockTxn(_) => {
245 op = Some(UpdateNodeOperation::Demotion)
246 }
247 Error::InvalidStatus(_) => op = Some(UpdateNodeOperation::Failure),
248 Error::InvalidMessageFormat => {
249 op = Some(UpdateNodeOperation::Remove)
250 }
251 Error::UnknownPeer => {
252 warn = false;
253 op = Some(UpdateNodeOperation::Failure)
254 }
255 Error::UnexpectedResponse => disconnect = true,
258 Error::RequestNotFound => {
259 warn = false;
260 disconnect = false;
261 }
262 Error::InCatchUpMode(_) => {
263 disconnect = false;
264 warn = false;
265 }
266 Error::TooManyTrans => {}
267 Error::InvalidTimestamp => op = Some(UpdateNodeOperation::Demotion),
268 Error::InvalidSnapshotManifest(_) => {
269 op = Some(UpdateNodeOperation::Demotion)
270 }
271 Error::InvalidSnapshotChunk(_) => {
272 op = Some(UpdateNodeOperation::Demotion)
273 }
274 Error::AlreadyThrottled(_) => {
275 op = Some(UpdateNodeOperation::Remove)
276 }
277 Error::EmptySnapshotChunk => disconnect = false,
278 Error::Throttled(_, msg) => {
279 disconnect = false;
280
281 if let Err(e) = msg.send(io, peer) {
282 error!("failed to send throttled packet: {:?}", e);
283 disconnect = true;
284 }
285 }
286 Error::Decoder(_) => op = Some(UpdateNodeOperation::Remove),
287 Error::Io(_) => disconnect = false,
288 Error::Network(kind) => match kind {
289 network::Error::AddressParse => disconnect = false,
290 network::Error::AddressResolve(_) => disconnect = false,
291 network::Error::Auth => disconnect = false,
292 network::Error::BadProtocol => {
293 op = Some(UpdateNodeOperation::Remove)
294 }
295 network::Error::BadAddr => disconnect = false,
296 network::Error::Decoder(_) => {
297 op = Some(UpdateNodeOperation::Remove)
298 }
299 network::Error::Expired => disconnect = false,
300 network::Error::Disconnect(_) => disconnect = false,
301 network::Error::InvalidNodeId => disconnect = false,
302 network::Error::OversizedPacket => disconnect = false,
303 network::Error::Io(_) => disconnect = false,
304 network::Error::Throttling(_) => disconnect = false,
305 network::Error::SocketIo(_) => {
306 op = Some(UpdateNodeOperation::Failure)
307 }
308 network::Error::Msg(_) => {
309 op = Some(UpdateNodeOperation::Failure)
310 }
311 network::Error::MessageDeprecated { .. } => {
312 op = Some(UpdateNodeOperation::Failure)
313 }
314 network::Error::SendUnsupportedMessage { .. } => {
315 op = Some(UpdateNodeOperation::Failure)
316 }
317 },
318 Error::Storage(_) => {}
319 Error::Msg(_) => op = Some(UpdateNodeOperation::Failure),
320 Error::InternalError(_) => {}
324 Error::RpcTimeout => {}
325 Error::RpcCancelledByDisconnection => {}
326 Error::UnexpectedMessage(_) => {
327 op = Some(UpdateNodeOperation::Remove)
328 }
329 Error::NotSupported(_) => disconnect = false,
330 }
331
332 if warn {
333 warn!(
334 "Error while handling message, peer={}, msgid={:?}, error={}",
335 peer, msg_id, error_reason
336 );
337 } else {
338 debug!(
339 "Minor error while handling message, peer={}, msgid={:?}, error={}",
340 peer, msg_id, error_reason
341 );
342 }
343
344 if disconnect {
345 io.disconnect_peer(peer, op, reason.as_str());
346 }
347 }
348
349 fn dispatch_message(
350 &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId,
351 msg: &[u8],
352 ) -> Result<(), Error> {
353 trace!("Dispatching message: peer={:?}, msg_id={:?}", peer, msg_id);
354 let peer_hash = if !io.is_peer_self(peer) {
355 if *peer == NodeId::default() {
356 return Err(Error::UnknownPeer.into());
357 }
358 let peer_hash = keccak(peer);
359 if !self.peers.contains(&peer_hash) {
360 return Err(Error::UnknownPeer.into());
361 }
362 peer_hash
363 } else {
364 self.own_node_hash.clone()
365 };
366
367 let ctx = Context {
368 peer_hash,
369 peer: *peer,
370 io,
371 manager: self,
372 };
373
374 if !handle_serialized_message(msg_id, &ctx, msg)? {
375 warn!("Unknown message: peer={:?} msgid={:?}", peer, msg_id);
376 let reason =
377 format!("unknown sync protocol message id {:?}", msg_id);
378 io.disconnect_peer(
379 peer,
380 Some(UpdateNodeOperation::Remove),
381 reason.as_str(),
382 );
383 }
384
385 Ok(())
386 }
387}
388
389pub fn handle_serialized_message(
390 id: MsgId, ctx: &Context, msg: &[u8],
391) -> Result<bool, Error> {
392 match id {
393 msgid::PROPOSAL => handle_message::<ProposalMsg>(ctx, msg)?,
394 msgid::VOTE => handle_message::<VoteMsg>(ctx, msg)?,
395 msgid::SYNC_INFO => handle_message::<SyncInfo>(ctx, msg)?,
396 msgid::BLOCK_RETRIEVAL => {
397 handle_message::<BlockRetrievalRpcRequest>(ctx, msg)?
398 }
399 msgid::BLOCK_RETRIEVAL_RESPONSE => {
400 handle_message::<BlockRetrievalRpcResponse>(ctx, msg)?
401 }
402 msgid::EPOCH_RETRIEVAL => {
403 handle_message::<EpochRetrievalRequest>(ctx, msg)?
404 }
405 msgid::EPOCH_CHANGE => handle_message::<EpochChangeProof>(ctx, msg)?,
406 msgid::CONSENSUS_MSG => handle_message::<ConsensusMsg>(ctx, msg)?,
407 msgid::MEMPOOL_SYNC_MSG => handle_message::<MempoolSyncMsg>(ctx, msg)?,
408 _ => return Ok(false),
409 }
410 Ok(true)
411}
412
413fn handle_message<'a, M>(ctx: &Context, msg: &'a [u8]) -> Result<(), Error>
414where M: Deserialize<'a> + Handleable + Message {
415 let msg: M = bcs::from_bytes(msg)?;
416 let msg_id = msg.msg_id();
417 let msg_name = msg.msg_name();
418 let req_id = msg.get_request_id();
419
420 trace!(
421 "handle sync protocol message, peer = {:?}, id = {}, name = {}, request_id = {:?}",
422 ctx.peer_hash, msg_id, msg_name, req_id,
423 );
424
425 if let Err(e) = msg.handle(ctx) {
428 info!(
429 "failed to handle sync protocol message, peer = {}, id = {}, name = {}, request_id = {:?}, error_kind = {:?}",
430 ctx.peer, msg_id, msg_name, req_id, e,
431 );
432
433 return Err(e);
434 }
435
436 Ok(())
437}
438
439impl NetworkProtocolHandler for HotStuffSynchronizationProtocol {
440 fn minimum_supported_version(&self) -> ProtocolVersion {
441 ProtocolVersion(0)
442 }
443
444 fn initialize(&self, io: &dyn NetworkContext) {
445 io.register_timer(
446 CHECK_RPC_REQUEST_TIMER,
447 self.protocol_config.check_request_period,
448 )
449 .expect("Error registering check rpc request timer");
450 }
451
452 fn on_message(&self, io: &dyn NetworkContext, peer: &NodeId, raw: &[u8]) {
453 let len = raw.len();
454 if len < 2 {
455 return self.handle_error(
457 io,
458 peer,
459 msgid::INVALID,
460 Error::InvalidMessageFormat.into(),
461 );
462 }
463
464 let msg_id = raw[len - 1];
465 debug!("on_message: peer={:?}, msgid={:?}", peer, msg_id);
466
467 let msg = &raw[0..raw.len() - 1];
468 self.dispatch_message(io, peer, msg_id.into(), msg)
469 .unwrap_or_else(|e| self.handle_error(io, peer, msg_id.into(), e));
470 }
471
472 fn on_peer_connected(
473 &self, io: &dyn NetworkContext, node_id: &NodeId,
474 _peer_protocol_version: ProtocolVersion,
475 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
476 ) {
477 if io.get_peer_connection_origin(node_id).is_none() {
478 debug!("Peer does not exist when just connected");
479 return;
480 }
481 let peer_hash = keccak(node_id);
482
483 self.peers
489 .insert(peer_hash, *node_id, pos_public_key.clone());
490 self.request_manager.on_peer_connected(node_id);
491
492 if let Some(public_key) = pos_public_key {
493 self.pos_peer_mapping.write().insert(
494 from_consensus_public_key(&public_key.0, &public_key.1),
495 peer_hash,
496 );
497 let event = NetworkEvent::PeerConnected;
498 if let Err(e) = self
499 .mempool_network_task
500 .network_events_tx
501 .push((*node_id, discriminant(&event)), (*node_id, event))
502 {
503 warn!("error sending PeerConnected: e={:?}", e);
504 }
505 } else {
506 info!(
507 "pos public key is not provided for peer peer_hash={:?}",
508 peer_hash
509 );
510 }
511
512 debug!(
513 "hsb on_peer_connected: peer {:?}, peer_hash {:?}, peer count {}",
514 node_id,
515 peer_hash,
516 self.peers.len()
517 );
518 }
519
520 fn on_peer_disconnected(&self, io: &dyn NetworkContext, peer: &NodeId) {
521 let peer_hash = keccak(*peer);
522 if let Some(peer_state) = self.peers.remove(&peer_hash) {
523 if let Some(pos_public_key) = &peer_state.read().pos_public_key {
524 self.pos_peer_mapping.write().remove(
525 &from_consensus_public_key(
526 &pos_public_key.0,
527 &pos_public_key.1,
528 ),
529 );
530 }
531 }
532 let event = NetworkEvent::PeerDisconnected;
534 if let Err(e) = self
535 .mempool_network_task
536 .network_events_tx
537 .push((*peer, discriminant(&event)), (*peer, event))
538 {
539 warn!("error sending PeerDisconnected: e={:?}", e);
540 }
541
542 self.request_manager.on_peer_disconnected(io, peer);
543 debug!(
544 "hsb on_peer_disconnected: peer={}, peer count {}",
545 peer,
546 self.peers.len()
547 );
548 }
549
550 fn on_timeout(&self, io: &dyn NetworkContext, timer: TimerToken) {
551 trace!("hsb protocol timeout: timer={:?}", timer);
552 match timer {
553 CHECK_RPC_REQUEST_TIMER => {
554 self.remove_expired_flying_request(io);
555 }
556 _ => warn!("hsb protocol: unknown timer {} triggered.", timer),
557 }
558 }
559
560 fn send_local_message(&self, _io: &dyn NetworkContext, _message: Vec<u8>) {
561 todo!()
562 }
563
564 fn on_work_dispatch(&self, _io: &dyn NetworkContext, _work_type: u8) {
565 todo!()
566 }
567}
568
569pub trait Handleable {
570 fn handle(self, ctx: &Context) -> Result<(), Error>;
571}
572
573pub trait RpcResponse: Send + Sync + Debug + AsAny {}
574
575impl From<bcs::Error> for Error {
576 fn from(_: bcs::Error) -> Self { Error::InvalidMessageFormat.into() }
577}
578
579impl From<anyhow::Error> for Error {
580 fn from(error: anyhow::Error) -> Self {
581 Error::InternalError(format!("{}", error)).into()
582 }
583}