cfxcore/pos/protocol/
sync_protocol.rs

1// Copyright 2019-2020 Conflux Foundation. All rights reserved.
2// TreeGraph is free software and distributed under Apache License 2.0.
3// See https://www.apache.org/licenses/LICENSE-2.0
4
5use std::{collections::HashMap, fmt::Debug, mem::discriminant, sync::Arc};
6
7use keccak_hash::keccak;
8use parking_lot::RwLock;
9use serde::Deserialize;
10
11use cfx_types::H256;
12use consensus_types::{
13    epoch_retrieval::EpochRetrievalRequest, proposal_msg::ProposalMsg,
14    sync_info::SyncInfo, vote_msg::VoteMsg,
15};
16use diem_types::{
17    account_address::{from_consensus_public_key, AccountAddress},
18    epoch_change::EpochChangeProof,
19    validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey},
20};
21use network::{
22    node_table::NodeId, service::ProtocolVersion, NetworkContext,
23    NetworkProtocolHandler, NetworkService, UpdateNodeOperation,
24};
25
26use crate::{
27    message::{Message, MsgId},
28    pos::{
29        consensus::network::{
30            ConsensusMsg, NetworkTask as ConsensusNetworkTask,
31        },
32        mempool::network::{MempoolSyncMsg, NetworkTask as MempoolNetworkTask},
33        protocol::{
34            message::{
35                block_retrieval::BlockRetrievalRpcRequest,
36                block_retrieval_response::BlockRetrievalRpcResponse, msgid,
37            },
38            network_event::NetworkEvent,
39            request_manager::{
40                request_handler::AsAny, RequestManager, RequestMessage,
41            },
42        },
43    },
44    sync::{Error, ProtocolConfiguration, CHECK_RPC_REQUEST_TIMER},
45};
46
47use super::{HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION};
48
49type TimerToken = usize;
50
51#[derive(Default)]
52pub struct PeerState {
53    id: NodeId,
54    peer_hash: H256,
55    // TODO(lpl): Only keep AccountAddress?
56    pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
57}
58
59impl PeerState {
60    pub fn new(
61        id: NodeId, peer_hash: H256,
62        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
63    ) -> Self {
64        Self {
65            id,
66            peer_hash,
67            pos_public_key,
68        }
69    }
70
71    pub fn set_pos_public_key(
72        &mut self,
73        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
74    ) {
75        self.pos_public_key = pos_public_key
76    }
77
78    pub fn get_id(&self) -> NodeId { self.id }
79}
80
81#[derive(Default)]
82pub struct Peers(RwLock<HashMap<H256, Arc<RwLock<PeerState>>>>);
83
84impl Peers {
85    pub fn new() -> Peers { Self::default() }
86
87    pub fn get(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
88        self.0.read().get(peer).cloned()
89    }
90
91    pub fn insert(
92        &self, peer: H256, id: NodeId,
93        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
94    ) {
95        self.0.write().entry(peer).or_insert(Arc::new(RwLock::new(
96            PeerState::new(id, peer, pos_public_key),
97        )));
98    }
99
100    pub fn len(&self) -> usize { self.0.read().len() }
101
102    pub fn is_empty(&self) -> bool { self.0.read().is_empty() }
103
104    pub fn contains(&self, peer: &H256) -> bool {
105        self.0.read().contains_key(peer)
106    }
107
108    pub fn remove(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
109        self.0.write().remove(peer)
110    }
111
112    pub fn all_peers_satisfying<F>(&self, mut predicate: F) -> Vec<H256>
113    where F: FnMut(&mut PeerState) -> bool {
114        self.0
115            .read()
116            .iter()
117            .filter_map(|(id, state)| {
118                if predicate(&mut *state.write()) {
119                    Some(*id)
120                } else {
121                    None
122                }
123            })
124            .collect()
125    }
126
127    pub fn fold<B, F>(&self, init: B, f: F) -> B
128    where F: FnMut(B, &Arc<RwLock<PeerState>>) -> B {
129        self.0.write().values().fold(init, f)
130    }
131}
132
133pub struct Context<'a> {
134    pub io: &'a dyn NetworkContext,
135    pub peer: NodeId,
136    pub peer_hash: H256,
137    pub manager: &'a HotStuffSynchronizationProtocol,
138}
139
140impl<'a> Context<'a> {
141    pub fn match_request(
142        &self, request_id: u64,
143    ) -> Result<RequestMessage, Error> {
144        self.manager
145            .request_manager
146            .match_request(self.io, &self.peer, request_id)
147    }
148
149    pub fn send_response(&self, response: &dyn Message) -> Result<(), Error> {
150        response.send(self.io, &self.peer)?;
151        Ok(())
152    }
153
154    pub fn get_peer_account_address(&self) -> Result<AccountAddress, Error> {
155        let k = self.get_pos_public_key().ok_or(Error::UnknownPeer)?;
156        Ok(from_consensus_public_key(&k.0, &k.1))
157    }
158
159    fn get_pos_public_key(
160        &self,
161    ) -> Option<(ConsensusPublicKey, ConsensusVRFPublicKey)> {
162        self.manager
163            .peers
164            .get(&self.peer_hash)
165            .as_ref()?
166            .read()
167            .pos_public_key
168            .clone()
169    }
170}
171
172pub struct HotStuffSynchronizationProtocol {
173    pub protocol_config: ProtocolConfiguration,
174    pub own_node_hash: H256,
175    pub peers: Arc<Peers>,
176    pub request_manager: Arc<RequestManager>,
177    pub consensus_network_task: ConsensusNetworkTask,
178    pub mempool_network_task: MempoolNetworkTask,
179    pub pos_peer_mapping: RwLock<HashMap<AccountAddress, H256>>,
180}
181
182impl HotStuffSynchronizationProtocol {
183    pub fn new(
184        own_node_hash: H256, consensus_network_task: ConsensusNetworkTask,
185        mempool_network_task: MempoolNetworkTask,
186        protocol_config: ProtocolConfiguration,
187    ) -> Self {
188        let request_manager = Arc::new(RequestManager::new(&protocol_config));
189        HotStuffSynchronizationProtocol {
190            protocol_config,
191            own_node_hash,
192            peers: Arc::new(Peers::new()),
193            request_manager,
194            consensus_network_task,
195            mempool_network_task,
196            pos_peer_mapping: RwLock::new(Default::default()),
197        }
198    }
199
200    pub fn with_peers(
201        protocol_config: ProtocolConfiguration, own_node_hash: H256,
202        consensus_network_task: ConsensusNetworkTask,
203        mempool_network_task: MempoolNetworkTask, peers: Arc<Peers>,
204    ) -> Self {
205        let request_manager = Arc::new(RequestManager::new(&protocol_config));
206        HotStuffSynchronizationProtocol {
207            protocol_config,
208            own_node_hash,
209            peers,
210            request_manager,
211            consensus_network_task,
212            mempool_network_task,
213            pos_peer_mapping: RwLock::new(Default::default()),
214        }
215    }
216
217    pub fn register(
218        self: Arc<Self>, network: Arc<NetworkService>,
219    ) -> Result<(), String> {
220        network
221            .register_protocol(self, HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION)
222            .map_err(|e| {
223                format!(
224                    "failed to register HotStuffSynchronizationProtocol: {:?}",
225                    e
226                )
227            })
228    }
229
230    pub fn remove_expired_flying_request(&self, io: &dyn NetworkContext) {
231        self.request_manager.process_timeout_requests(io);
232        self.request_manager.resend_waiting_requests(io);
233    }
234
235    /// In the event two peers simultaneously dial each other we need to be able
236    /// to do tie-breaking to determine which connection to keep and which
237    /// to drop in a deterministic way. One simple way is to compare our
238    /// local PeerId with that of the remote's PeerId and
239    /// keep the connection where the peer with the greater PeerId is the
240    /// dialer.
241    ///
242    /// Returns `true` if the existing connection should be dropped and `false`
243    /// if the new connection should be dropped.
244    fn simultaneous_dial_tie_breaking(
245        own_peer_id: H256, remote_peer_id: H256, existing_origin: bool,
246        new_origin: bool,
247    ) -> bool {
248        match (existing_origin, new_origin) {
249            // If the remote dials while an existing connection is open, the
250            // older connection is dropped.
251            (false /* in-bound */, false /* in-bound */) => true,
252            (false /* in-bound */, true /* out-bound */) => {
253                remote_peer_id < own_peer_id
254            }
255            (true /* out-bound */, false /* in-bound */) => {
256                own_peer_id < remote_peer_id
257            }
258            // We should never dial the same peer twice, but if we do drop the
259            // new connection
260            (true /* out-bound */, true /* out-bound */) => false,
261        }
262    }
263
264    fn handle_error(
265        &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId, e: Error,
266    ) {
267        let mut disconnect = true;
268        let mut warn = false;
269        let reason = format!("{}", e);
270        let error_reason = format!("{:?}", e);
271        let mut op = None;
272
273        // NOTE, DO NOT USE WILDCARD IN THE FOLLOWING MATCH STATEMENT!
274        // COMPILER WILL HELP TO FIND UNHANDLED ERROR CASES.
275        match e {
276            Error::InvalidBlock => op = Some(UpdateNodeOperation::Demotion),
277            Error::InvalidGetBlockTxn(_) => {
278                op = Some(UpdateNodeOperation::Demotion)
279            }
280            Error::InvalidStatus(_) => op = Some(UpdateNodeOperation::Failure),
281            Error::InvalidMessageFormat => {
282                op = Some(UpdateNodeOperation::Remove)
283            }
284            Error::UnknownPeer => {
285                warn = false;
286                op = Some(UpdateNodeOperation::Failure)
287            }
288            // TODO handle the unexpected response case (timeout or real invalid
289            // message type)
290            Error::UnexpectedResponse => disconnect = true,
291            Error::RequestNotFound => {
292                warn = false;
293                disconnect = false;
294            }
295            Error::InCatchUpMode(_) => {
296                disconnect = false;
297                warn = false;
298            }
299            Error::TooManyTrans => {}
300            Error::InvalidTimestamp => op = Some(UpdateNodeOperation::Demotion),
301            Error::InvalidSnapshotManifest(_) => {
302                op = Some(UpdateNodeOperation::Demotion)
303            }
304            Error::InvalidSnapshotChunk(_) => {
305                op = Some(UpdateNodeOperation::Demotion)
306            }
307            Error::AlreadyThrottled(_) => {
308                op = Some(UpdateNodeOperation::Remove)
309            }
310            Error::EmptySnapshotChunk => disconnect = false,
311            Error::Throttled(_, msg) => {
312                disconnect = false;
313
314                if let Err(e) = msg.send(io, peer) {
315                    error!("failed to send throttled packet: {:?}", e);
316                    disconnect = true;
317                }
318            }
319            Error::Decoder(_) => op = Some(UpdateNodeOperation::Remove),
320            Error::Io(_) => disconnect = false,
321            Error::Network(kind) => match kind {
322                network::Error::AddressParse => disconnect = false,
323                network::Error::AddressResolve(_) => disconnect = false,
324                network::Error::Auth => disconnect = false,
325                network::Error::BadProtocol => {
326                    op = Some(UpdateNodeOperation::Remove)
327                }
328                network::Error::BadAddr => disconnect = false,
329                network::Error::Decoder(_) => {
330                    op = Some(UpdateNodeOperation::Remove)
331                }
332                network::Error::Expired => disconnect = false,
333                network::Error::Disconnect(_) => disconnect = false,
334                network::Error::InvalidNodeId => disconnect = false,
335                network::Error::OversizedPacket => disconnect = false,
336                network::Error::Io(_) => disconnect = false,
337                network::Error::Throttling(_) => disconnect = false,
338                network::Error::SocketIo(_) => {
339                    op = Some(UpdateNodeOperation::Failure)
340                }
341                network::Error::Msg(_) => {
342                    op = Some(UpdateNodeOperation::Failure)
343                }
344                network::Error::MessageDeprecated { .. } => {
345                    op = Some(UpdateNodeOperation::Failure)
346                }
347                network::Error::SendUnsupportedMessage { .. } => {
348                    op = Some(UpdateNodeOperation::Failure)
349                }
350            },
351            Error::Storage(_) => {}
352            Error::Msg(_) => op = Some(UpdateNodeOperation::Failure),
353            // Error::__Nonexhaustive {} => {
354            //     op = Some(UpdateNodeOperation::Failure)
355            // }
356            Error::InternalError(_) => {}
357            Error::RpcTimeout => {}
358            Error::RpcCancelledByDisconnection => {}
359            Error::UnexpectedMessage(_) => {
360                op = Some(UpdateNodeOperation::Remove)
361            }
362            Error::NotSupported(_) => disconnect = false,
363        }
364
365        if warn {
366            warn!(
367                "Error while handling message, peer={}, msgid={:?}, error={}",
368                peer, msg_id, error_reason
369            );
370        } else {
371            debug!(
372                "Minor error while handling message, peer={}, msgid={:?}, error={}",
373                peer, msg_id, error_reason
374            );
375        }
376
377        if disconnect {
378            io.disconnect_peer(peer, op, reason.as_str());
379        }
380    }
381
382    fn dispatch_message(
383        &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId,
384        msg: &[u8],
385    ) -> Result<(), Error> {
386        trace!("Dispatching message: peer={:?}, msg_id={:?}", peer, msg_id);
387        let peer_hash = if !io.is_peer_self(peer) {
388            if *peer == NodeId::default() {
389                return Err(Error::UnknownPeer.into());
390            }
391            let peer_hash = keccak(peer);
392            if !self.peers.contains(&peer_hash) {
393                return Err(Error::UnknownPeer.into());
394            }
395            peer_hash
396        } else {
397            self.own_node_hash.clone()
398        };
399
400        let ctx = Context {
401            peer_hash,
402            peer: *peer,
403            io,
404            manager: self,
405        };
406
407        if !handle_serialized_message(msg_id, &ctx, msg)? {
408            warn!("Unknown message: peer={:?} msgid={:?}", peer, msg_id);
409            let reason =
410                format!("unknown sync protocol message id {:?}", msg_id);
411            io.disconnect_peer(
412                peer,
413                Some(UpdateNodeOperation::Remove),
414                reason.as_str(),
415            );
416        }
417
418        Ok(())
419    }
420}
421
422pub fn handle_serialized_message(
423    id: MsgId, ctx: &Context, msg: &[u8],
424) -> Result<bool, Error> {
425    match id {
426        msgid::PROPOSAL => handle_message::<ProposalMsg>(ctx, msg)?,
427        msgid::VOTE => handle_message::<VoteMsg>(ctx, msg)?,
428        msgid::SYNC_INFO => handle_message::<SyncInfo>(ctx, msg)?,
429        msgid::BLOCK_RETRIEVAL => {
430            handle_message::<BlockRetrievalRpcRequest>(ctx, msg)?
431        }
432        msgid::BLOCK_RETRIEVAL_RESPONSE => {
433            handle_message::<BlockRetrievalRpcResponse>(ctx, msg)?
434        }
435        msgid::EPOCH_RETRIEVAL => {
436            handle_message::<EpochRetrievalRequest>(ctx, msg)?
437        }
438        msgid::EPOCH_CHANGE => handle_message::<EpochChangeProof>(ctx, msg)?,
439        msgid::CONSENSUS_MSG => handle_message::<ConsensusMsg>(ctx, msg)?,
440        msgid::MEMPOOL_SYNC_MSG => handle_message::<MempoolSyncMsg>(ctx, msg)?,
441        _ => return Ok(false),
442    }
443    Ok(true)
444}
445
446fn handle_message<'a, M>(ctx: &Context, msg: &'a [u8]) -> Result<(), Error>
447where M: Deserialize<'a> + Handleable + Message {
448    let msg: M = bcs::from_bytes(msg)?;
449    let msg_id = msg.msg_id();
450    let msg_name = msg.msg_name();
451    let req_id = msg.get_request_id();
452
453    trace!(
454        "handle sync protocol message, peer = {:?}, id = {}, name = {}, request_id = {:?}",
455        ctx.peer_hash, msg_id, msg_name, req_id,
456    );
457
458    // FIXME: add throttling.
459
460    if let Err(e) = msg.handle(ctx) {
461        info!(
462            "failed to handle sync protocol message, peer = {}, id = {}, name = {}, request_id = {:?}, error_kind = {:?}",
463            ctx.peer, msg_id, msg_name, req_id, e,
464        );
465
466        return Err(e);
467    }
468
469    Ok(())
470}
471
472impl NetworkProtocolHandler for HotStuffSynchronizationProtocol {
473    fn minimum_supported_version(&self) -> ProtocolVersion {
474        ProtocolVersion(0)
475    }
476
477    fn initialize(&self, io: &dyn NetworkContext) {
478        io.register_timer(
479            CHECK_RPC_REQUEST_TIMER,
480            self.protocol_config.check_request_period,
481        )
482        .expect("Error registering check rpc request timer");
483    }
484
485    fn on_message(&self, io: &dyn NetworkContext, peer: &NodeId, raw: &[u8]) {
486        let len = raw.len();
487        if len < 2 {
488            // Empty message.
489            return self.handle_error(
490                io,
491                peer,
492                msgid::INVALID,
493                Error::InvalidMessageFormat.into(),
494            );
495        }
496
497        let msg_id = raw[len - 1];
498        debug!("on_message: peer={:?}, msgid={:?}", peer, msg_id);
499
500        let msg = &raw[0..raw.len() - 1];
501        self.dispatch_message(io, peer, msg_id.into(), msg)
502            .unwrap_or_else(|e| self.handle_error(io, peer, msg_id.into(), e));
503    }
504
505    fn on_peer_connected(
506        &self, io: &dyn NetworkContext, node_id: &NodeId,
507        _peer_protocol_version: ProtocolVersion,
508        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
509    ) {
510        // TODO(linxi): maintain peer protocol version
511        let new_originated = io.get_peer_connection_origin(node_id);
512        if new_originated.is_none() {
513            debug!("Peer does not exist when just connected");
514            return;
515        }
516        let new_originated = new_originated.unwrap();
517        let peer_hash = keccak(node_id);
518
519        let add_new_peer = if let Some(old_peer) = self.peers.remove(&peer_hash)
520        {
521            let old_peer_id = &old_peer.read().id;
522            let old_originated = io.get_peer_connection_origin(old_peer_id);
523            if old_originated.is_none() {
524                debug!("Old session does not exist.");
525                true
526            } else {
527                let old_originated = old_originated.unwrap();
528                if Self::simultaneous_dial_tie_breaking(
529                    self.own_node_hash.clone(),
530                    peer_hash.clone(),
531                    old_originated,
532                    new_originated,
533                ) {
534                    // Drop the existing connection and replace it with the new
535                    // connection.
536                    io.disconnect_peer(
537                        old_peer_id,
538                        Some(UpdateNodeOperation::Failure),
539                        "remove old peer connection",
540                    );
541                    true
542                } else {
543                    // Drop the new connection.
544                    false
545                }
546            }
547        } else {
548            true
549        };
550
551        if add_new_peer {
552            self.peers.insert(peer_hash.clone(), *node_id, None);
553            if let Some(state) = self.peers.get(&peer_hash) {
554                let mut state = state.write();
555                state.id = *node_id;
556                state.peer_hash = peer_hash;
557                self.request_manager.on_peer_connected(node_id);
558            } else {
559                warn!(
560                    "PeerState is missing for peer: peer_hash={:?}",
561                    peer_hash
562                );
563            }
564        } else {
565            io.disconnect_peer(
566                node_id,
567                Some(UpdateNodeOperation::Failure),
568                "remove new peer connection",
569            );
570        }
571
572        if let Some(public_key) = pos_public_key {
573            self.pos_peer_mapping.write().insert(
574                from_consensus_public_key(&public_key.0, &public_key.1),
575                peer_hash,
576            );
577            if add_new_peer {
578                let event = NetworkEvent::PeerConnected;
579                if let Err(e) = self
580                    .mempool_network_task
581                    .network_events_tx
582                    .push((*node_id, discriminant(&event)), (*node_id, event))
583                {
584                    warn!("error sending PeerConnected: e={:?}", e);
585                }
586            }
587            if let Some(state) = self.peers.get(&peer_hash) {
588                state.write().set_pos_public_key(Some(public_key));
589            } else {
590                warn!(
591                    "PeerState is missing for peer: peer_hash={:?}",
592                    peer_hash
593                );
594            }
595        } else {
596            info!(
597                "pos public key is not provided for peer peer_hash={:?}",
598                peer_hash
599            );
600        }
601
602        debug!(
603            "hsb on_peer_connected: peer {:?}, peer_hash {:?}, peer count {}",
604            node_id,
605            peer_hash,
606            self.peers.len()
607        );
608    }
609
610    fn on_peer_disconnected(&self, io: &dyn NetworkContext, peer: &NodeId) {
611        let peer_hash = keccak(*peer);
612        if let Some(peer_state) = self.peers.remove(&peer_hash) {
613            if let Some(pos_public_key) = &peer_state.read().pos_public_key {
614                self.pos_peer_mapping.write().remove(
615                    &from_consensus_public_key(
616                        &pos_public_key.0,
617                        &pos_public_key.1,
618                    ),
619                );
620            }
621        }
622        // notify pos mempool
623        let event = NetworkEvent::PeerDisconnected;
624        if let Err(e) = self
625            .mempool_network_task
626            .network_events_tx
627            .push((*peer, discriminant(&event)), (*peer, event))
628        {
629            warn!("error sending PeerDisconnected: e={:?}", e);
630        }
631
632        self.request_manager.on_peer_disconnected(io, peer);
633        debug!(
634            "hsb on_peer_disconnected: peer={}, peer count {}",
635            peer,
636            self.peers.len()
637        );
638    }
639
640    fn on_timeout(&self, io: &dyn NetworkContext, timer: TimerToken) {
641        trace!("hsb protocol timeout: timer={:?}", timer);
642        match timer {
643            CHECK_RPC_REQUEST_TIMER => {
644                self.remove_expired_flying_request(io);
645            }
646            _ => warn!("hsb protocol: unknown timer {} triggered.", timer),
647        }
648    }
649
650    fn send_local_message(&self, _io: &dyn NetworkContext, _message: Vec<u8>) {
651        todo!()
652    }
653
654    fn on_work_dispatch(&self, _io: &dyn NetworkContext, _work_type: u8) {
655        todo!()
656    }
657}
658
659pub trait Handleable {
660    fn handle(self, ctx: &Context) -> Result<(), Error>;
661}
662
663pub trait RpcResponse: Send + Sync + Debug + AsAny {}
664
665impl From<bcs::Error> for Error {
666    fn from(_: bcs::Error) -> Self { Error::InvalidMessageFormat.into() }
667}
668
669impl From<anyhow::Error> for Error {
670    fn from(error: anyhow::Error) -> Self {
671        Error::InternalError(format!("{}", error)).into()
672    }
673}