1use crate::{
6 connection::{Connection, ConnectionDetails, SendQueueStatus, WriteStatus},
7 handshake::Handshake,
8 node_table::{NodeEndpoint, NodeEntry, NodeId},
9 parse_msg_id_leb128_2_bytes_at_most,
10 service::{NetworkServiceInner, ProtocolVersion},
11 DisconnectReason, Error, ProtocolId, ProtocolInfo, SessionMetadata,
12 UpdateNodeOperation, PROTOCOL_ID_SIZE,
13};
14use bytes::Bytes;
15use cfx_util_macros::bail;
16use diem_crypto::{bls::BLS_PUBLIC_KEY_LENGTH, ValidCryptoMaterial};
17use diem_types::validator_config::{ConsensusPublicKey, ConsensusVRFPublicKey};
18use io::{IoContext, StreamToken};
19use log::{debug, trace};
20use mio::{net::TcpStream, Registry, Token};
21use priority_send_queue::SendQueuePriority;
22use rlp::{Rlp, RlpStream};
23use serde::Deserialize;
24use serde_derive::Serialize;
25use std::{
26 collections::HashSet,
27 convert::TryFrom,
28 fmt,
29 net::SocketAddr,
30 str,
31 time::{Duration, Instant},
32};
33
34pub struct Session {
45 pub metadata: SessionMetadata,
47 address: SocketAddr,
49 state: State,
51 sent_hello: Instant,
53 had_hello: Option<Instant>,
55 expired: Option<Instant>,
57
58 last_read: Instant,
60 last_write: (Instant, WriteStatus),
61 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
62}
63
64enum State {
66 Handshake(MovableWrapper<Handshake>),
70 Session(Connection),
72}
73
74pub enum SessionData {
76 None,
78 Ready {
80 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
81 },
82 Message { data: Vec<u8>, protocol: ProtocolId },
85 Continue,
87}
88
89pub struct SessionDataWithDisconnectInfo {
90 pub session_data: SessionData,
91 pub token_to_disconnect: Option<(StreamToken, String)>,
92}
93
94const PACKET_HELLO: u8 = 0x80;
96const MAX_PEER_PROTOCOLS_IN_HELLO: usize = 64;
99const PACKET_DISCONNECT: u8 = 0x01;
101pub const PACKET_USER: u8 = 0x10;
103pub const PACKET_HEADER_VERSION: u8 = 0;
106const HEADER_VERSION_WITH_EXTENSION: u8 = 0;
108
109impl Session {
110 #[allow(clippy::too_many_arguments)]
113 pub fn new<Message: Send + Sync + Clone + 'static>(
114 io: &IoContext<Message>, socket: TcpStream, address: SocketAddr,
115 id: Option<&NodeId>, peer_header_version: u8, token: StreamToken,
116 host: &NetworkServiceInner,
117 pos_public_key: Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>,
118 ) -> Result<Session, Error> {
119 let originated = id.is_some();
120
121 let mut handshake = Handshake::new(token, id, socket);
122 handshake.start(io, &host.metadata)?;
123
124 Ok(Session {
125 metadata: SessionMetadata {
126 id: id.cloned(),
127 peer_protocols: Vec::new(),
128 originated,
129 peer_header_version,
130 },
131 address,
132 state: State::Handshake(MovableWrapper::new(handshake)),
133 sent_hello: Instant::now(),
134 had_hello: None,
135 expired: None,
136 last_read: Instant::now(),
137 last_write: (Instant::now(), WriteStatus::Complete),
138 pos_public_key,
139 })
140 }
141
142 pub fn have_capability(&self, protocol: ProtocolId) -> bool {
143 self.metadata
144 .peer_protocols
145 .iter()
146 .any(|c| c.protocol == protocol)
147 }
148
149 pub fn id(&self) -> Option<&NodeId> { self.metadata.id.as_ref() }
151
152 pub fn originated(&self) -> bool { self.metadata.originated }
153
154 pub fn is_ready(&self) -> bool { self.had_hello.is_some() }
155
156 pub fn expired(&self) -> bool { self.expired.is_some() }
157
158 pub fn set_expired(&mut self) { self.expired = Some(Instant::now()); }
159
160 pub fn done(&self) -> bool {
161 self.expired() && !self.connection().is_sending()
162 }
163
164 fn connection(&self) -> &Connection {
165 match self.state {
166 State::Handshake(ref h) => &h.get().connection,
167 State::Session(ref c) => c,
168 }
169 }
170
171 fn connection_mut(&mut self) -> &mut Connection {
172 match self.state {
173 State::Handshake(ref mut h) => &mut h.get_mut().connection,
174 State::Session(ref mut c) => c,
175 }
176 }
177
178 pub fn token(&self) -> StreamToken { self.connection().token() }
179
180 pub fn address(&self) -> SocketAddr { self.address }
181
182 pub fn register_socket(
185 &mut self, reg: Token, poll_registry: &Registry,
186 ) -> Result<(), Error> {
187 if !self.expired() {
188 self.connection_mut().register_socket(reg, poll_registry)?;
189 }
190
191 Ok(())
192 }
193
194 pub fn update_socket(
196 &mut self, reg: Token, poll_registry: &Registry,
197 ) -> Result<(), Error> {
198 self.connection_mut().update_socket(reg, poll_registry)?;
199 Ok(())
200 }
201
202 pub fn deregister_socket(
204 &mut self, poll_registry: &Registry,
205 ) -> Result<(), Error> {
206 self.connection_mut().deregister_socket(poll_registry)?;
207 Ok(())
208 }
209
210 fn complete_handshake<Message>(
215 &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
216 ) -> Result<(), Error>
217 where Message: Send + Sync + Clone {
218 let wrapper = match self.state {
219 State::Handshake(ref mut h) => h,
220 State::Session(_) => panic!("Unexpected session state"),
221 };
222
223 if self.metadata.id.is_none() {
225 let id = wrapper.get().id;
226
227 if host.node_db.write().evaluate_blacklisted(&id) {
229 return Err(self.send_disconnect(DisconnectReason::Blacklisted));
230 }
231
232 self.metadata.id = Some(id);
233 }
234
235 self.state = State::Session(wrapper.take().connection);
237 self.write_hello(io, host)?;
238
239 Ok(())
240 }
241
242 pub fn readable<Message: Send + Sync + Clone>(
244 &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
245 ) -> Result<SessionDataWithDisconnectInfo, Error> {
246 self.last_read = Instant::now();
248
249 if self.expired() {
250 debug!("cannot read data due to expired, session = {:?}", self);
251 return Ok(SessionDataWithDisconnectInfo {
252 session_data: SessionData::None,
253 token_to_disconnect: None,
254 });
255 }
256
257 match self.state {
258 State::Handshake(ref mut h) => {
259 let h = h.get_mut();
260
261 if !h.readable(io, &host.metadata)? {
262 return Ok(SessionDataWithDisconnectInfo {
263 session_data: SessionData::None,
264 token_to_disconnect: None,
265 });
266 }
267
268 if h.done() {
269 self.complete_handshake(io, host)?;
270 io.update_registration(self.token()).unwrap_or_else(|e| {
271 debug!("Token registration error: {:?}", e)
272 });
273 }
274
275 Ok(SessionDataWithDisconnectInfo {
276 session_data: SessionData::Continue,
277 token_to_disconnect: None,
278 })
279 }
280 State::Session(ref mut c) => match c.readable()? {
281 Some(data) => Ok(self.read_packet(data, host)?),
282 None => Ok(SessionDataWithDisconnectInfo {
283 session_data: SessionData::None,
284 token_to_disconnect: None,
285 }),
286 },
287 }
288 }
289
290 fn read_packet(
292 &mut self, data: Bytes, host: &NetworkServiceInner,
293 ) -> Result<SessionDataWithDisconnectInfo, Error> {
294 let packet = SessionPacket::parse(data)?;
295
296 if packet.id != PACKET_HELLO
299 && packet.id != PACKET_DISCONNECT
300 && self.had_hello.is_none()
301 {
302 return Err(Error::BadProtocol);
303 }
304
305 match packet.id {
306 PACKET_HELLO => {
307 debug!("Read HELLO in session {:?}", self);
308 self.metadata.peer_header_version = packet.header_version;
309 let token_to_disconnect = self.update_ingress_node_id(host)?;
311
312 if token_to_disconnect.as_ref().map(|(t, _)| *t)
315 == Some(self.token())
316 {
317 return Ok(SessionDataWithDisconnectInfo {
318 session_data: SessionData::None,
319 token_to_disconnect,
320 });
321 }
322
323 let rlp = Rlp::new(&packet.data);
325 let pos_public_key = self.read_hello(&rlp, host)?;
326 Ok(SessionDataWithDisconnectInfo {
327 session_data: SessionData::Ready { pos_public_key },
328 token_to_disconnect,
329 })
330 }
331 PACKET_DISCONNECT => {
332 let rlp = Rlp::new(&packet.data);
333 let reason: DisconnectReason = rlp.as_val()?;
334 debug!(
335 "read packet DISCONNECT, reason = {}, session = {:?}",
336 reason, self
337 );
338 Err(Error::Disconnect(reason))
339 }
340 PACKET_USER => Ok(SessionDataWithDisconnectInfo {
341 session_data: SessionData::Message {
342 data: packet.data.to_vec(),
343 protocol: packet
344 .protocol
345 .expect("protocol should available for USER packet"),
346 },
347 token_to_disconnect: None,
348 }),
349 _ => {
350 debug!(
351 "read packet UNKNOWN, packet_id = {:?}, session = {:?}",
352 packet.id, self
353 );
354 Err(Error::BadProtocol)
355 }
356 }
357 }
358
359 fn update_ingress_node_id(
363 &mut self, host: &NetworkServiceInner,
364 ) -> Result<Option<(usize, String)>, Error> {
365 if self.metadata.originated {
367 return Ok(None);
368 }
369
370 let token = self.token();
371 let node_id = self
372 .metadata
373 .id
374 .expect("should have node id after handshake");
375
376 let result = host
377 .sessions
378 .update_ingress_node_id(token, &node_id)
379 .map_err(|reason| {
380 debug!(
381 "failed to update node id of ingress session, reason = {:?}, session = {:?}",
382 reason, self
383 );
384 self.send_disconnect(DisconnectReason::UpdateNodeIdFailed)
385 })?;
386
387 match result {
388 crate::session_manager::UpdateIngressResult::Inserted => Ok(None),
389 crate::session_manager::UpdateIngressResult::Replaced(old) => {
390 Ok(Some((
391 old,
392 String::from("Remove old session from the same node"),
393 )))
394 }
395 crate::session_manager::UpdateIngressResult::DropNew => {
396 debug!(
399 "lost simultaneous-dial tie-break, dropping new ingress session = {:?}",
400 self
401 );
402 let _ = self.send_disconnect(DisconnectReason::Custom(
403 "simultaneous dial: drop new connection".into(),
404 ));
405 Ok(Some((
406 token,
407 String::from("simultaneous dial: drop new connection"),
408 )))
409 }
410 }
411 }
412
413 fn read_hello(
420 &mut self, rlp: &Rlp, host: &NetworkServiceInner,
421 ) -> Result<Option<(ConsensusPublicKey, ConsensusVRFPublicKey)>, Error>
422 {
423 let remote_network_id: u64 = rlp.val_at(0)?;
424 if remote_network_id != host.metadata.network_id {
425 debug!(
426 "failed to read hello, network id mismatch, self = {}, remote = {}",
427 host.metadata.network_id, remote_network_id);
428 return Err(self.send_disconnect(DisconnectReason::Custom(
429 "network id mismatch".into(),
430 )));
431 }
432
433 let peer_caps_count = rlp
437 .at(1)?
438 .iter()
439 .take(MAX_PEER_PROTOCOLS_IN_HELLO + 1)
440 .count();
441 if peer_caps_count > MAX_PEER_PROTOCOLS_IN_HELLO {
442 debug!(
443 "Too many protocols in hello: {}, remote = {}",
444 peer_caps_count, remote_network_id
445 );
446 return Err(self.send_disconnect(DisconnectReason::Custom(
447 "Invalid protocol list: too many protocols.".into(),
448 )));
449 }
450
451 let mut peer_caps: Vec<ProtocolInfo> = rlp.list_at(1)?;
452 let mut seen_protocols = HashSet::with_capacity(peer_caps.len());
453 for cap in &peer_caps {
454 if !seen_protocols.insert(cap.protocol) {
455 debug!(
456 "Invalid protocol list from hello. Duplication: {:?}, \
457 remote = {}",
458 cap.protocol, remote_network_id
459 );
460 return Err(self.send_disconnect(DisconnectReason::Custom(
461 "Invalid protocol list: duplication.".into(),
462 )));
463 }
464 }
465
466 peer_caps.retain(|c| {
467 host.metadata
468 .minimum_peer_protocol_version
469 .read()
470 .iter()
471 .any(|hc| hc.protocol == c.protocol && hc.version <= c.version)
472 });
473
474 self.metadata.peer_protocols = peer_caps;
475 if self.metadata.peer_protocols.is_empty() {
476 debug!("No common capabilities with remote peer, peer_node_id = {:?}, session = {:?}", self.metadata.id, self);
477 return Err(self.send_disconnect(DisconnectReason::UselessPeer));
478 }
479
480 let mut hello_from = NodeEndpoint::from_rlp(&rlp.at(2)?)?;
481 hello_from.address.set_ip(self.address.ip());
485
486 let ping_to = NodeEndpoint {
487 address: hello_from.address,
488 udp_port: hello_from.udp_port,
489 };
490
491 let entry = NodeEntry {
492 id: self
493 .metadata
494 .id
495 .expect("should have node ID after handshake"),
496 endpoint: ping_to,
497 };
498 if !entry.endpoint.is_valid() {
499 debug!("Got invalid endpoint {:?}, session = {:?}", entry, self);
500 return Err(
501 self.send_disconnect(DisconnectReason::WrongEndpointInfo)
502 );
503 } else if !(entry.endpoint.is_allowed(host.get_ip_filter())
504 && entry.id != *host.metadata.id())
505 {
506 debug!(
507 "Address not allowed, endpoint = {:?}, session = {:?}",
508 entry, self
509 );
510 return Err(self.send_disconnect(DisconnectReason::IpLimited));
511 } else {
512 debug!("Received valid endpoint {:?}, session = {:?}", entry, self);
513 host.node_db.write().insert_with_token(entry, self.token());
514 }
515
516 self.had_hello = Some(Instant::now());
517 match rlp.item_count()? {
518 3 => Ok(None),
519 4 => {
520 let pos_public_key_bytes: Vec<u8> = rlp.val_at(3)?;
522 trace!("pos_public_key_bytes: {:?}", pos_public_key_bytes);
523 if pos_public_key_bytes.len() < BLS_PUBLIC_KEY_LENGTH {
524 bail!("pos public key bytes is too short!");
525 }
526 let bls_pub_key = ConsensusPublicKey::try_from(
527 &pos_public_key_bytes[..BLS_PUBLIC_KEY_LENGTH],
528 )
529 .map_err(|e| Error::Decoder(format!("{:?}", e)))?;
530 let vrf_pub_key = ConsensusVRFPublicKey::try_from(
531 &pos_public_key_bytes[BLS_PUBLIC_KEY_LENGTH..],
532 )
533 .map_err(|e| Error::Decoder(format!("{:?}", e)))?;
534
535 Ok(Some((bls_pub_key, vrf_pub_key)))
536 }
537 length => Err(Error::Decoder(format!(
538 "Hello has incorrect rlp length: {:?}",
539 length
540 ))),
541 }
542 }
543
544 fn prepare_packet(
548 &self, protocol: Option<ProtocolId>, packet_id: u8, data: Vec<u8>,
549 ) -> Result<Vec<u8>, Error> {
550 if protocol.is_some() && self.had_hello.is_none() {
551 debug!(
552 "Sending to unconfirmed session {}, protocol: {:?}, packet: {}",
553 self.token(),
554 protocol
555 .as_ref()
556 .map(|p| str::from_utf8(&p[..]).unwrap_or("???")),
557 packet_id
558 );
559 bail!(Error::Expired);
560 }
561
562 if self.expired() {
563 return Err(Error::Expired);
564 }
565
566 Ok(SessionPacket::assemble(
567 packet_id,
568 self.metadata.peer_header_version,
569 protocol,
570 data,
571 ))
572 }
573
574 #[inline]
575 pub fn check_message_protocol_version(
576 &self, protocol: Option<ProtocolId>,
577 min_protocol_version: ProtocolVersion, mut msg: &[u8],
578 ) -> Result<(), Error> {
579 if let Some(protocol) = protocol {
582 for peer_protocol in &self.metadata.peer_protocols {
583 if protocol.eq(&peer_protocol.protocol) {
584 if min_protocol_version <= peer_protocol.version {
585 break;
586 } else {
587 bail!(Error::SendUnsupportedMessage {
588 protocol,
589 msg_id: parse_msg_id_leb128_2_bytes_at_most(
590 &mut msg
591 )
592 .map_err(|_| Error::Msg(
593 "msg_id parse failed when checking protocol version".into()
594 ))?,
595 peer_protocol_version: Some(peer_protocol.version),
596 min_supported_version: None,
597 });
598 }
599 }
600 }
601 }
602
603 Ok(())
604 }
605
606 pub fn send_packet<Message: Send + Sync + Clone>(
608 &mut self, io: &IoContext<Message>, protocol: Option<ProtocolId>,
609 min_proto_version: ProtocolVersion, packet_id: u8, data: Vec<u8>,
610 priority: SendQueuePriority,
611 ) -> Result<SendQueueStatus, Error> {
612 self.check_message_protocol_version(
613 protocol,
614 min_proto_version,
615 &data,
616 )?;
617 let packet = self.prepare_packet(protocol, packet_id, data)?;
618 self.connection_mut().send(io, packet, priority)
619 }
620
621 pub fn send_packet_immediately(
623 &mut self, protocol: Option<ProtocolId>,
624 min_proto_version: ProtocolVersion, packet_id: u8, data: Vec<u8>,
625 ) -> Result<usize, Error> {
626 self.check_message_protocol_version(
627 protocol,
628 min_proto_version,
629 &data,
630 )?;
631 let packet = self.prepare_packet(protocol, packet_id, data)?;
632 self.connection_mut().write_raw_data(packet)
633 }
634
635 pub fn send_disconnect(&mut self, reason: DisconnectReason) -> Error {
637 let packet = rlp::encode(&reason).to_vec();
638 let _ = self.send_packet_immediately(
639 None,
640 ProtocolVersion::default(),
641 PACKET_DISCONNECT,
642 packet,
643 );
644 Error::Disconnect(reason)
645 }
646
647 fn write_hello<Message: Send + Sync + Clone>(
649 &mut self, io: &IoContext<Message>, host: &NetworkServiceInner,
650 ) -> Result<(), Error> {
651 debug!("Sending Hello, session = {:?}", self);
652 let mut rlp = RlpStream::new_list(4);
653 rlp.append(&host.metadata.network_id);
654 rlp.append_list(&host.metadata.protocols.read());
655 host.metadata.public_endpoint.to_rlp_list(&mut rlp);
656 let mut key_bytes =
657 self.pos_public_key.as_ref().unwrap().0.to_bytes().to_vec();
658 key_bytes.append(
659 &mut self.pos_public_key.as_ref().unwrap().1.to_bytes().to_vec(),
660 );
661 rlp.append(&key_bytes);
662 self.send_packet(
663 io,
664 None,
665 ProtocolVersion::default(),
666 PACKET_HELLO,
667 rlp.out().to_vec(),
668 SendQueuePriority::High,
669 )
670 .map(|_| ())
671 }
672
673 pub fn writable<Message: Send + Sync + Clone>(
675 &mut self, io: &IoContext<Message>,
676 ) -> Result<(), Error> {
677 let status = self.connection_mut().writable(io)?;
678 self.last_write = (Instant::now(), status);
679 Ok(())
680 }
681
682 pub fn details(&self) -> SessionDetails {
685 SessionDetails {
686 originated: self.metadata.originated,
687 node_id: self.metadata.id,
688 address: self.address,
689 connection: self.connection().details(),
690 status: if let Some(time) = self.expired {
691 format!("expired ({:?})", time.elapsed())
692 } else if let Some(time) = self.had_hello {
693 format!("communicating ({:?})", time.elapsed())
694 } else {
695 format!("handshaking ({:?})", self.sent_hello.elapsed())
696 },
697 last_read: format!("{:?}", self.last_read.elapsed()),
698 last_write: format!("{:?}", self.last_write.0.elapsed()),
699 last_write_status: format!("{:?}", self.last_write.1),
700 }
701 }
702
703 pub fn check_timeout(&self) -> (bool, Option<UpdateNodeOperation>) {
712 if let Some(time) = self.expired {
713 if time.elapsed() > Duration::from_secs(5) {
715 return (true, None);
716 }
717 } else if self.had_hello.is_none() {
718 if self.sent_hello.elapsed() > Duration::from_secs(300) {
720 return (true, Some(UpdateNodeOperation::Failure));
721 }
722 }
723
724 (false, None)
725 }
726}
727
728impl fmt::Debug for Session {
729 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
730 write!(f, "Session {{ token: {}, id: {:?}, originated: {}, address: {:?}, had_hello: {}, expired: {} }}",
731 self.token(), self.id(), self.metadata.originated, self.address, self.had_hello.is_some(), self.expired())
732 }
733}
734
735#[derive(Serialize, Deserialize, Clone)]
737#[serde(rename_all = "camelCase")]
738pub struct SessionDetails {
739 pub originated: bool,
740 pub node_id: Option<NodeId>,
741 pub address: SocketAddr,
742 pub connection: ConnectionDetails,
743 pub status: String,
744 pub last_read: String,
745 pub last_write: String,
746 pub last_write_status: String,
747}
748
749struct MovableWrapper<T> {
752 item: Option<T>,
753}
754
755impl<T> MovableWrapper<T> {
756 fn new(item: T) -> Self { MovableWrapper { item: Some(item) } }
757
758 fn get(&self) -> &T {
759 match self.item {
760 Some(ref item) => item,
761 None => panic!("cannot get moved item"),
762 }
763 }
764
765 fn get_mut(&mut self) -> &mut T {
766 match self.item {
767 Some(ref mut item) => item,
768 None => panic!("cannot get_mut moved item"),
769 }
770 }
771
772 fn take(&mut self) -> T {
773 if self.item.is_none() {
774 panic!("cannot take moved item")
775 }
776
777 self.item.take().expect("should have value")
778 }
779}
780
781#[derive(Eq, PartialEq)]
802struct SessionPacket {
803 pub id: u8,
804 pub protocol: Option<ProtocolId>,
805 pub data: Bytes,
806 pub header_version: u8,
807 pub extensions: Vec<Vec<u8>>,
808}
809
810impl SessionPacket {
811 fn assemble(
813 id: u8, header_version: u8, protocol: Option<ProtocolId>,
814 mut data: Vec<u8>,
815 ) -> Vec<u8> {
816 let mut protocol_flag = 0;
817 if let Some(protocol) = protocol {
818 data.extend_from_slice(&protocol);
819 protocol_flag = 1;
820 }
821
822 let header_byte = (header_version << 1) + protocol_flag;
823 data.push(header_byte);
824 data.push(id);
825
826 data
827 }
828
829 fn parse(mut data: Bytes) -> Result<Self, Error> {
830 if data.is_empty() {
832 debug!("failed to parse session packet, packet id missed");
833 return Err(Error::BadProtocol);
834 }
835
836 let packet_id = data.split_off(data.len() - 1)[0];
837
838 if data.is_empty() {
840 debug!("failed to parse session packet, protocol flag missed");
841 return Err(Error::BadProtocol);
842 }
843
844 let header_byte = data.split_off(data.len() - 1)[0];
845 let protocol_flag = header_byte & 1;
846 let header_version = (header_byte & 0x0f) >> 1;
847 if header_version > HEADER_VERSION_WITH_EXTENSION {
848 debug!("unsupported header_version {}", header_version);
849 return Err(Error::BadProtocol);
850 }
851 let has_extension = (header_byte & 0x10) >> 4;
852
853 if protocol_flag == 0 {
855 if packet_id == PACKET_USER {
856 debug!("failed to parse session packet, no protocol for user packet");
857 return Err(Error::BadProtocol);
858 }
859
860 let (data, extensions) =
861 Self::parse_extensions(data, has_extension != 0)?;
862
863 return Ok(SessionPacket {
864 id: packet_id,
865 header_version,
866 protocol: None,
867 data,
868 extensions,
869 });
870 }
871
872 if packet_id != PACKET_USER {
873 debug!("failed to parse session packet, invalid packet id");
874 return Err(Error::BadProtocol);
875 }
876
877 if data.len() < PROTOCOL_ID_SIZE {
879 debug!("failed to parse session packet, protocol missed");
880 return Err(Error::BadProtocol);
881 }
882
883 let protocol_bytes = data.split_off(data.len() - PROTOCOL_ID_SIZE);
884 let mut protocol = ProtocolId::default();
885 protocol.copy_from_slice(&protocol_bytes);
886
887 let (data, extensions) =
889 Self::parse_extensions(data, has_extension != 0)?;
890
891 Ok(SessionPacket {
892 id: packet_id,
893 protocol: Some(protocol),
894 header_version,
895 data,
896 extensions,
897 })
898 }
899
900 fn parse_extensions(
901 mut data: Bytes, mut has_extension: bool,
902 ) -> Result<(Bytes, Vec<Vec<u8>>), Error> {
903 let mut extensions = Vec::new();
904 while has_extension {
905 if data.is_empty() {
906 debug!(
907 "failed to parse session packet, extension data exhausted"
908 );
909 bail!(Error::BadProtocol);
910 }
911 let extension_byte = data.split_off(data.len() - 1)[0];
912 let extension_len = (extension_byte >> 1) as usize;
913 has_extension = (extension_byte & 1) != 0;
914 if data.len() < extension_len {
915 debug!("failed to parse session packet, not enough bytes for extension.");
916 bail!(Error::BadProtocol);
917 }
918 extensions
919 .push(data.split_off(data.len() - extension_len).to_vec());
920 }
921
922 Ok((data, extensions))
923 }
924}
925
926impl fmt::Debug for SessionPacket {
927 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
928 write!(
929 f,
930 "SessionPacket {{ id: {}, protocol: {:?}, date_len: {} }}",
931 self.id,
932 self.protocol,
933 self.data.len()
934 )
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941
942 #[test]
943 fn test_packet_assemble() {
944 let packet =
945 SessionPacket::assemble(5, PACKET_HEADER_VERSION, None, vec![1, 3]);
946 assert_eq!(packet, vec![1, 3, 0, 5]);
947
948 let packet = SessionPacket::assemble(
949 6,
950 PACKET_HEADER_VERSION,
951 Some([8; 3]),
952 vec![2, 4],
953 );
954 assert_eq!(packet, vec![2, 4, 8, 8, 8, 1, 6]);
955 }
956
957 #[test]
958 fn test_packet_parse() {
959 assert!(SessionPacket::parse(vec![].into()).is_err());
961
962 assert!(SessionPacket::parse(vec![1].into()).is_err());
964
965 assert!(SessionPacket::parse(vec![2, 1].into()).is_err());
967
968 assert!(SessionPacket::parse(vec![0, PACKET_USER].into()).is_err());
970
971 let packet = SessionPacket::parse(vec![1, 2, 0, 20].into()).unwrap();
973 assert_eq!(
974 packet,
975 SessionPacket {
976 id: 20,
977 header_version: 0,
978 protocol: None,
979 data: vec![1, 2].into(),
980 extensions: vec![],
981 }
982 );
983
984 assert!(SessionPacket::parse(vec![6, 6, 6, 1, 7].into()).is_err());
986
987 assert!(
989 SessionPacket::parse(vec![6, 6, 1, PACKET_USER].into()).is_err()
990 );
991
992 let packet =
994 SessionPacket::parse(vec![1, 9, 3, 3, 3, 1, PACKET_USER].into())
995 .unwrap();
996 assert_eq!(
997 packet,
998 SessionPacket {
999 id: PACKET_USER,
1000 header_version: 0,
1001 protocol: Some([3; 3]),
1002 data: vec![1, 9].into(),
1003 extensions: vec![],
1004 }
1005 );
1006 }
1007}