1use crate::{
6 hash::keccak,
7 node_database::NodeDatabase,
8 node_table::{NodeId, *},
9 service::{UdpIoContext, MAX_DATAGRAM_SIZE, UDP_PROTOCOL_DISCOVERY},
10 DiscoveryConfiguration, Error, IpFilter, ThrottlingReason,
11 NODE_TAG_ARCHIVE, NODE_TAG_NODE_TYPE,
12};
13use cfx_bytes::Bytes;
14use cfx_types::{H256, H520};
15use cfx_util_macros::bail;
16use cfxkey::{recover, sign, KeyPair, Secret};
17use log::{debug, trace, warn};
18use rlp::{Encodable, Rlp, RlpStream};
19use rlp_derive::{RlpDecodable, RlpEncodable};
20use std::{
21 collections::{hash_map::Entry, HashMap, HashSet},
22 net::{IpAddr, SocketAddr},
23 time::{Instant, SystemTime, UNIX_EPOCH},
24};
25use throttling::time_window_bucket::TimeWindowBucket;
26
27const DISCOVER_PROTOCOL_VERSION: u32 = 1;
28
29const DISCOVERY_MAX_STEPS: u16 = 4; const PACKET_PING: u8 = 1;
32const PACKET_PONG: u8 = 2;
33const PACKET_FIND_NODE: u8 = 3;
34const PACKET_NEIGHBOURS: u8 = 4;
35
36struct PingRequest {
37 sent_at: Instant,
39 node: NodeEntry,
41 echo_hash: H256,
43}
44
45struct FindNodeRequest {
46 sent_at: Instant,
48 num_chunks: usize,
50 received_chunks: HashSet<usize>,
52}
53
54impl Default for FindNodeRequest {
55 fn default() -> Self {
56 FindNodeRequest {
57 sent_at: Instant::now(),
58 num_chunks: 0,
59 received_chunks: HashSet::new(),
60 }
61 }
62}
63
64impl FindNodeRequest {
65 fn is_completed(&self) -> bool {
66 self.num_chunks > 0 && self.num_chunks == self.received_chunks.len()
67 }
68}
69
70#[allow(dead_code)]
71pub struct Discovery {
72 id: NodeId,
73 id_hash: H256,
74 secret: Secret,
75 public_endpoint: NodeEndpoint,
76 discovery_initiated: bool,
77 discovery_round: Option<u16>,
78 discovery_nodes: HashSet<NodeId>,
79 in_flight_pings: HashMap<NodeId, PingRequest>,
80 in_flight_find_nodes: HashMap<NodeId, FindNodeRequest>,
81 check_timestamps: bool,
82 adding_nodes: Vec<NodeEntry>,
83 ip_filter: IpFilter,
84 pub disc_option: DiscoveryOption,
85
86 ping_throttling: TimeWindowBucket<IpAddr>,
88 find_nodes_throttling: TimeWindowBucket<IpAddr>,
89
90 config: DiscoveryConfiguration,
91}
92
93impl Discovery {
94 pub fn new(
95 key: &KeyPair, public: NodeEndpoint, ip_filter: IpFilter,
96 config: DiscoveryConfiguration,
97 ) -> Discovery {
98 Discovery {
99 id: key.public().clone(),
100 id_hash: keccak(key.public()),
101 secret: key.secret().clone(),
102 public_endpoint: public,
103 discovery_initiated: false,
104 discovery_round: None,
105 discovery_nodes: HashSet::new(),
106 in_flight_pings: HashMap::new(),
107 in_flight_find_nodes: HashMap::new(),
108 check_timestamps: true,
109 adding_nodes: Vec::new(),
110 ip_filter,
111 disc_option: DiscoveryOption {
112 general: true,
113 archive: false,
114 },
115 ping_throttling: TimeWindowBucket::new(
116 config.throttling_interval,
117 config.throttling_limit_ping,
118 ),
119 find_nodes_throttling: TimeWindowBucket::new(
120 config.throttling_interval,
121 config.throttling_limit_find_nodes,
122 ),
123 config,
124 }
125 }
126
127 fn is_allowed(&self, entry: &NodeEntry) -> bool {
128 entry.endpoint.is_allowed(&self.ip_filter) && entry.id != self.id
129 }
130
131 pub fn try_ping_nodes(
132 &mut self, uio: &UdpIoContext, nodes: Vec<NodeEntry>,
133 ) {
134 for node in nodes {
135 self.try_ping(uio, node);
136 }
137 }
138
139 fn try_ping(&mut self, uio: &UdpIoContext, node: NodeEntry) {
140 if !self.is_allowed(&node) {
141 trace!("Node {:?} not allowed", node);
142 return;
143 }
144 if self.in_flight_pings.contains_key(&node.id)
145 || self.in_flight_find_nodes.contains_key(&node.id)
146 {
147 trace!("Node {:?} in flight requests", node);
148 return;
149 }
150 if self.adding_nodes.iter().any(|n| n.id == node.id) {
151 trace!("Node {:?} in adding nodes", node);
152 return;
153 }
154
155 if self.in_flight_pings.len() < self.config.max_nodes_ping {
156 self.ping(uio, &node).unwrap_or_else(|e| {
157 warn!("Error sending Ping packet: {:?}", e);
158 });
159 } else {
160 self.adding_nodes.push(node);
161 }
162 }
163
164 fn ping(
165 &mut self, uio: &UdpIoContext, node: &NodeEntry,
166 ) -> Result<(), Error> {
167 let mut rlp = RlpStream::new_list(4);
168 rlp.append(&DISCOVER_PROTOCOL_VERSION);
169 self.public_endpoint.to_rlp_list(&mut rlp);
170 node.endpoint.to_rlp_list(&mut rlp);
171 rlp.append(&self.config.expire_timestamp());
172 let hash = self.send_packet(
173 uio,
174 PACKET_PING,
175 &node.endpoint.udp_address(),
176 &rlp.drain(),
177 )?;
178
179 self.in_flight_pings.insert(
180 node.id.clone(),
181 PingRequest {
182 sent_at: Instant::now(),
183 node: node.clone(),
184 echo_hash: hash,
185 },
186 );
187
188 trace!("Sent Ping to {:?} ; node_id={:#x}", &node.endpoint, node.id);
189 Ok(())
190 }
191
192 fn send_packet(
193 &mut self, uio: &UdpIoContext, packet_id: u8, address: &SocketAddr,
194 payload: &[u8],
195 ) -> Result<H256, Error> {
196 let packet = assemble_packet(packet_id, payload, &self.secret)?;
197 let hash = H256::from_slice(&packet[1..=32]);
198 self.send_to(uio, packet, address.clone());
199 Ok(hash)
200 }
201
202 fn send_to(
203 &mut self, uio: &UdpIoContext, payload: Bytes, address: SocketAddr,
204 ) {
205 uio.send(payload, address);
206 }
207
208 pub fn on_packet(
209 &mut self, uio: &UdpIoContext, packet: &[u8], from: SocketAddr,
210 ) -> Result<(), Error> {
211 if packet.len() < 32 + 65 + 4 + 1 {
213 return Err(Error::BadProtocol.into());
214 }
215
216 let hash_signed = keccak(&packet[32..]);
217 if hash_signed[..] != packet[0..32] {
218 return Err(Error::BadProtocol.into());
219 }
220
221 let signed = &packet[(32 + 65)..];
222 let signature = H520::from_slice(&packet[32..(32 + 65)]);
223 let node_id = recover(&signature.into(), &keccak(signed))?;
224
225 let packet_id = signed[0];
226 let rlp = Rlp::new(&signed[1..]);
227 match packet_id {
228 PACKET_PING => {
229 self.on_ping(uio, &rlp, &node_id, &from, hash_signed.as_bytes())
230 }
231 PACKET_PONG => self.on_pong(uio, &rlp, &node_id, &from),
232 PACKET_FIND_NODE => self.on_find_node(uio, &rlp, &node_id, &from),
233 PACKET_NEIGHBOURS => self.on_neighbours(uio, &rlp, &node_id, &from),
234 _ => {
235 debug!("Unknown UDP packet: {}", packet_id);
236 Ok(())
237 }
238 }
239 }
240
241 fn check_timestamp(&self, timestamp: u64) -> Result<(), Error> {
244 let secs_since_epoch = SystemTime::now()
245 .duration_since(UNIX_EPOCH)
246 .unwrap_or_default()
247 .as_secs();
248 if self.check_timestamps && timestamp < secs_since_epoch {
249 debug!("Expired packet");
250 return Err(Error::Expired.into());
251 }
252 Ok(())
253 }
254
255 fn on_ping(
256 &mut self, uio: &UdpIoContext, rlp: &Rlp, node_id: &NodeId,
257 from: &SocketAddr, echo_hash: &[u8],
258 ) -> Result<(), Error> {
259 trace!("Got Ping from {:?}", &from);
260
261 if !self.ping_throttling.try_acquire(from.ip()) {
262 return Err(Error::Throttling(ThrottlingReason::PacketThrottled(
263 "PING",
264 ))
265 .into());
266 }
267
268 let ping_from = NodeEndpoint::from_rlp(&rlp.at(1)?)?;
269 let ping_to = NodeEndpoint::from_rlp(&rlp.at(2)?)?;
270 let timestamp: u64 = rlp.val_at(3)?;
271 self.check_timestamp(timestamp)?;
272
273 let mut response = RlpStream::new_list(3);
274 let pong_to = NodeEndpoint {
275 address: from.clone(),
276 udp_port: ping_from.udp_port,
277 };
278 ping_to.to_rlp_list(&mut response);
285 response.append(&echo_hash);
288 response.append(&self.config.expire_timestamp());
289 self.send_packet(uio, PACKET_PONG, from, &response.drain())?;
290
291 let entry = NodeEntry {
292 id: node_id.clone(),
293 endpoint: pong_to,
294 };
295 if !entry.endpoint.is_valid() {
297 debug!("Got bad address: {:?}", entry);
298 } else if !self.is_allowed(&entry) {
299 debug!("Address not allowed: {:?}", entry);
300 } else {
301 uio.node_db
302 .write()
303 .note_success(node_id, None, false );
304 }
305 Ok(())
306 }
307
308 fn on_pong(
309 &mut self, uio: &UdpIoContext, rlp: &Rlp, node_id: &NodeId,
310 from: &SocketAddr,
311 ) -> Result<(), Error> {
312 trace!("Got Pong from {:?} ; node_id={:#x}", &from, node_id);
313 let _pong_to = NodeEndpoint::from_rlp(&rlp.at(0)?)?;
314 let echo_hash: H256 = rlp.val_at(1)?;
315 let timestamp: u64 = rlp.val_at(2)?;
316 self.check_timestamp(timestamp)?;
317
318 let expected_node = match self.in_flight_pings.entry(*node_id) {
319 Entry::Occupied(entry) => {
320 let expected_node = {
321 let request = entry.get();
322 if request.echo_hash != echo_hash {
323 debug!("Got unexpected Pong from {:?} ; packet_hash={:#x} ; expected_hash={:#x}", &from, request.echo_hash, echo_hash);
324 None
325 } else {
326 Some(request.node.clone())
327 }
328 };
329
330 if expected_node.is_some() {
331 entry.remove();
332 }
333 expected_node
334 }
335 Entry::Vacant(_) => None,
336 };
337
338 if let Some(node) = expected_node {
339 uio.node_db.write().insert_with_conditional_promotion(node);
340 Ok(())
341 } else {
342 debug!("Got unexpected Pong from {:?} ; request not found", &from);
343 Ok(())
344 }
345 }
346
347 fn on_find_node(
348 &mut self, uio: &UdpIoContext, rlp: &Rlp, _node: &NodeId,
349 from: &SocketAddr,
350 ) -> Result<(), Error> {
351 trace!("Got FindNode from {:?}", &from);
352
353 if !self.find_nodes_throttling.try_acquire(from.ip()) {
354 return Err(Error::Throttling(ThrottlingReason::PacketThrottled(
355 "FIND_NODES",
356 ))
357 .into());
358 }
359
360 let msg: FindNodeMessage = rlp.as_val()?;
361 self.check_timestamp(msg.expire_timestamp)?;
362 let neighbors = msg.sample(
363 &*uio.node_db.read(),
364 &self.ip_filter,
365 self.config.discover_node_count,
366 )?;
367
368 trace!("Sample {} Neighbours for {:?}", neighbors.len(), &from);
369
370 let chunk_size = (MAX_DATAGRAM_SIZE - (1 + 109)) / 90;
371 let chunks = NeighborsChunkMessage::chunks(neighbors, chunk_size);
372
373 for chunk in &chunks {
374 self.send_packet(uio, PACKET_NEIGHBOURS, from, &chunk.rlp_bytes())?;
375 }
376
377 trace!("Sent {} Neighbours chunks to {:?}", chunks.len(), &from);
378 Ok(())
379 }
380
381 fn on_neighbours(
382 &mut self, uio: &UdpIoContext, rlp: &Rlp, node_id: &NodeId,
383 from: &SocketAddr,
384 ) -> Result<(), Error> {
385 let mut entry = match self.in_flight_find_nodes.entry(*node_id) {
386 Entry::Occupied(entry) => entry,
387 Entry::Vacant(_) => {
388 debug!("Got unexpected Neighbors from {:?} ; couldn't find node_id={:#x}", &from, node_id);
389 return Ok(());
390 }
391 };
392
393 let msg: NeighborsChunkMessage = rlp.as_val()?;
394 let request = entry.get_mut();
395
396 if !msg.update(request)? {
397 return Ok(());
398 }
399
400 if request.is_completed() {
401 entry.remove();
402 }
403
404 trace!("Got {} Neighbours from {:?}", msg.neighbors.len(), &from);
405
406 for node in msg.neighbors {
407 if !node.endpoint.is_valid() {
408 debug!("Bad address: {:?}", node.endpoint);
409 continue;
410 }
411 if node.id == self.id {
412 continue;
413 }
414 if !self.is_allowed(&node) {
415 debug!("Address not allowed: {:?}", node);
416 continue;
417 }
418 self.try_ping(uio, node);
419 }
420
421 Ok(())
422 }
423
424 fn start(&mut self) {
426 trace!("Starting discovery");
427 self.discovery_round = Some(0);
428 self.discovery_nodes.clear();
429 }
430
431 fn stop(&mut self) {
433 trace!("Completing discovery");
434 self.discovery_round = None;
435 self.discovery_nodes.clear();
436 }
437
438 fn check_expired(&mut self, uio: &UdpIoContext, time: Instant) {
439 let mut nodes_to_expire = Vec::new();
440 let ping_timeout = &self.config.ping_timeout;
441 self.in_flight_pings.retain(|node_id, ping_request| {
442 if time.duration_since(ping_request.sent_at) > *ping_timeout {
443 debug!(
444 "Removing expired PING request for node_id={:#x}",
445 node_id
446 );
447 nodes_to_expire.push(*node_id);
448 false
449 } else {
450 true
451 }
452 });
453 let find_node_timeout = &self.config.find_node_timeout;
454 self.in_flight_find_nodes.retain(|node_id, find_node_request| {
455 if time.duration_since(find_node_request.sent_at) > *find_node_timeout {
456 if !find_node_request.is_completed() {
457 debug!("Removing expired FIND NODE request for node_id={:#x}", node_id);
458 nodes_to_expire.push(*node_id);
459 }
460 false
461 } else {
462 true
463 }
464 });
465 for node_id in nodes_to_expire {
466 self.expire_node_request(uio, node_id);
467 }
468 }
469
470 fn expire_node_request(&mut self, uio: &UdpIoContext, node_id: NodeId) {
471 uio.node_db.write().note_failure(
472 &node_id, false, true, );
475 }
476
477 fn update_new_nodes(&mut self, uio: &UdpIoContext) {
478 while self.in_flight_pings.len() < self.config.max_nodes_ping {
479 match self.adding_nodes.pop() {
480 Some(next) => self.try_ping(uio, next),
481 None => break,
482 }
483 }
484 }
485
486 fn discover(&mut self, uio: &UdpIoContext) {
487 let discovery_round = match self.discovery_round {
488 Some(r) => r,
489 None => return,
490 };
491 if discovery_round == DISCOVERY_MAX_STEPS {
492 trace!("Discover stop due to beyond max round count.");
493 self.stop();
494 return;
495 }
496 trace!("Starting round {:?}", self.discovery_round);
497 let mut tried_count = 0;
498
499 if self.disc_option.general {
500 tried_count += self.discover_without_tag(uio);
501 }
502
503 if self.disc_option.archive {
504 let key: String = NODE_TAG_NODE_TYPE.into();
505 let value: String = NODE_TAG_ARCHIVE.into();
506 tried_count += self.discover_with_tag(uio, &key, &value);
507 }
508
509 if tried_count == 0 {
510 trace!("Discovery stop due to 0 tried_count");
511 self.stop();
512 return;
513 }
514 self.discovery_round = Some(discovery_round + 1);
515 }
516
517 fn send_find_node(
518 &mut self, uio: &UdpIoContext, node: &NodeEntry,
519 tag_key: Option<String>, tag_value: Option<String>,
520 ) -> Result<(), Error> {
521 let msg = FindNodeMessage::new(
522 tag_key,
523 tag_value,
524 self.config.expire_timestamp(),
525 );
526
527 self.send_packet(
528 uio,
529 PACKET_FIND_NODE,
530 &node.endpoint.udp_address(),
531 &msg.rlp_bytes(),
532 )?;
533
534 self.in_flight_find_nodes
535 .insert(node.id.clone(), FindNodeRequest::default());
536
537 trace!("Sent FindNode to {:?}", node);
538 Ok(())
539 }
540
541 pub fn round(&mut self, uio: &UdpIoContext) {
542 self.check_expired(uio, Instant::now());
543 self.update_new_nodes(uio);
544
545 if self.discovery_round.is_some() {
546 self.discover(uio);
547 } else if self.in_flight_pings.is_empty() && !self.discovery_initiated {
548 self.discovery_initiated = true;
551 self.refresh();
552 }
553 }
554
555 pub fn refresh(&mut self) {
556 if self.discovery_round.is_none() {
557 self.start();
558 }
559 }
560
561 fn discover_without_tag(&mut self, uio: &UdpIoContext) -> usize {
562 let sampled: Vec<NodeEntry> = uio
563 .node_db
564 .read()
565 .sample_trusted_nodes(
566 self.config.discover_node_count,
567 &self.ip_filter,
568 )
569 .into_iter()
570 .filter(|n| !self.discovery_nodes.contains(&n.id))
571 .collect();
572
573 self.discover_with_nodes(uio, sampled, None, None)
574 }
575
576 fn discover_with_nodes(
577 &mut self, uio: &UdpIoContext, nodes: Vec<NodeEntry>,
578 tag_key: Option<String>, tag_value: Option<String>,
579 ) -> usize {
580 let mut sent = 0;
581
582 for node in nodes {
583 match self.send_find_node(
584 uio,
585 &node,
586 tag_key.clone(),
587 tag_value.clone(),
588 ) {
589 Ok(_) => {
590 self.discovery_nodes.insert(node.id);
591 sent += 1;
592 }
593 Err(e) => {
594 warn!(
595 "Error sending node discovery packet for {:?}: {:?}",
596 node.endpoint, e
597 );
598 }
599 }
600 }
601
602 sent
603 }
604
605 fn discover_with_tag(
606 &mut self, uio: &UdpIoContext, key: &String, value: &String,
607 ) -> usize {
608 let tagged_nodes = uio.node_db.read().sample_trusted_node_ids_with_tag(
609 self.config.discover_node_count / 2,
610 key,
611 value,
612 );
613
614 let count = self.config.discover_node_count - tagged_nodes.len() as u32;
615 let random_nodes = uio
616 .node_db
617 .read()
618 .sample_trusted_node_ids(count, &self.ip_filter);
619
620 let sampled: HashSet<NodeId> = tagged_nodes
621 .into_iter()
622 .chain(random_nodes)
623 .filter(|id| !self.discovery_nodes.contains(id))
624 .collect();
625
626 let sampled_nodes = uio
627 .node_db
628 .read()
629 .get_nodes(sampled, true );
630
631 self.discover_with_nodes(
632 uio,
633 sampled_nodes,
634 Some(key.clone()),
635 Some(value.clone()),
636 )
637 }
638}
639
640fn assemble_packet(
641 packet_id: u8, bytes: &[u8], secret: &Secret,
642) -> Result<Bytes, Error> {
643 let mut packet = Bytes::with_capacity(bytes.len() + 32 + 65 + 1 + 1);
644 packet.push(UDP_PROTOCOL_DISCOVERY);
645 packet.resize(1 + 32 + 65, 0); packet.push(packet_id);
647 packet.extend_from_slice(bytes);
648
649 let hash = keccak(&packet[(1 + 32 + 65)..]);
650 let signature = match sign(secret, &hash) {
651 Ok(s) => s,
652 Err(e) => {
653 warn!("Error signing UDP packet");
654 return Err(Error::from(e));
655 }
656 };
657 packet[(1 + 32)..(1 + 32 + 65)].copy_from_slice(&signature[..]);
658 let signed_hash = keccak(&packet[(1 + 32)..]);
659 packet[1..=32].copy_from_slice(signed_hash.as_bytes());
660 Ok(packet)
661}
662
663pub struct DiscoveryOption {
664 pub general: bool,
666 pub archive: bool,
668}
669
670#[derive(RlpEncodable, RlpDecodable)]
671struct FindNodeMessage {
672 pub tag_key: Option<String>,
673 pub tag_value: Option<String>,
674 pub expire_timestamp: u64,
675}
676
677impl FindNodeMessage {
678 fn new(
679 tag_key: Option<String>, tag_value: Option<String>,
680 expire_timestamp: u64,
681 ) -> Self {
682 FindNodeMessage {
683 tag_key,
684 tag_value,
685 expire_timestamp,
686 }
687 }
688
689 fn sample(
690 &self, node_db: &NodeDatabase, ip_filter: &IpFilter,
691 discover_node_count: u32,
692 ) -> Result<Vec<NodeEntry>, Error> {
693 let key = match self.tag_key {
694 Some(ref key) => key,
695 None => {
696 return Ok(node_db
697 .sample_trusted_nodes(discover_node_count, ip_filter))
698 }
699 };
700
701 let value = match self.tag_value {
702 Some(ref value) => value,
703 None => return Err(Error::BadProtocol.into()),
704 };
705
706 let ids = node_db.sample_trusted_node_ids_with_tag(
707 discover_node_count,
708 key,
709 value,
710 );
711
712 Ok(node_db.get_nodes(ids, true ))
713 }
714}
715
716#[derive(RlpEncodable, RlpDecodable)]
717struct NeighborsChunkMessage {
718 neighbors: Vec<NodeEntry>,
719 num_chunks: usize,
720 chunk_index: usize,
721}
722
723impl NeighborsChunkMessage {
724 fn chunks(
725 neighbors: Vec<NodeEntry>, chunk_size: usize,
726 ) -> Vec<NeighborsChunkMessage> {
727 let chunks = neighbors.chunks(chunk_size);
728 let num_chunks = chunks.len();
729 chunks
730 .enumerate()
731 .map(|(chunk_index, chunk)| NeighborsChunkMessage {
732 neighbors: chunk.to_vec(),
733 num_chunks,
734 chunk_index,
735 })
736 .collect()
737 }
738
739 fn validate(&self) -> Result<(), Error> {
740 if self.neighbors.is_empty() {
741 debug!("invalid NeighborsChunkMessage, neighbors is empty");
742 bail!(Error::BadProtocol);
743 }
744
745 if self.num_chunks == 0 {
746 debug!("invalid NeighborsChunkMessage, num_chunks is zero");
747 bail!(Error::BadProtocol);
748 }
749
750 if self.chunk_index >= self.num_chunks {
751 debug!(
752 "invalid NeighborsChunkMessage, chunk index is invalid, len = {}, index = {}",
753 self.num_chunks, self.chunk_index
754 );
755 bail!(Error::BadProtocol);
756 }
757
758 Ok(())
759 }
760
761 fn update(&self, request: &mut FindNodeRequest) -> Result<bool, Error> {
766 self.validate()?;
767
768 if request.num_chunks == 0 {
769 request.num_chunks = self.num_chunks;
770 } else if request.num_chunks != self.num_chunks {
771 debug!("invalid NeighborsChunkMessage, chunk number mismatch, requested = {}, responded = {}", request.num_chunks, self.num_chunks);
772 bail!(Error::BadProtocol);
773 }
774
775 if !request.received_chunks.insert(self.chunk_index) {
776 debug!("duplicated NeighborsChunkMessage");
777 return Ok(false);
778 }
779
780 Ok(true)
781 }
782}