diem_types/
chain_id.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4// Copyright 2021 Conflux Foundation. All rights reserved.
5// Conflux is free software and distributed under GNU General Public License.
6// See http://www.gnu.org/licenses/
7use anyhow::{ensure, format_err, Error, Result};
8use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
9use std::{convert::TryFrom, fmt, str::FromStr};
10
11/// A registry of named chain IDs
12/// Its main purpose is to improve human readability of reserved chain IDs in
13/// config files and CLI When signing transactions for such chains, the
14/// numerical chain ID should still be used (e.g. MAINNET has numeric chain ID
15/// 1, TESTNET has chain ID 2, etc)
16#[repr(u8)]
17#[derive(Copy, Clone, Debug)]
18pub enum NamedChain {
19    /// Users might accidentally initialize the ChainId field to 0, hence
20    /// reserving ChainId 0 for accidental initialization.
21    /// MAINNET is the Diem mainnet production chain and is reserved for 1
22    MAINNET = 1,
23    // Even though these CHAIN IDs do not correspond to MAINNET, changing them
24    // should be avoided since they can break test environments for
25    // various organisations.
26    TESTNET = 2,
27    DEVNET = 3,
28    TESTING = 4,
29    PREMAINNET = 5,
30}
31
32impl NamedChain {
33    fn str_to_chain_id(s: &str) -> Result<ChainId> {
34        // TODO implement custom macro that derives FromStr impl for enum
35        // (similar to diem/common/num-variants)
36        let reserved_chain = match s {
37            "MAINNET" => NamedChain::MAINNET,
38            "TESTNET" => NamedChain::TESTNET,
39            "DEVNET" => NamedChain::DEVNET,
40            "TESTING" => NamedChain::TESTING,
41            "PREMAINNET" => NamedChain::PREMAINNET,
42            _ => {
43                return Err(format_err!("Not a reserved chain: {:?}", s));
44            }
45        };
46        Ok(ChainId::new(reserved_chain.id()))
47    }
48
49    pub fn id(&self) -> u64 { *self as u64 }
50
51    pub fn from_chain_id(chain_id: &ChainId) -> Result<NamedChain, String> {
52        match chain_id.id() {
53            1 => Ok(NamedChain::MAINNET),
54            2 => Ok(NamedChain::TESTNET),
55            3 => Ok(NamedChain::DEVNET),
56            4 => Ok(NamedChain::TESTING),
57            5 => Ok(NamedChain::PREMAINNET),
58            _ => Err(String::from("Not a named chain")),
59        }
60    }
61}
62
63#[derive(Clone, Copy, Deserialize, Eq, Hash, PartialEq, Serialize)]
64pub struct ChainId(u64);
65
66pub fn deserialize_config_chain_id<'de, D>(
67    deserializer: D,
68) -> std::result::Result<ChainId, D::Error>
69where D: Deserializer<'de> {
70    struct ChainIdVisitor;
71
72    impl<'de> Visitor<'de> for ChainIdVisitor {
73        type Value = ChainId;
74
75        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76            f.write_str("ChainId as string or u8")
77        }
78
79        fn visit_str<E>(
80            self, value: &str,
81        ) -> std::result::Result<Self::Value, E>
82        where E: serde::de::Error {
83            ChainId::from_str(value).map_err(serde::de::Error::custom)
84        }
85
86        fn visit_u64<E>(
87            self, value: u64,
88        ) -> std::result::Result<Self::Value, E>
89        where E: serde::de::Error {
90            Ok(ChainId::new(
91                u64::try_from(value).map_err(serde::de::Error::custom)?,
92            ))
93        }
94    }
95
96    deserializer.deserialize_any(ChainIdVisitor)
97}
98
99impl fmt::Debug for ChainId {
100    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
101        write!(f, "{}", self)
102    }
103}
104
105impl fmt::Display for ChainId {
106    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107        write!(
108            f,
109            "{}",
110            NamedChain::from_chain_id(&self)
111                .map_or_else(|_| self.0.to_string(), |chain| chain.to_string())
112        )
113    }
114}
115
116impl fmt::Display for NamedChain {
117    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118        write!(
119            f,
120            "{}",
121            match self {
122                NamedChain::DEVNET => "DEVNET",
123                NamedChain::TESTNET => "TESTNET",
124                NamedChain::MAINNET => "MAINNET",
125                NamedChain::TESTING => "TESTING",
126                NamedChain::PREMAINNET => "PREMAINNET",
127            }
128        )
129    }
130}
131
132impl Default for ChainId {
133    fn default() -> Self { Self::test() }
134}
135
136impl FromStr for ChainId {
137    type Err = Error;
138
139    fn from_str(s: &str) -> Result<Self> {
140        ensure!(!s.is_empty(), "Cannot create chain ID from empty string");
141        NamedChain::str_to_chain_id(s).or_else(|_err| {
142            let value = s.parse::<u64>()?;
143            ensure!(value > 0, "cannot have chain ID with 0");
144            Ok(ChainId::new(value))
145        })
146    }
147}
148
149impl ChainId {
150    pub fn new(id: u64) -> Self {
151        assert!(id > 0, "cannot have chain ID with 0");
152        Self(id)
153    }
154
155    pub fn id(&self) -> u64 { self.0 }
156
157    pub fn test() -> Self { ChainId::new(NamedChain::TESTING.id()) }
158}
159
160#[cfg(test)]
161mod test {
162    use super::*;
163
164    #[test]
165    fn test_chain_id_from_str() {
166        assert!(ChainId::from_str("").is_err());
167        assert!(ChainId::from_str("0").is_err());
168        // 2^64 overflows.
169        assert!(ChainId::from_str("18446744073709551616").is_err());
170        assert_eq!(ChainId::from_str("TESTING").unwrap(), ChainId::test());
171        assert_eq!(ChainId::from_str("255").unwrap(), ChainId::new(255));
172    }
173}