1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Copyright (c) The Diem Core Contributors
// SPDX-License-Identifier: Apache-2.0

// Copyright 2021 Conflux Foundation. All rights reserved.
// Conflux is free software and distributed under GNU General Public License.
// See http://www.gnu.org/licenses/

use std::sync::RwLock as StdRwLock;

pub use std::sync::{RwLockReadGuard, RwLockWriteGuard};

/// A simple wrapper around the lock() function of a std::sync::RwLock
/// The only difference is that you don't need to call unwrap() on it.
#[derive(Debug, Default)]
pub struct RwLock<T>(StdRwLock<T>);

impl<T> RwLock<T> {
    /// creates a read-write lock
    pub fn new(t: T) -> Self { Self(StdRwLock::new(t)) }

    /// lock the rwlock in read mode
    pub fn read(&self) -> RwLockReadGuard<'_, T> {
        self.0
            .read()
            .expect("diem cannot currently handle a poisoned lock")
    }

    /// lock the rwlock in write mode
    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
        self.0
            .write()
            .expect("diem cannot currently handle a poisoned lock")
    }

    /// return the owned type consuming the lock
    pub fn into_inner(self) -> T {
        self.0
            .into_inner()
            .expect("diem cannot currently handle a poisoned lock")
    }
}

#[cfg(test)]
mod tests {

    use super::*;
    use std::{sync::Arc, thread};

    #[test]
    fn test_diem_rwlock() {
        let a = 7u8;
        let rwlock = Arc::new(RwLock::new(a));
        let rwlock2 = rwlock.clone();
        let rwlock3 = rwlock.clone();

        let thread1 = thread::spawn(move || {
            let mut b = rwlock2.write();
            *b = 8;
        });
        let thread2 = thread::spawn(move || {
            let mut b = rwlock3.write();
            *b = 9;
        });

        let _ = thread1.join();
        let _ = thread2.join();

        let _read = rwlock.read();
    }
}