1use anyhow::{ensure, format_err, Error, Result};
8use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
9use std::{convert::TryFrom, fmt, str::FromStr};
10
11#[repr(u8)]
17#[derive(Copy, Clone, Debug)]
18pub enum NamedChain {
19 MAINNET = 1,
23 TESTNET = 2,
27 DEVNET = 3,
28 TESTING = 4,
29 PREMAINNET = 5,
30}
31
32impl NamedChain {
33 fn str_to_chain_id(s: &str) -> Result<ChainId> {
34 let reserved_chain = match s {
37 "MAINNET" => NamedChain::MAINNET,
38 "TESTNET" => NamedChain::TESTNET,
39 "DEVNET" => NamedChain::DEVNET,
40 "TESTING" => NamedChain::TESTING,
41 "PREMAINNET" => NamedChain::PREMAINNET,
42 _ => {
43 return Err(format_err!("Not a reserved chain: {:?}", s));
44 }
45 };
46 Ok(ChainId::new(reserved_chain.id()))
47 }
48
49 pub fn id(&self) -> u64 { *self as u64 }
50
51 pub fn from_chain_id(chain_id: &ChainId) -> Result<NamedChain, String> {
52 match chain_id.id() {
53 1 => Ok(NamedChain::MAINNET),
54 2 => Ok(NamedChain::TESTNET),
55 3 => Ok(NamedChain::DEVNET),
56 4 => Ok(NamedChain::TESTING),
57 5 => Ok(NamedChain::PREMAINNET),
58 _ => Err(String::from("Not a named chain")),
59 }
60 }
61}
62
63#[derive(Clone, Copy, Deserialize, Eq, Hash, PartialEq, Serialize)]
64pub struct ChainId(u64);
65
66pub fn deserialize_config_chain_id<'de, D>(
67 deserializer: D,
68) -> std::result::Result<ChainId, D::Error>
69where D: Deserializer<'de> {
70 struct ChainIdVisitor;
71
72 impl<'de> Visitor<'de> for ChainIdVisitor {
73 type Value = ChainId;
74
75 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.write_str("ChainId as string or u8")
77 }
78
79 fn visit_str<E>(
80 self, value: &str,
81 ) -> std::result::Result<Self::Value, E>
82 where E: serde::de::Error {
83 ChainId::from_str(value).map_err(serde::de::Error::custom)
84 }
85
86 fn visit_u64<E>(
87 self, value: u64,
88 ) -> std::result::Result<Self::Value, E>
89 where E: serde::de::Error {
90 Ok(ChainId::new(
91 u64::try_from(value).map_err(serde::de::Error::custom)?,
92 ))
93 }
94 }
95
96 deserializer.deserialize_any(ChainIdVisitor)
97}
98
99impl fmt::Debug for ChainId {
100 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
101 write!(f, "{}", self)
102 }
103}
104
105impl fmt::Display for ChainId {
106 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
107 write!(
108 f,
109 "{}",
110 NamedChain::from_chain_id(&self)
111 .map_or_else(|_| self.0.to_string(), |chain| chain.to_string())
112 )
113 }
114}
115
116impl fmt::Display for NamedChain {
117 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118 write!(
119 f,
120 "{}",
121 match self {
122 NamedChain::DEVNET => "DEVNET",
123 NamedChain::TESTNET => "TESTNET",
124 NamedChain::MAINNET => "MAINNET",
125 NamedChain::TESTING => "TESTING",
126 NamedChain::PREMAINNET => "PREMAINNET",
127 }
128 )
129 }
130}
131
132impl Default for ChainId {
133 fn default() -> Self { Self::test() }
134}
135
136impl FromStr for ChainId {
137 type Err = Error;
138
139 fn from_str(s: &str) -> Result<Self> {
140 ensure!(!s.is_empty(), "Cannot create chain ID from empty string");
141 NamedChain::str_to_chain_id(s).or_else(|_err| {
142 let value = s.parse::<u64>()?;
143 ensure!(value > 0, "cannot have chain ID with 0");
144 Ok(ChainId::new(value))
145 })
146 }
147}
148
149impl ChainId {
150 pub fn new(id: u64) -> Self {
151 assert!(id > 0, "cannot have chain ID with 0");
152 Self(id)
153 }
154
155 pub fn id(&self) -> u64 { self.0 }
156
157 pub fn test() -> Self { ChainId::new(NamedChain::TESTING.id()) }
158}
159
160#[cfg(test)]
161mod test {
162 use super::*;
163
164 #[test]
165 fn test_chain_id_from_str() {
166 assert!(ChainId::from_str("").is_err());
167 assert!(ChainId::from_str("0").is_err());
168 assert!(ChainId::from_str("18446744073709551616").is_err());
170 assert_eq!(ChainId::from_str("TESTING").unwrap(), ChainId::test());
171 assert_eq!(ChainId::from_str("255").unwrap(), ChainId::new(255));
172 }
173}