cfxcore/pos/protocol/
sync_protocol.rs

1// Copyright 2019-2020 Conflux Foundation. All rights reserved.
2// TreeGraph is free software and distributed under Apache License 2.0.
3// See https://www.apache.org/licenses/LICENSE-2.0
4
5use std::{collections::HashMap, fmt::Debug, mem::discriminant, sync::Arc};
6
7use keccak_hash::keccak;
8use parking_lot::RwLock;
9use serde::Deserialize;
10
11use cfx_types::H256;
12use consensus_types::{
13    epoch_retrieval::EpochRetrievalRequest, proposal_msg::ProposalMsg,
14    sync_info::SyncInfo, vote_msg::VoteMsg,
15};
16use diem_types::{
17    account_address::{from_consensus_public_key, AccountAddress},
18    epoch_change::EpochChangeProof,
19    validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey},
20};
21use network::{
22    node_table::NodeId, service::ProtocolVersion, NetworkContext,
23    NetworkProtocolHandler, NetworkService, UpdateNodeOperation,
24};
25
26use crate::{
27    message::{Message, MsgId},
28    pos::{
29        consensus::network::{
30            ConsensusMsg, NetworkTask as ConsensusNetworkTask,
31        },
32        mempool::network::{MempoolSyncMsg, NetworkTask as MempoolNetworkTask},
33        protocol::{
34            message::{
35                block_retrieval::BlockRetrievalRpcRequest,
36                block_retrieval_response::BlockRetrievalRpcResponse, msgid,
37            },
38            network_event::NetworkEvent,
39            request_manager::{
40                request_handler::AsAny, RequestManager, RequestMessage,
41            },
42        },
43    },
44    sync::{Error, ProtocolConfiguration, CHECK_RPC_REQUEST_TIMER},
45};
46
47use super::{HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION};
48
49type TimerToken = usize;
50
51#[derive(Default)]
52pub struct PeerState {
53    id: NodeId,
54    // TODO(lpl): Only keep AccountAddress?
55    pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
56}
57
58impl PeerState {
59    pub fn new(
60        id: NodeId,
61        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
62    ) -> Self {
63        Self { id, pos_public_key }
64    }
65
66    pub fn set_pos_public_key(
67        &mut self,
68        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
69    ) {
70        self.pos_public_key = pos_public_key
71    }
72
73    pub fn get_id(&self) -> NodeId { self.id }
74}
75
76#[derive(Default)]
77pub struct Peers(RwLock<HashMap<H256, Arc<RwLock<PeerState>>>>);
78
79impl Peers {
80    pub fn new() -> Peers { Self::default() }
81
82    pub fn get(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
83        self.0.read().get(peer).cloned()
84    }
85
86    pub fn insert(
87        &self, peer: H256, id: NodeId,
88        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
89    ) {
90        self.0.write().insert(
91            peer,
92            Arc::new(RwLock::new(PeerState::new(id, pos_public_key))),
93        );
94    }
95
96    pub fn len(&self) -> usize { self.0.read().len() }
97
98    pub fn is_empty(&self) -> bool { self.0.read().is_empty() }
99
100    pub fn contains(&self, peer: &H256) -> bool {
101        self.0.read().contains_key(peer)
102    }
103
104    pub fn remove(&self, peer: &H256) -> Option<Arc<RwLock<PeerState>>> {
105        self.0.write().remove(peer)
106    }
107
108    pub fn all_peers_satisfying<F>(&self, mut predicate: F) -> Vec<H256>
109    where F: FnMut(&mut PeerState) -> bool {
110        self.0
111            .read()
112            .iter()
113            .filter_map(|(id, state)| {
114                if predicate(&mut *state.write()) {
115                    Some(*id)
116                } else {
117                    None
118                }
119            })
120            .collect()
121    }
122
123    pub fn fold<B, F>(&self, init: B, f: F) -> B
124    where F: FnMut(B, &Arc<RwLock<PeerState>>) -> B {
125        self.0.write().values().fold(init, f)
126    }
127}
128
129pub struct Context<'a> {
130    pub io: &'a dyn NetworkContext,
131    pub peer: NodeId,
132    pub peer_hash: H256,
133    pub manager: &'a HotStuffSynchronizationProtocol,
134}
135
136impl<'a> Context<'a> {
137    pub fn match_request(
138        &self, request_id: u64,
139    ) -> Result<RequestMessage, Error> {
140        self.manager
141            .request_manager
142            .match_request(self.io, &self.peer, request_id)
143    }
144
145    pub fn send_response(&self, response: &dyn Message) -> Result<(), Error> {
146        response.send(self.io, &self.peer)?;
147        Ok(())
148    }
149
150    pub fn get_peer_account_address(&self) -> Result<AccountAddress, Error> {
151        let k = self.get_pos_public_key().ok_or(Error::UnknownPeer)?;
152        Ok(from_consensus_public_key(&k.0, &k.1))
153    }
154
155    fn get_pos_public_key(
156        &self,
157    ) -> Option<(ConsensusPublicKey, ConsensusVRFPublicKey)> {
158        self.manager
159            .peers
160            .get(&self.peer_hash)
161            .as_ref()?
162            .read()
163            .pos_public_key
164            .clone()
165    }
166}
167
168pub struct HotStuffSynchronizationProtocol {
169    pub protocol_config: ProtocolConfiguration,
170    pub own_node_hash: H256,
171    pub peers: Arc<Peers>,
172    pub request_manager: Arc<RequestManager>,
173    pub consensus_network_task: ConsensusNetworkTask,
174    pub mempool_network_task: MempoolNetworkTask,
175    pub pos_peer_mapping: RwLock<HashMap<AccountAddress, H256>>,
176}
177
178impl HotStuffSynchronizationProtocol {
179    pub fn new(
180        own_node_hash: H256, consensus_network_task: ConsensusNetworkTask,
181        mempool_network_task: MempoolNetworkTask,
182        protocol_config: ProtocolConfiguration,
183    ) -> Self {
184        let request_manager = Arc::new(RequestManager::new(&protocol_config));
185        HotStuffSynchronizationProtocol {
186            protocol_config,
187            own_node_hash,
188            peers: Arc::new(Peers::new()),
189            request_manager,
190            consensus_network_task,
191            mempool_network_task,
192            pos_peer_mapping: RwLock::new(Default::default()),
193        }
194    }
195
196    pub fn with_peers(
197        protocol_config: ProtocolConfiguration, own_node_hash: H256,
198        consensus_network_task: ConsensusNetworkTask,
199        mempool_network_task: MempoolNetworkTask, peers: Arc<Peers>,
200    ) -> Self {
201        let request_manager = Arc::new(RequestManager::new(&protocol_config));
202        HotStuffSynchronizationProtocol {
203            protocol_config,
204            own_node_hash,
205            peers,
206            request_manager,
207            consensus_network_task,
208            mempool_network_task,
209            pos_peer_mapping: RwLock::new(Default::default()),
210        }
211    }
212
213    pub fn register(
214        self: Arc<Self>, network: Arc<NetworkService>,
215    ) -> Result<(), String> {
216        network
217            .register_protocol(self, HSB_PROTOCOL_ID, HSB_PROTOCOL_VERSION)
218            .map_err(|e| {
219                format!(
220                    "failed to register HotStuffSynchronizationProtocol: {:?}",
221                    e
222                )
223            })
224    }
225
226    pub fn remove_expired_flying_request(&self, io: &dyn NetworkContext) {
227        self.request_manager.process_timeout_requests(io);
228        self.request_manager.resend_waiting_requests(io);
229    }
230
231    fn handle_error(
232        &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId, e: Error,
233    ) {
234        let mut disconnect = true;
235        let mut warn = false;
236        let reason = format!("{}", e);
237        let error_reason = format!("{:?}", e);
238        let mut op = None;
239
240        // NOTE, DO NOT USE WILDCARD IN THE FOLLOWING MATCH STATEMENT!
241        // COMPILER WILL HELP TO FIND UNHANDLED ERROR CASES.
242        match e {
243            Error::InvalidBlock => op = Some(UpdateNodeOperation::Demotion),
244            Error::InvalidGetBlockTxn(_) => {
245                op = Some(UpdateNodeOperation::Demotion)
246            }
247            Error::InvalidStatus(_) => op = Some(UpdateNodeOperation::Failure),
248            Error::InvalidMessageFormat => {
249                op = Some(UpdateNodeOperation::Remove)
250            }
251            Error::UnknownPeer => {
252                warn = false;
253                op = Some(UpdateNodeOperation::Failure)
254            }
255            // TODO handle the unexpected response case (timeout or real invalid
256            // message type)
257            Error::UnexpectedResponse => disconnect = true,
258            Error::RequestNotFound => {
259                warn = false;
260                disconnect = false;
261            }
262            Error::InCatchUpMode(_) => {
263                disconnect = false;
264                warn = false;
265            }
266            Error::TooManyTrans => {}
267            Error::InvalidTimestamp => op = Some(UpdateNodeOperation::Demotion),
268            Error::InvalidSnapshotManifest(_) => {
269                op = Some(UpdateNodeOperation::Demotion)
270            }
271            Error::InvalidSnapshotChunk(_) => {
272                op = Some(UpdateNodeOperation::Demotion)
273            }
274            Error::AlreadyThrottled(_) => {
275                op = Some(UpdateNodeOperation::Remove)
276            }
277            Error::EmptySnapshotChunk => disconnect = false,
278            Error::Throttled(_, msg) => {
279                disconnect = false;
280
281                if let Err(e) = msg.send(io, peer) {
282                    error!("failed to send throttled packet: {:?}", e);
283                    disconnect = true;
284                }
285            }
286            Error::Decoder(_) => op = Some(UpdateNodeOperation::Remove),
287            Error::Io(_) => disconnect = false,
288            Error::Network(kind) => match kind {
289                network::Error::AddressParse => disconnect = false,
290                network::Error::AddressResolve(_) => disconnect = false,
291                network::Error::Auth => disconnect = false,
292                network::Error::BadProtocol => {
293                    op = Some(UpdateNodeOperation::Remove)
294                }
295                network::Error::BadAddr => disconnect = false,
296                network::Error::Decoder(_) => {
297                    op = Some(UpdateNodeOperation::Remove)
298                }
299                network::Error::Expired => disconnect = false,
300                network::Error::Disconnect(_) => disconnect = false,
301                network::Error::InvalidNodeId => disconnect = false,
302                network::Error::OversizedPacket => disconnect = false,
303                network::Error::Io(_) => disconnect = false,
304                network::Error::Throttling(_) => disconnect = false,
305                network::Error::SocketIo(_) => {
306                    op = Some(UpdateNodeOperation::Failure)
307                }
308                network::Error::Msg(_) => {
309                    op = Some(UpdateNodeOperation::Failure)
310                }
311                network::Error::MessageDeprecated { .. } => {
312                    op = Some(UpdateNodeOperation::Failure)
313                }
314                network::Error::SendUnsupportedMessage { .. } => {
315                    op = Some(UpdateNodeOperation::Failure)
316                }
317            },
318            Error::Storage(_) => {}
319            Error::Msg(_) => op = Some(UpdateNodeOperation::Failure),
320            // Error::__Nonexhaustive {} => {
321            //     op = Some(UpdateNodeOperation::Failure)
322            // }
323            Error::InternalError(_) => {}
324            Error::RpcTimeout => {}
325            Error::RpcCancelledByDisconnection => {}
326            Error::UnexpectedMessage(_) => {
327                op = Some(UpdateNodeOperation::Remove)
328            }
329            Error::NotSupported(_) => disconnect = false,
330        }
331
332        if warn {
333            warn!(
334                "Error while handling message, peer={}, msgid={:?}, error={}",
335                peer, msg_id, error_reason
336            );
337        } else {
338            debug!(
339                "Minor error while handling message, peer={}, msgid={:?}, error={}",
340                peer, msg_id, error_reason
341            );
342        }
343
344        if disconnect {
345            io.disconnect_peer(peer, op, reason.as_str());
346        }
347    }
348
349    fn dispatch_message(
350        &self, io: &dyn NetworkContext, peer: &NodeId, msg_id: MsgId,
351        msg: &[u8],
352    ) -> Result<(), Error> {
353        trace!("Dispatching message: peer={:?}, msg_id={:?}", peer, msg_id);
354        let peer_hash = if !io.is_peer_self(peer) {
355            if *peer == NodeId::default() {
356                return Err(Error::UnknownPeer.into());
357            }
358            let peer_hash = keccak(peer);
359            if !self.peers.contains(&peer_hash) {
360                return Err(Error::UnknownPeer.into());
361            }
362            peer_hash
363        } else {
364            self.own_node_hash.clone()
365        };
366
367        let ctx = Context {
368            peer_hash,
369            peer: *peer,
370            io,
371            manager: self,
372        };
373
374        if !handle_serialized_message(msg_id, &ctx, msg)? {
375            warn!("Unknown message: peer={:?} msgid={:?}", peer, msg_id);
376            let reason =
377                format!("unknown sync protocol message id {:?}", msg_id);
378            io.disconnect_peer(
379                peer,
380                Some(UpdateNodeOperation::Remove),
381                reason.as_str(),
382            );
383        }
384
385        Ok(())
386    }
387}
388
389pub fn handle_serialized_message(
390    id: MsgId, ctx: &Context, msg: &[u8],
391) -> Result<bool, Error> {
392    match id {
393        msgid::PROPOSAL => handle_message::<ProposalMsg>(ctx, msg)?,
394        msgid::VOTE => handle_message::<VoteMsg>(ctx, msg)?,
395        msgid::SYNC_INFO => handle_message::<SyncInfo>(ctx, msg)?,
396        msgid::BLOCK_RETRIEVAL => {
397            handle_message::<BlockRetrievalRpcRequest>(ctx, msg)?
398        }
399        msgid::BLOCK_RETRIEVAL_RESPONSE => {
400            handle_message::<BlockRetrievalRpcResponse>(ctx, msg)?
401        }
402        msgid::EPOCH_RETRIEVAL => {
403            handle_message::<EpochRetrievalRequest>(ctx, msg)?
404        }
405        msgid::EPOCH_CHANGE => handle_message::<EpochChangeProof>(ctx, msg)?,
406        msgid::CONSENSUS_MSG => handle_message::<ConsensusMsg>(ctx, msg)?,
407        msgid::MEMPOOL_SYNC_MSG => handle_message::<MempoolSyncMsg>(ctx, msg)?,
408        _ => return Ok(false),
409    }
410    Ok(true)
411}
412
413fn handle_message<'a, M>(ctx: &Context, msg: &'a [u8]) -> Result<(), Error>
414where M: Deserialize<'a> + Handleable + Message {
415    let msg: M = bcs::from_bytes(msg)?;
416    let msg_id = msg.msg_id();
417    let msg_name = msg.msg_name();
418    let req_id = msg.get_request_id();
419
420    trace!(
421        "handle sync protocol message, peer = {:?}, id = {}, name = {}, request_id = {:?}",
422        ctx.peer_hash, msg_id, msg_name, req_id,
423    );
424
425    // FIXME: add throttling.
426
427    if let Err(e) = msg.handle(ctx) {
428        info!(
429            "failed to handle sync protocol message, peer = {}, id = {}, name = {}, request_id = {:?}, error_kind = {:?}",
430            ctx.peer, msg_id, msg_name, req_id, e,
431        );
432
433        return Err(e);
434    }
435
436    Ok(())
437}
438
439impl NetworkProtocolHandler for HotStuffSynchronizationProtocol {
440    fn minimum_supported_version(&self) -> ProtocolVersion {
441        ProtocolVersion(0)
442    }
443
444    fn initialize(&self, io: &dyn NetworkContext) {
445        io.register_timer(
446            CHECK_RPC_REQUEST_TIMER,
447            self.protocol_config.check_request_period,
448        )
449        .expect("Error registering check rpc request timer");
450    }
451
452    fn on_message(&self, io: &dyn NetworkContext, peer: &NodeId, raw: &[u8]) {
453        let len = raw.len();
454        if len < 2 {
455            // Empty message.
456            return self.handle_error(
457                io,
458                peer,
459                msgid::INVALID,
460                Error::InvalidMessageFormat.into(),
461            );
462        }
463
464        let msg_id = raw[len - 1];
465        debug!("on_message: peer={:?}, msgid={:?}", peer, msg_id);
466
467        let msg = &raw[0..raw.len() - 1];
468        self.dispatch_message(io, peer, msg_id.into(), msg)
469            .unwrap_or_else(|e| self.handle_error(io, peer, msg_id.into(), e));
470    }
471
472    fn on_peer_connected(
473        &self, io: &dyn NetworkContext, node_id: &NodeId,
474        _peer_protocol_version: ProtocolVersion,
475        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
476    ) {
477        if io.get_peer_connection_origin(node_id).is_none() {
478            debug!("Peer does not exist when just connected");
479            return;
480        }
481        let peer_hash = keccak(node_id);
482
483        // Unconditionally replace any existing entry. The network layer
484        // already serializes session conflicts via update_ingress_node_id
485        // and calls on_peer_disconnected before on_peer_connected; any
486        // entry observed here is a stale leftover from a thread race
487        // between simultaneous connections to the same peer.
488        self.peers
489            .insert(peer_hash, *node_id, pos_public_key.clone());
490        self.request_manager.on_peer_connected(node_id);
491
492        if let Some(public_key) = pos_public_key {
493            self.pos_peer_mapping.write().insert(
494                from_consensus_public_key(&public_key.0, &public_key.1),
495                peer_hash,
496            );
497            let event = NetworkEvent::PeerConnected;
498            if let Err(e) = self
499                .mempool_network_task
500                .network_events_tx
501                .push((*node_id, discriminant(&event)), (*node_id, event))
502            {
503                warn!("error sending PeerConnected: e={:?}", e);
504            }
505        } else {
506            info!(
507                "pos public key is not provided for peer peer_hash={:?}",
508                peer_hash
509            );
510        }
511
512        debug!(
513            "hsb on_peer_connected: peer {:?}, peer_hash {:?}, peer count {}",
514            node_id,
515            peer_hash,
516            self.peers.len()
517        );
518    }
519
520    fn on_peer_disconnected(&self, io: &dyn NetworkContext, peer: &NodeId) {
521        let peer_hash = keccak(*peer);
522        if let Some(peer_state) = self.peers.remove(&peer_hash) {
523            if let Some(pos_public_key) = &peer_state.read().pos_public_key {
524                self.pos_peer_mapping.write().remove(
525                    &from_consensus_public_key(
526                        &pos_public_key.0,
527                        &pos_public_key.1,
528                    ),
529                );
530            }
531        }
532        // notify pos mempool
533        let event = NetworkEvent::PeerDisconnected;
534        if let Err(e) = self
535            .mempool_network_task
536            .network_events_tx
537            .push((*peer, discriminant(&event)), (*peer, event))
538        {
539            warn!("error sending PeerDisconnected: e={:?}", e);
540        }
541
542        self.request_manager.on_peer_disconnected(io, peer);
543        debug!(
544            "hsb on_peer_disconnected: peer={}, peer count {}",
545            peer,
546            self.peers.len()
547        );
548    }
549
550    fn on_timeout(&self, io: &dyn NetworkContext, timer: TimerToken) {
551        trace!("hsb protocol timeout: timer={:?}", timer);
552        match timer {
553            CHECK_RPC_REQUEST_TIMER => {
554                self.remove_expired_flying_request(io);
555            }
556            _ => warn!("hsb protocol: unknown timer {} triggered.", timer),
557        }
558    }
559
560    fn send_local_message(&self, _io: &dyn NetworkContext, _message: Vec<u8>) {
561        todo!()
562    }
563
564    fn on_work_dispatch(&self, _io: &dyn NetworkContext, _work_type: u8) {
565        todo!()
566    }
567}
568
569pub trait Handleable {
570    fn handle(self, ctx: &Context) -> Result<(), Error>;
571}
572
573pub trait RpcResponse: Send + Sync + Debug + AsAny {}
574
575impl From<bcs::Error> for Error {
576    fn from(_: bcs::Error) -> Self { Error::InvalidMessageFormat.into() }
577}
578
579impl From<anyhow::Error> for Error {
580    fn from(error: anyhow::Error) -> Self {
581        Error::InternalError(format!("{}", error)).into()
582    }
583}