network/
session.rs

1// Copyright 2019 Conflux Foundation. All rights reserved.
2// Conflux is free software and distributed under GNU General Public License.
3// See http://www.gnu.org/licenses/
4
5use crate::{
6    connection::{Connection, ConnectionDetails, SendQueueStatus, WriteStatus},
7    handshake::Handshake,
8    node_table::{NodeEndpoint, NodeEntry, NodeId},
9    parse_msg_id_leb128_2_bytes_at_most,
10    service::{NetworkServiceInner, ProtocolVersion},
11    DisconnectReason, Error, ProtocolId, ProtocolInfo, SessionMetadata,
12    UpdateNodeOperation, PROTOCOL_ID_SIZE,
13};
14use bytes::Bytes;
15use cfx_util_macros::bail;
16use diem_crypto::{bls::BLS_PUBLIC_KEY_LENGTH, ValidCryptoMaterial};
17use diem_types::validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey};
18use io::{IoContext, StreamToken};
19use log::{debug, trace};
20use mio::{net::TcpStream, Registry, Token};
21use priority_send_queue::SendQueuePriority;
22use rlp::{Rlp, RlpStream};
23use serde::Deserialize;
24use serde_derive::Serialize;
25use std::{
26    convert::TryFrom,
27    fmt,
28    net::SocketAddr,
29    str,
30    time::{Duration, Instant},
31};
32
33/// Peer session over TCP connection, including outgoing and incoming sessions.
34///
35/// When a session created, 2 peers handshake with each other to exchange the
36/// node id based on asymmetric cryptography. After handshake, peers send HELLO
37/// packet to exchange the supported protocols. Then, session is ready to send
38/// and receive protocol packets.
39///
40/// Conflux do not use AES based encrypted connection to send protocol packets.
41/// This is because that Conflux has high TPS, and the encryption/decryption
42/// workloads are very heavy (about 20% CPU time in 3000 TPS).
43pub struct Session {
44    /// Session information
45    pub metadata: SessionMetadata,
46    /// Socket address of remote peer
47    address: SocketAddr,
48    /// Session state
49    state: State,
50    /// Timestamp of when Hello packet sent, which is used to measure timeout.
51    sent_hello: Instant,
52    /// Session ready flag that set after successful Hello packet received.
53    had_hello: Option<Instant>,
54    /// Session is no longer active flag.
55    expired: Option<Instant>,
56
57    // statistics for read/write
58    last_read: Instant,
59    last_write: (Instant, WriteStatus),
60    pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
61}
62
63/// Session state.
64enum State {
65    /// Handshake to exchange node id.
66    /// When handshake completed, the underlying TCP connection instance of
67    /// handshake will also be moved to the state `State::Session`.
68    Handshake(MovableWrapper<Handshake>),
69    /// Ready to send Hello or protocol packets.
70    Session(Connection),
71}
72
73/// Session data represents various of packet read from socket.
74pub enum SessionData {
75    /// No packet read from socket.
76    None,
77    /// Session is ready to send or receive protocol packets.
78    Ready {
79        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
80    },
81    /// A protocol packet has been received, and delegate to the corresponding
82    /// protocol handler to handle the packet.
83    Message { data: Vec<u8>, protocol: ProtocolId },
84    /// Session has more data to be read.
85    Continue,
86}
87
88pub struct SessionDataWithDisconnectInfo {
89    pub session_data: SessionData,
90    pub token_to_disconnect: Option<(StreamToken, String)>,
91}
92
93// id for Hello packet
94const PACKET_HELLO: u8 = 0x80;
95// id for Disconnect packet
96const PACKET_DISCONNECT: u8 = 0x01;
97// id for protocol packet
98pub const PACKET_USER: u8 = 0x10;
99/// header_version for protocol packet.
100/// Change the version only when there is a major change to the protocol packet.
101pub const PACKET_HEADER_VERSION: u8 = 0;
102/// The header version where extension is introduced.
103const HEADER_VERSION_WITH_EXTENSION: u8 = 0;
104
105impl Session {
106    /// Create a new instance of `Session`, which starts to handshake with
107    /// remote peer.
108    pub fn new<Message: Send + Sync + Clone + 'static>(
109        io: &IoContext<Message>, socket: TcpStream, address: SocketAddr,
110        id: Option<&NodeId>, peer_header_version: u8, token: StreamToken,
111        host: &NetworkServiceInner,
112        pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
113    ) -> Result<Session, Error> {
114        let originated = id.is_some();
115
116        let mut handshake = Handshake::new(token, id, socket);
117        handshake.start(io, &host.metadata)?;
118
119        Ok(Session {
120            metadata: SessionMetadata {
121                id: id.cloned(),
122                peer_protocols: Vec::new(),
123                originated,
124                peer_header_version,
125            },
126            address,
127            state: State::Handshake(MovableWrapper::new(handshake)),
128            sent_hello: Instant::now(),
129            had_hello: None,
130            expired: None,
131            last_read: Instant::now(),
132            last_write: (Instant::now(), WriteStatus::Complete),
133            pos_public_key,
134        })
135    }
136
137    pub fn have_capability(&self, protocol: ProtocolId) -> bool {
138        self.metadata
139            .peer_protocols
140            .iter()
141            .any(|c| c.protocol == protocol)
142    }
143
144    /// Get id of the remote peer
145    pub fn id(&self) -> Option<&NodeId> { self.metadata.id.as_ref() }
146
147    pub fn originated(&self) -> bool { self.metadata.originated }
148
149    pub fn is_ready(&self) -> bool { self.had_hello.is_some() }
150
151    pub fn expired(&self) -> bool { self.expired.is_some() }
152
153    pub fn set_expired(&mut self) { self.expired = Some(Instant::now()); }
154
155    pub fn done(&self) -> bool {
156        self.expired() && !self.connection().is_sending()
157    }
158
159    fn connection(&self) -> &Connection {
160        match self.state {
161            State::Handshake(ref h) => &h.get().connection,
162            State::Session(ref c) => c,
163        }
164    }
165
166    fn connection_mut(&mut self) -> &mut Connection {
167        match self.state {
168            State::Handshake(ref mut h) => &mut h.get_mut().connection,
169            State::Session(ref mut c) => c,
170        }
171    }
172
173    pub fn token(&self) -> StreamToken { self.connection().token() }
174
175    pub fn address(&self) -> SocketAddr { self.address }
176
177    /// Register event loop for the underlying connection.
178    /// If session expired, no effect taken.
179    pub fn register_socket(
180        &mut self, reg: Token, poll_registry: &Registry,
181    ) -> Result<(), Error> {
182        if !self.expired() {
183            self.connection_mut().register_socket(reg, poll_registry)?;
184        }
185
186        Ok(())
187    }
188
189    /// Update the event loop for the underlying connection.
190    pub fn update_socket(
191        &mut self, reg: Token, poll_registry: &Registry,
192    ) -> Result<(), Error> {
193        self.connection_mut().update_socket(reg, poll_registry)?;
194        Ok(())
195    }
196
197    /// Deregister the event loop for the underlying connection.
198    pub fn deregister_socket(
199        &mut self, poll_registry: &Registry,
200    ) -> Result<(), Error> {
201        self.connection_mut().deregister_socket(poll_registry)?;
202        Ok(())
203    }
204
205    /// Complete the handshake process:
206    /// 1. For incoming session, check if the remote peer is blacklisted.
207    /// 2. Change the session state to `State::Session`.
208    /// 3. Send Hello packet to remote peer.
209    fn complete_handshake<Message>(
210        &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
211    ) -> Result<(), Error>
212    where Message: Send + Sync + Clone {
213        let wrapper = match self.state {
214            State::Handshake(ref mut h) => h,
215            State::Session(_) => panic!("Unexpected session state"),
216        };
217
218        // update node id for ingress session
219        if self.metadata.id.is_none() {
220            let id = wrapper.get().id.clone();
221
222            // refuse incoming session if the node is blacklisted
223            if host.node_db.write().evaluate_blacklisted(&id) {
224                return Err(self.send_disconnect(DisconnectReason::Blacklisted));
225            }
226
227            self.metadata.id = Some(id);
228        }
229
230        // write HELLO packet to remote peer
231        self.state = State::Session(wrapper.take().connection);
232        self.write_hello(io, host)?;
233
234        Ok(())
235    }
236
237    /// Readable IO handler. Returns packet data if available.
238    pub fn readable<Message: Send + Sync + Clone>(
239        &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
240    ) -> Result<SessionDataWithDisconnectInfo, Error> {
241        // update the last read timestamp for statistics
242        self.last_read = Instant::now();
243
244        if self.expired() {
245            debug!("cannot read data due to expired, session = {:?}", self);
246            return Ok(SessionDataWithDisconnectInfo {
247                session_data: SessionData::None,
248                token_to_disconnect: None,
249            });
250        }
251
252        match self.state {
253            State::Handshake(ref mut h) => {
254                let h = h.get_mut();
255
256                if !h.readable(io, &host.metadata)? {
257                    return Ok(SessionDataWithDisconnectInfo {
258                        session_data: SessionData::None,
259                        token_to_disconnect: None,
260                    });
261                }
262
263                if h.done() {
264                    self.complete_handshake(io, host)?;
265                    io.update_registration(self.token()).unwrap_or_else(|e| {
266                        debug!("Token registration error: {:?}", e)
267                    });
268                }
269
270                Ok(SessionDataWithDisconnectInfo {
271                    session_data: SessionData::Continue,
272                    token_to_disconnect: None,
273                })
274            }
275            State::Session(ref mut c) => match c.readable()? {
276                Some(data) => Ok(self.read_packet(data, host)?),
277                None => Ok(SessionDataWithDisconnectInfo {
278                    session_data: SessionData::None,
279                    token_to_disconnect: None,
280                }),
281            },
282        }
283    }
284
285    /// Handle the packet from underlying connection.
286    fn read_packet(
287        &mut self, data: Bytes, host: &NetworkServiceInner,
288    ) -> Result<SessionDataWithDisconnectInfo, Error> {
289        let packet = SessionPacket::parse(data)?;
290
291        // For protocol packet, the Hello packet should already been received.
292        // So that dispatch it to the corresponding protocol handler.
293        if packet.id != PACKET_HELLO
294            && packet.id != PACKET_DISCONNECT
295            && self.had_hello.is_none()
296        {
297            return Err(Error::BadProtocol.into());
298        }
299
300        match packet.id {
301            PACKET_HELLO => {
302                debug!("Read HELLO in session {:?}", self);
303                self.metadata.peer_header_version = packet.header_version;
304                // For ingress session, update the node id in `SessionManager`
305                let token_to_disconnect = self.update_ingress_node_id(host)?;
306
307                let token_to_disconnect = match token_to_disconnect {
308                    Some(token) => Some((
309                        token,
310                        String::from("Remove old session from the same node"),
311                    )),
312                    None => None,
313                };
314
315                // Handle Hello packet to exchange protocols
316                let rlp = Rlp::new(&packet.data);
317                let pos_public_key = self.read_hello(&rlp, host)?;
318                Ok(SessionDataWithDisconnectInfo {
319                    session_data: SessionData::Ready { pos_public_key },
320                    token_to_disconnect,
321                })
322            }
323            PACKET_DISCONNECT => {
324                let rlp = Rlp::new(&packet.data);
325                let reason: DisconnectReason = rlp.as_val()?;
326                debug!(
327                    "read packet DISCONNECT, reason = {}, session = {:?}",
328                    reason, self
329                );
330                Err(Error::Disconnect(reason).into())
331            }
332            PACKET_USER => Ok(SessionDataWithDisconnectInfo {
333                session_data: SessionData::Message {
334                    data: packet.data.to_vec(),
335                    protocol: packet
336                        .protocol
337                        .expect("protocol should available for USER packet"),
338                },
339                token_to_disconnect: None,
340            }),
341            _ => {
342                debug!(
343                    "read packet UNKNOWN, packet_id = {:?}, session = {:?}",
344                    packet.id, self
345                );
346                Err(Error::BadProtocol.into())
347            }
348        }
349    }
350
351    /// Update node Id in `SessionManager` for ingress session.
352    fn update_ingress_node_id(
353        &mut self, host: &NetworkServiceInner,
354    ) -> Result<Option<usize>, Error> {
355        // ignore egress session
356        if self.metadata.originated {
357            return Ok(None);
358        }
359
360        let token = self.token();
361        let node_id = self
362            .metadata
363            .id
364            .expect("should have node id after handshake");
365
366        host.sessions.update_ingress_node_id(token, &node_id)
367            .map_err(|reason| {
368                debug!(
369                    "failed to update node id of ingress session, reason = {:?}, session = {:?}",
370                    reason, self
371                );
372
373                self.send_disconnect(DisconnectReason::UpdateNodeIdFailed)
374            })
375    }
376
377    /// Read Hello packet to exchange the supported protocols, and set the
378    /// `had_hello` flag to indicates that session is ready to send/receive
379    /// protocol packets.
380    ///
381    /// Besides, the node endpoint of remote peer will be added or updated in
382    /// node database, which is used to establish outgoing connections.
383    fn read_hello(
384        &mut self, rlp: &Rlp, host: &NetworkServiceInner,
385    ) -> Result<Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>, Error>
386    {
387        let remote_network_id: u64 = rlp.val_at(0)?;
388        if remote_network_id != host.metadata.network_id {
389            debug!(
390                "failed to read hello, network id mismatch, self = {}, remote = {}",
391                host.metadata.network_id, remote_network_id);
392            return Err(self.send_disconnect(DisconnectReason::Custom(
393                "network id mismatch".into(),
394            )));
395        }
396
397        let mut peer_caps: Vec<ProtocolInfo> = rlp.list_at(1)?;
398        for i in 1..peer_caps.len() {
399            for j in 0..i {
400                if peer_caps[j].protocol == peer_caps[i].protocol {
401                    debug!(
402                        "Invalid protocol list from hello. Duplication: {:?},\
403                         remote = {}",
404                        peer_caps[i].protocol, remote_network_id
405                    );
406                    bail!(self.send_disconnect(DisconnectReason::Custom(
407                        "Invalid protocol list: duplication.".into()
408                    )))
409                }
410            }
411        }
412
413        peer_caps.retain(|c| {
414            host.metadata
415                .minimum_peer_protocol_version
416                .read()
417                .iter()
418                .any(|hc| hc.protocol == c.protocol && hc.version <= c.version)
419        });
420
421        self.metadata.peer_protocols = peer_caps;
422        if self.metadata.peer_protocols.is_empty() {
423            debug!("No common capabilities with remote peer, peer_node_id = {:?}, session = {:?}", self.metadata.id, self);
424            return Err(self.send_disconnect(DisconnectReason::UselessPeer));
425        }
426
427        let mut hello_from = NodeEndpoint::from_rlp(&rlp.at(2)?)?;
428        // Use the ip of the socket as endpoint ip directly.
429        // We do not allow peers to specify the ip to avoid being used to DDoS
430        // the target ip.
431        hello_from.address.set_ip(self.address.ip());
432
433        let ping_to = NodeEndpoint {
434            address: hello_from.address,
435            udp_port: hello_from.udp_port,
436        };
437
438        let entry = NodeEntry {
439            id: self
440                .metadata
441                .id
442                .expect("should have node ID after handshake"),
443            endpoint: ping_to,
444        };
445        if !entry.endpoint.is_valid() {
446            debug!("Got invalid endpoint {:?}, session = {:?}", entry, self);
447            return Err(
448                self.send_disconnect(DisconnectReason::WrongEndpointInfo)
449            );
450        } else if !(entry.endpoint.is_allowed(host.get_ip_filter())
451            && entry.id != *host.metadata.id())
452        {
453            debug!(
454                "Address not allowed, endpoint = {:?}, session = {:?}",
455                entry, self
456            );
457            return Err(self.send_disconnect(DisconnectReason::IpLimited));
458        } else {
459            debug!("Received valid endpoint {:?}, session = {:?}", entry, self);
460            host.node_db.write().insert_with_token(entry, self.token());
461        }
462
463        self.had_hello = Some(Instant::now());
464        match rlp.item_count()? {
465            3 => Ok(None),
466            4 => {
467                // FIXME(lpl): Verify keys.
468                let pos_public_key_bytes: Vec<u8> = rlp.val_at(3)?;
469                trace!("pos_public_key_bytes: {:?}", pos_public_key_bytes);
470                if pos_public_key_bytes.len() < BLS_PUBLIC_KEY_LENGTH {
471                    bail!("pos public key bytes is too short!");
472                }
473                let bls_pub_key = ConsensusPublicKey::try_from(
474                    &pos_public_key_bytes[..BLS_PUBLIC_KEY_LENGTH],
475                )
476                .map_err(|e| Error::Decoder(format!("{:?}", e)))?;
477                let vrf_pub_key = ConsensusVRFPublicKey::try_from(
478                    &pos_public_key_bytes[BLS_PUBLIC_KEY_LENGTH..],
479                )
480                .map_err(|e| Error::Decoder(format!("{:?}", e)))?;
481
482                Ok(Some((bls_pub_key, vrf_pub_key)))
483            }
484            length => Err(Error::Decoder(format!(
485                "Hello has incorrect rlp length: {:?}",
486                length
487            ))
488            .into()),
489        }
490    }
491
492    /// Assemble a packet with specified protocol id, packet id and data.
493    /// Return concrete error if session is expired or the protocol id is
494    /// invalid.
495    fn prepare_packet(
496        &self, protocol: Option<ProtocolId>, packet_id: u8, data: Vec<u8>,
497    ) -> Result<Vec<u8>, Error> {
498        if protocol.is_some() && self.had_hello.is_none() {
499            debug!(
500                "Sending to unconfirmed session {}, protocol: {:?}, packet: {}",
501                self.token(),
502                protocol
503                    .as_ref()
504                    .map(|p| str::from_utf8(&p[..]).unwrap_or("???")),
505                packet_id
506            );
507            bail!(Error::Expired);
508        }
509
510        if self.expired() {
511            return Err(Error::Expired.into());
512        }
513
514        Ok(SessionPacket::assemble(
515            packet_id,
516            self.metadata.peer_header_version,
517            protocol,
518            data,
519        ))
520    }
521
522    #[inline]
523    pub fn check_message_protocol_version(
524        &self, protocol: Option<ProtocolId>,
525        min_protocol_version: ProtocolVersion, mut msg: &[u8],
526    ) -> Result<(), Error> {
527        // min_protocol_version is the version when the Message is introduced.
528        // peer protocol version must be higher.
529        if let Some(protocol) = protocol {
530            for peer_protocol in &self.metadata.peer_protocols {
531                if protocol.eq(&peer_protocol.protocol) {
532                    if min_protocol_version <= peer_protocol.version {
533                        break;
534                    } else {
535                        bail!(Error::SendUnsupportedMessage {
536                            protocol,
537                            msg_id: parse_msg_id_leb128_2_bytes_at_most(
538                                &mut msg
539                            ),
540                            peer_protocol_version: Some(peer_protocol.version),
541                            min_supported_version: None,
542                        });
543                    }
544                }
545            }
546        }
547
548        Ok(())
549    }
550
551    /// Send a packet to remote peer asynchronously.
552    pub fn send_packet<Message: Send + Sync + Clone>(
553        &mut self, io: &IoContext<Message>, protocol: Option<ProtocolId>,
554        min_proto_version: ProtocolVersion, packet_id: u8, data: Vec<u8>,
555        priority: SendQueuePriority,
556    ) -> Result<SendQueueStatus, Error> {
557        self.check_message_protocol_version(
558            protocol.clone(),
559            min_proto_version,
560            &data,
561        )?;
562        let packet = self.prepare_packet(protocol, packet_id, data)?;
563        self.connection_mut().send(io, packet, priority)
564    }
565
566    /// Send a packet to remote peer immediately.
567    pub fn send_packet_immediately(
568        &mut self, protocol: Option<ProtocolId>,
569        min_proto_version: ProtocolVersion, packet_id: u8, data: Vec<u8>,
570    ) -> Result<usize, Error> {
571        self.check_message_protocol_version(
572            protocol.clone(),
573            min_proto_version,
574            &data,
575        )?;
576        let packet = self.prepare_packet(protocol, packet_id, data)?;
577        self.connection_mut().write_raw_data(packet)
578    }
579
580    /// Send a Disconnect packet immediately to the remote peer.
581    pub fn send_disconnect(&mut self, reason: DisconnectReason) -> Error {
582        let packet = rlp::encode(&reason);
583        let _ = self.send_packet_immediately(
584            None,
585            ProtocolVersion::default(),
586            PACKET_DISCONNECT,
587            packet,
588        );
589        Error::Disconnect(reason).into()
590    }
591
592    /// Send Hello packet to remote peer.
593    fn write_hello<Message: Send + Sync + Clone>(
594        &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
595    ) -> Result<(), Error> {
596        debug!("Sending Hello, session = {:?}", self);
597        let mut rlp = RlpStream::new_list(4);
598        rlp.append(&host.metadata.network_id);
599        rlp.append_list(&*host.metadata.protocols.read());
600        host.metadata.public_endpoint.to_rlp_list(&mut rlp);
601        let mut key_bytes =
602            self.pos_public_key.as_ref().unwrap().0.to_bytes().to_vec();
603        key_bytes.append(
604            &mut self.pos_public_key.as_ref().unwrap().1.to_bytes().to_vec(),
605        );
606        rlp.append(&key_bytes);
607        self.send_packet(
608            io,
609            None,
610            ProtocolVersion::default(),
611            PACKET_HELLO,
612            rlp.drain(),
613            SendQueuePriority::High,
614        )
615        .map(|_| ())
616    }
617
618    /// Writable IO handler. Sends pending packets.
619    pub fn writable<Message: Send + Sync + Clone>(
620        &mut self, io: &IoContext<Message>,
621    ) -> Result<(), Error> {
622        let status = self.connection_mut().writable(io)?;
623        self.last_write = (Instant::now(), status);
624        Ok(())
625    }
626
627    /// Get the user friendly information of session.
628    /// This is specially for Debug RPC.
629    pub fn details(&self) -> SessionDetails {
630        SessionDetails {
631            originated: self.metadata.originated,
632            node_id: self.metadata.id,
633            address: self.address,
634            connection: self.connection().details(),
635            status: if let Some(time) = self.expired {
636                format!("expired ({:?})", time.elapsed())
637            } else if let Some(time) = self.had_hello {
638                format!("communicating ({:?})", time.elapsed())
639            } else {
640                format!("handshaking ({:?})", self.sent_hello.elapsed())
641            },
642            last_read: format!("{:?}", self.last_read.elapsed()),
643            last_write: format!("{:?}", self.last_write.0.elapsed()),
644            last_write_status: format!("{:?}", self.last_write.1),
645        }
646    }
647
648    /// Check if the session is timeout.
649    /// Once a session is timeout during handshake or exchanging Hello packet,
650    /// the TCP connection should be disconnected timely.
651    ///
652    /// Note, there is no periodical Ping/Pong mechanism to check if the session
653    /// is inactive for a long time. The synchronization protocol handler has
654    /// heartbeat mechanism to exchange peer status. As a result, Inactive
655    /// sessions (e.g. network issue) will be disconnected timely.
656    pub fn check_timeout(&self) -> (bool, Option<UpdateNodeOperation>) {
657        if let Some(time) = self.expired {
658            // should disconnected timely once expired
659            if time.elapsed() > Duration::from_secs(5) {
660                return (true, None);
661            }
662        } else if self.had_hello.is_none() {
663            // should receive HELLO packet timely after session created
664            if self.sent_hello.elapsed() > Duration::from_secs(300) {
665                return (true, Some(UpdateNodeOperation::Failure));
666            }
667        }
668
669        (false, None)
670    }
671}
672
673impl fmt::Debug for Session {
674    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
675        write!(f, "Session {{ token: {}, id: {:?}, originated: {}, address: {:?}, had_hello: {}, expired: {} }}",
676               self.token(), self.id(), self.metadata.originated, self.address, self.had_hello.is_some(), self.expired.is_some())
677    }
678}
679
680/// User friendly session information that used for Debug RPC.
681#[derive(Serialize, Deserialize)]
682#[serde(rename_all = "camelCase")]
683pub struct SessionDetails {
684    pub originated: bool,
685    pub node_id: Option<NodeId>,
686    pub address: SocketAddr,
687    pub connection: ConnectionDetails,
688    pub status: String,
689    pub last_read: String,
690    pub last_write: String,
691    pub last_write_status: String,
692}
693
694/// MovableWrapper is a util to move a value out of a struct.
695/// It is used to move the `Connection` instance when session state changed.
696struct MovableWrapper<T> {
697    item: Option<T>,
698}
699
700impl<T> MovableWrapper<T> {
701    fn new(item: T) -> Self { MovableWrapper { item: Some(item) } }
702
703    fn get(&self) -> &T {
704        match self.item {
705            Some(ref item) => item,
706            None => panic!("cannot get moved item"),
707        }
708    }
709
710    fn get_mut(&mut self) -> &mut T {
711        match self.item {
712            Some(ref mut item) => item,
713            None => panic!("cannot get_mut moved item"),
714        }
715    }
716
717    fn take(&mut self) -> T {
718        if self.item.is_none() {
719            panic!("cannot take moved item")
720        }
721
722        self.item.take().expect("should have value")
723    }
724}
725
726/// Session packet is composed of packet id, optional protocol id and data.
727/// To avoid memory copy, especially when the data size is very big (e.g. 4MB),
728/// packet id and protocol id are appended in the end of data.
729///
730/// The packet format is:
731///     [data (0 to more bytes) || header]
732///
733/// The header format is:
734/// [  extensions (0 to more bytes) || protocol (0 or 3 bytes if protocol_flag)
735///   || reserved (3 bit), has_extension (1 bit), header_version (3 bit),
736///      protocol_flag (1 bit)
737///   || packet_id]
738///
739/// The protocol format is:
740///     [ protocol_id (3 bytes)]
741///
742/// The extensions format is:
743/// [ extension data (0 to more bytes)
744///   || extension data length (7 bit) | has_next_extension (1 bit)
745/// ]
746#[derive(Eq, PartialEq)]
747struct SessionPacket {
748    pub id: u8,
749    pub protocol: Option<ProtocolId>,
750    pub data: Bytes,
751    pub header_version: u8,
752    pub extensions: Vec<Vec<u8>>,
753}
754
755impl SessionPacket {
756    // data + Option<protocol> + protocol_flag + packet_id
757    fn assemble(
758        id: u8, header_version: u8, protocol: Option<ProtocolId>,
759        mut data: Vec<u8>,
760    ) -> Vec<u8> {
761        let mut protocol_flag = 0;
762        if let Some(protocol) = protocol {
763            data.extend_from_slice(&protocol);
764            protocol_flag = 1;
765        }
766
767        let header_byte = (header_version << 1) + protocol_flag;
768        data.push(header_byte);
769        data.push(id);
770
771        data
772    }
773
774    fn parse(mut data: Bytes) -> Result<Self, Error> {
775        // packet id
776        if data.is_empty() {
777            debug!("failed to parse session packet, packet id missed");
778            return Err(Error::BadProtocol.into());
779        }
780
781        let packet_id = data.split_off(data.len() - 1)[0];
782
783        // protocol flag
784        if data.is_empty() {
785            debug!("failed to parse session packet, protocol flag missed");
786            return Err(Error::BadProtocol.into());
787        }
788
789        let header_byte = data.split_off(data.len() - 1)[0];
790        let protocol_flag = header_byte & 1;
791        let header_version = (header_byte & 0x0f) >> 1;
792        if header_version > HEADER_VERSION_WITH_EXTENSION {
793            debug!("unsupported header_version {}", header_version);
794            return Err(Error::BadProtocol.into());
795        }
796        let has_extension = (header_byte & 0x10) >> 4;
797
798        // without protocol
799        if protocol_flag == 0 {
800            if packet_id == PACKET_USER {
801                debug!("failed to parse session packet, no protocol for user packet");
802                return Err(Error::BadProtocol.into());
803            }
804
805            let (data, extensions) =
806                Self::parse_extensions(data, has_extension != 0)?;
807
808            return Ok(SessionPacket {
809                id: packet_id,
810                header_version,
811                protocol: None,
812                data,
813                extensions,
814            });
815        }
816
817        if packet_id != PACKET_USER {
818            debug!("failed to parse session packet, invalid packet id");
819            return Err(Error::BadProtocol.into());
820        }
821
822        // protocol
823        if data.len() < PROTOCOL_ID_SIZE {
824            debug!("failed to parse session packet, protocol missed");
825            return Err(Error::BadProtocol.into());
826        }
827
828        let protocol_bytes = data.split_off(data.len() - PROTOCOL_ID_SIZE);
829        let mut protocol = ProtocolId::default();
830        protocol.copy_from_slice(&protocol_bytes);
831
832        // extensions
833        let (data, extensions) =
834            Self::parse_extensions(data, has_extension != 0)?;
835
836        Ok(SessionPacket {
837            id: packet_id,
838            protocol: Some(protocol),
839            header_version,
840            data,
841            extensions,
842        })
843    }
844
845    fn parse_extensions(
846        mut data: Bytes, mut has_extension: bool,
847    ) -> Result<(Bytes, Vec<Vec<u8>>), Error> {
848        let mut extensions = Vec::new();
849        while has_extension {
850            let extension_byte = data.split_off(data.len() - 1)[0];
851            let extension_len = (extension_byte >> 1) as usize;
852            has_extension = (extension_byte & 1) != 0;
853            if data.len() < extension_len {
854                debug!("failed to parse session packet, not enough bytes for extension.");
855                bail!(Error::BadProtocol);
856            }
857            extensions
858                .push(data.split_off(data.len() - extension_len).to_vec());
859        }
860
861        Ok((data, extensions))
862    }
863}
864
865impl fmt::Debug for SessionPacket {
866    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
867        write!(
868            f,
869            "SessionPacket {{ id: {}, protocol: {:?}, date_len: {} }}",
870            self.id,
871            self.protocol,
872            self.data.len()
873        )
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    #[test]
882    fn test_packet_assemble() {
883        let packet =
884            SessionPacket::assemble(5, PACKET_HEADER_VERSION, None, vec![1, 3]);
885        assert_eq!(packet, vec![1, 3, 0, 5]);
886
887        let packet = SessionPacket::assemble(
888            6,
889            PACKET_HEADER_VERSION,
890            Some([8; 3]),
891            vec![2, 4],
892        );
893        assert_eq!(packet, vec![2, 4, 8, 8, 8, 1, 6]);
894    }
895
896    #[test]
897    fn test_packet_parse() {
898        // packet id missed
899        assert!(SessionPacket::parse(vec![].into()).is_err());
900
901        // protocol flag missed
902        assert!(SessionPacket::parse(vec![1].into()).is_err());
903
904        // protocol flag invalid
905        assert!(SessionPacket::parse(vec![2, 1].into()).is_err());
906
907        // user packet without protocol
908        assert!(SessionPacket::parse(vec![0, PACKET_USER].into()).is_err());
909
910        // packet without protocol
911        let packet = SessionPacket::parse(vec![1, 2, 0, 20].into()).unwrap();
912        assert_eq!(
913            packet,
914            SessionPacket {
915                id: 20,
916                header_version: 0,
917                protocol: None,
918                data: vec![1, 2].into(),
919                extensions: vec![],
920            }
921        );
922
923        // non user packet with protocol
924        assert!(SessionPacket::parse(vec![6, 6, 6, 1, 7].into()).is_err());
925
926        // user packet, but protocol length is not enough
927        assert!(
928            SessionPacket::parse(vec![6, 6, 1, PACKET_USER].into()).is_err()
929        );
930
931        // user packet with protocol
932        let packet =
933            SessionPacket::parse(vec![1, 9, 3, 3, 3, 1, PACKET_USER].into())
934                .unwrap();
935        assert_eq!(
936            packet,
937            SessionPacket {
938                id: PACKET_USER,
939                header_version: 0,
940                protocol: Some([3; 3]),
941                data: vec![1, 9].into(),
942                extensions: vec![],
943            }
944        );
945    }
946}