cfxcore/
message.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5pub type RequestId = u64;
6pub type MsgId = u16;
7const MSG_ID_MAX: u16 = 1 << 14;
8
9pub use cfx_bytes::Bytes;
10pub use priority_send_queue::SendQueuePriority;
11use rlp::{Decodable, Rlp};
12
13use crate::sync::msg_sender::metric_message;
14use network::{
15    node_table::NodeId, parse_msg_id_leb128_2_bytes_at_most,
16    service::ProtocolVersion, ProtocolId,
17};
18pub use network::{
19    throttling::THROTTLING_SERVICE, Error as NetworkError, NetworkContext,
20    PeerId,
21};
22
23macro_rules! build_msgid {
24    ($($name:ident = $value:expr)*) => {
25        #[allow(dead_code)]
26        pub mod msgid {
27            use super::MsgId;
28            $(pub const $name: MsgId = $value;)*
29        }
30    }
31}
32
33// TODO: GetMaybeRequestId is part of Message due to the implementation of
34// TODO: Throttled. Conceptually this class isn't part of Message.
35pub trait GetMaybeRequestId {
36    fn get_request_id(&self) -> Option<RequestId> { None }
37}
38
39pub trait SetRequestId: GetMaybeRequestId {
40    fn set_request_id(&mut self, _id: RequestId);
41}
42
43pub trait MessageProtocolVersionBound {
44    /// This message is introduced since this version.
45    fn version_introduced(&self) -> ProtocolVersion;
46    /// This message is valid until the specified version.
47    ///
48    /// The return type is NOT defined as Option intentionally,
49    /// because I'd like to make it impossible to keep a Message
50    /// forever by default.
51    ///
52    /// Whenever we bump a protocol version, always update the
53    /// version_valid_till for each message.
54    fn version_valid_till(&self) -> ProtocolVersion;
55}
56
57pub trait Message:
58    Send + Sync + GetMaybeRequestId + MessageProtocolVersionBound
59{
60    // If true, message may be throttled when sent to remote peer.
61    fn is_size_sensitive(&self) -> bool { false }
62    fn msg_id(&self) -> MsgId;
63    fn push_msg_id_leb128_encoding(&self, buffer: &mut Vec<u8>) {
64        let msg_id = self.msg_id();
65        assert!(msg_id < MSG_ID_MAX);
66        let msg_id_msb = (msg_id >> 7) as u8;
67        let mut msg_id_lsb = (msg_id as u8) & 0x7f;
68        if msg_id_msb != 0 {
69            buffer.push(msg_id_msb);
70            msg_id_lsb |= 0x80;
71        }
72        buffer.push(msg_id_lsb);
73    }
74    fn msg_name(&self) -> &'static str;
75    fn priority(&self) -> SendQueuePriority { SendQueuePriority::High }
76
77    fn encode(&self) -> Vec<u8>;
78
79    fn throttle_token_cost(&self) -> (u64, u64) { (1, 0) }
80
81    fn send(
82        &self, io: &dyn NetworkContext, node_id: &NodeId,
83    ) -> Result<(), NetworkError> {
84        self.send_with_throttling(io, node_id, false)
85    }
86
87    fn send_with_throttling(
88        &self, io: &dyn NetworkContext, node_id: &NodeId,
89        throttling_disabled: bool,
90    ) -> Result<(), NetworkError> {
91        if !throttling_disabled && self.is_size_sensitive() {
92            if let Err(e) = THROTTLING_SERVICE.read().check_throttling() {
93                debug!("Throttling failure: {:?}", e);
94                return Err(e);
95            }
96        }
97
98        let msg = self.encode();
99        let size = msg.len();
100
101        if let Err(e) = io.send(
102            node_id,
103            msg,
104            self.version_introduced(),
105            self.version_valid_till(),
106            self.priority(),
107        ) {
108            debug!("Error sending message: {:?}", e);
109            return Err(e);
110        };
111
112        debug!(
113            "Send message({}) to peer {}, protocol {:?}",
114            self.msg_name(),
115            node_id,
116            io.get_protocol(),
117        );
118
119        if !io.is_peer_self(node_id) {
120            metric_message(self.msg_id(), size);
121        }
122
123        Ok(())
124    }
125}
126
127/// Check if we received deprecated message.
128#[inline]
129pub fn decode_rlp_and_check_deprecation<T: Message + Decodable>(
130    rlp: &Rlp, min_supported_version: ProtocolVersion, protocol: ProtocolId,
131) -> Result<T, NetworkError> {
132    let msg: T = rlp.as_val()?;
133
134    if min_supported_version > msg.version_valid_till() {
135        bail!(NetworkError::MessageDeprecated {
136            protocol,
137            msg_id: msg.msg_id(),
138            min_supported_version,
139        });
140    }
141
142    Ok(msg)
143}
144
145pub fn decode_msg(mut msg: &[u8]) -> Option<(MsgId, Rlp<'_>)> {
146    let len = msg.len();
147    if len < 2 {
148        return None;
149    }
150
151    let msg_id = parse_msg_id_leb128_2_bytes_at_most(&mut msg);
152    if msg.is_empty() {
153        return None;
154    }
155    let rlp = Rlp::new(&msg);
156
157    Some((msg_id, rlp))
158}
159
160macro_rules! mark_msg_version_bound {
161    ($name:ident, $msg_ver:expr, $msg_valid_till_ver:expr) => {
162        impl MessageProtocolVersionBound for $name {
163            fn version_introduced(&self) -> ProtocolVersion { $msg_ver }
164
165            fn version_valid_till(&self) -> ProtocolVersion {
166                $msg_valid_till_ver
167            }
168        }
169    };
170}
171
172macro_rules! build_msg_basic {
173    (
174        $name:ident,
175        $msg:expr,
176        $name_str:literal,
177        $msg_ver:expr,
178        $msg_valid_till_ver:expr
179    ) => {
180        mark_msg_version_bound!($name, $msg_ver, $msg_valid_till_ver);
181
182        impl Message for $name {
183            fn msg_id(&self) -> MsgId { $msg }
184
185            fn msg_name(&self) -> &'static str { $name_str }
186
187            fn encode(&self) -> Vec<u8> {
188                let mut encoded = self.rlp_bytes();
189                self.push_msg_id_leb128_encoding(&mut encoded);
190                encoded
191            }
192        }
193    };
194}
195
196macro_rules! build_msg_impl {
197    (
198        $name:ident,
199        $msg:expr,
200        $name_str:literal,
201        $msg_ver:expr,
202        $msg_valid_till_ver:expr
203    ) => {
204        impl GetMaybeRequestId for $name {}
205
206        build_msg_basic!($name, $msg, $name_str, $msg_ver, $msg_valid_till_ver);
207    };
208}
209
210macro_rules! impl_request_id_methods {
211    ($name:ty) => {
212        impl GetMaybeRequestId for $name {
213            fn get_request_id(&self) -> Option<RequestId> {
214                Some(self.request_id)
215            }
216        }
217
218        impl SetRequestId for $name {
219            fn set_request_id(&mut self, id: RequestId) {
220                self.request_id = id;
221            }
222        }
223    };
224}
225
226macro_rules! build_msg_with_request_id_impl {
227    (
228        $name:ident,
229        $msg:expr,
230        $name_str:literal,
231        $msg_ver:expr,
232        $msg_valid_till_ver:expr
233    ) => {
234        build_msg_basic!($name, $msg, $name_str, $msg_ver, $msg_valid_till_ver);
235        impl_request_id_methods!($name);
236    };
237}
238
239#[cfg(test)]
240mod test {
241    use super::Message;
242    use crate::message::{
243        decode_msg, GetMaybeRequestId, MessageProtocolVersionBound, MSG_ID_MAX,
244    };
245    use network::service::ProtocolVersion;
246    use rlp::{Decodable, DecoderError, Encodable, Rlp, RlpStream};
247
248    struct TestMessage {
249        msg_id: u16,
250    }
251
252    impl Encodable for TestMessage {
253        fn rlp_append(&self, s: &mut RlpStream) { s.append(&1u8); }
254    }
255
256    impl Decodable for TestMessage {
257        fn decode(_rlp: &Rlp) -> Result<Self, DecoderError> {
258            Ok(Self { msg_id: 0 })
259        }
260    }
261
262    impl MessageProtocolVersionBound for TestMessage {
263        fn version_introduced(&self) -> ProtocolVersion { unreachable!() }
264
265        fn version_valid_till(&self) -> ProtocolVersion { unreachable!() }
266    }
267
268    impl GetMaybeRequestId for TestMessage {}
269
270    impl Message for TestMessage {
271        fn msg_id(&self) -> u16 { self.msg_id }
272
273        fn msg_name(&self) -> &'static str { "TestMessageIdEncodeDecode" }
274
275        fn encode(&self) -> Vec<u8> { vec![] }
276    }
277
278    #[test]
279    fn test_message_id_encode_decode() {
280        for msg_id in 0..MSG_ID_MAX {
281            let mut buf = vec![];
282            let message = TestMessage { msg_id };
283            buf.extend_from_slice(&message.rlp_bytes());
284            message.push_msg_id_leb128_encoding(&mut buf);
285            match decode_msg(&buf) {
286                None => assert!(false, "Can not decode message"),
287                Some((decoded_msg_id, rlp)) => {
288                    assert_eq!(decoded_msg_id, msg_id);
289                    assert_eq!(rlp.as_raw().len(), 1);
290                }
291            }
292        }
293    }
294}