diff --git a/Cargo.lock b/Cargo.lock index fda0d79..8fb0261 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -219,6 +219,15 @@ dependencies = [ "xmltree", ] +[[package]] +name = "instant" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "itoa" version = "0.4.6" @@ -252,6 +261,15 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a" +[[package]] +name = "lock_api" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.11" @@ -298,6 +316,31 @@ version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13bd41f508810a131401606d54ac32a467c97172d74ba7662562ebba5ad07fa0" +[[package]] +name = "parking_lot" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c6d9b8427445284a09c55be860a15855ab580a417ccad9da88f5a06787ced0" +dependencies = [ + "cfg-if 1.0.0", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -494,6 +537,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "semver" version = "0.9.0" @@ -854,6 +903,7 @@ dependencies = [ "igd", "libc", "log", + "parking_lot", "privdrop", "rand 0.8.0", "ring", diff --git a/Cargo.toml b/Cargo.toml index 9ff9d3a..863c504 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ privdrop = "0.5" byteorder = "1.3" thiserror = "1.0" smallvec = "1.5" +parking_lot = "*" [dev-dependencies] tempfile = "3" diff --git a/src/crypto/core.rs b/src/crypto/core.rs index 9afad84..34c0b17 100644 --- a/src/crypto/core.rs +++ b/src/crypto/core.rs @@ -42,10 +42,12 @@ use ring::{ aead::{self, LessSafeKey, UnboundKey}, rand::{SecureRandom, SystemRandom} }; +use std::cell::UnsafeCell; use std::{ io::{Cursor, Read, Write}, mem, + sync::atomic::{AtomicUsize, Ordering}, time::{Duration, Instant} }; @@ -126,35 +128,42 @@ impl CryptoKey { } } +// Why this is safe: +// Only 2 of the 4 keys are accessed. +// Only the other two keys will ever be replaced and then become current. +// Between two replacements is enough time so that all calls using those keys are long done. pub struct CryptoCore { rand: SystemRandom, - keys: [CryptoKey; 4], - current_key: usize, - nonce_half: bool + keys: [UnsafeCell; 4], + current_key: AtomicUsize, + nonce_half: bool, + algorithm: &'static aead::Algorithm } impl CryptoCore { pub fn new(key: LessSafeKey, nonce_half: bool) -> Self { let rand = SystemRandom::new(); + let algorithm = key.algorithm(); let dummy_key_data = random_data(key.algorithm().key_len()); let dummy_key1 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap()); let dummy_key2 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap()); let dummy_key3 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap()); Self { keys: [ - CryptoKey::new(&rand, key, nonce_half), - CryptoKey::new(&rand, dummy_key1, nonce_half), - CryptoKey::new(&rand, dummy_key2, nonce_half), - CryptoKey::new(&rand, dummy_key3, nonce_half) + UnsafeCell::new(CryptoKey::new(&rand, key, nonce_half)), + UnsafeCell::new(CryptoKey::new(&rand, dummy_key1, nonce_half)), + UnsafeCell::new(CryptoKey::new(&rand, dummy_key2, nonce_half)), + UnsafeCell::new(CryptoKey::new(&rand, dummy_key3, nonce_half)) ], - current_key: 0, + current_key: AtomicUsize::new(0), nonce_half, - rand + rand, + algorithm } } - pub fn encrypt(&mut self, buffer: &mut MsgBuffer) { + pub fn encrypt(&self, buffer: &mut MsgBuffer) { let data_start = buffer.get_start(); let data_length = buffer.len(); assert!(buffer.get_start() >= EXTRA_LEN); @@ -162,11 +171,12 @@ impl CryptoCore { buffer.set_length(data_length + EXTRA_LEN + TAG_LEN); let (extra, data_and_tag) = buffer.message_mut().split_at_mut(EXTRA_LEN); let (data, tag_space) = data_and_tag.split_at_mut(data_length); - let key = &mut self.keys[self.current_key]; + let current_key = self.current_key.load(Ordering::SeqCst); + let key = unsafe { self.keys[current_key].get().as_mut().unwrap() }; key.send_nonce.increment(); { let mut extra = Cursor::new(extra); - extra.write_u8(self.current_key as u8).unwrap(); + extra.write_u8(current_key as u8).unwrap(); extra.write_all(&key.send_nonce.as_bytes()[5..]).unwrap(); } let nonce = aead::Nonce::assume_unique_for_key(*key.send_nonce.as_bytes()); @@ -190,7 +200,7 @@ impl CryptoCore { Ok(()) } - pub fn decrypt(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { + pub fn decrypt(&self, buffer: &mut MsgBuffer) -> Result<(), Error> { assert!(buffer.len() >= EXTRA_LEN + TAG_LEN); let (extra, data_and_tag) = buffer.message_mut().split_at_mut(EXTRA_LEN); let key_id; @@ -202,34 +212,37 @@ impl CryptoCore { extra.read_exact(&mut nonce.0[5..]).map_err(|_| Error::Crypto("Input data too short"))?; nonce.set_msb(if self.nonce_half { 0x00 } else { 0x80 }); } - let key = &mut self.keys[key_id as usize]; + let key = unsafe { self.keys[key_id as usize].get().as_mut().unwrap() }; let result = Self::decrypt_with_key(key, nonce, data_and_tag); buffer.set_start(buffer.get_start() + EXTRA_LEN); buffer.set_length(buffer.len() - TAG_LEN); result } - pub fn rotate_key(&mut self, key: LessSafeKey, id: u64, use_for_sending: bool) { + pub fn rotate_key(&self, key: LessSafeKey, id: u64, use_for_sending: bool) { debug!("Rotated key {} (use for sending: {})", id, use_for_sending); let id = (id % 4) as usize; - self.keys[id] = CryptoKey::new(&self.rand, key, self.nonce_half); + let mut new_key = CryptoKey::new(&self.rand, key, self.nonce_half); + let stored_key = unsafe { self.keys[id].get().as_mut().unwrap() }; + mem::swap(&mut new_key, stored_key); if use_for_sending { - self.current_key = id + self.current_key.store(id, Ordering::SeqCst); } } pub fn algorithm(&self) -> &'static aead::Algorithm { - self.keys[self.current_key].key.algorithm() + self.algorithm } - pub fn every_second(&mut self) { + pub fn every_second(&self) { // Set min nonce on all keys - for k in &mut self.keys { - k.update_min_nonce(); + for k in &self.keys { + unsafe { k.get().as_mut().unwrap().update_min_nonce() }; } } } +unsafe impl Sync for CryptoCore {} pub fn create_dummy_pair(algo: &'static aead::Algorithm) -> (CryptoCore, CryptoCore) { let key_data = random_data(algo.key_len()); diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 75d6bda..dfd8a7e 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -2,7 +2,7 @@ mod core; mod init; mod rotate; -pub use self::core::{EXTRA_LEN, TAG_LEN}; +pub use self::core::{EXTRA_LEN, TAG_LEN, CryptoCore}; use self::{ core::{test_speed, CryptoCore}, init::{InitResult, InitState, CLOSING}, diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index b8ad1bb..70ecb59 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -1,19 +1,16 @@ -use crate::{ - engine::{addr_nice, Hash, PeerData}, - messages::MESSAGE_TYPE_DATA, - net::Socket, - table::ClaimTable, - traffic::TrafficStats, - Protocol +use super::{ + shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, + SPACE_BEFORE }; -use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc}; - -use super::{shared::SharedData, SPACE_BEFORE}; use crate::{ device::Device, error::Error, - util::{MsgBuffer, Time, TimeSource} + messages::MESSAGE_TYPE_DATA, + net::Socket, + util::{MsgBuffer, Time, TimeSource}, + Protocol }; +use std::{marker::PhantomData, net::SocketAddr}; pub struct DeviceThread { // Read-only fields @@ -25,10 +22,9 @@ pub struct DeviceThread { device: D, next_housekeep: Time, // Shared fields - shared: Arc, - peers: HashMap, - traffic: TrafficStats, - table: ClaimTable + traffic: SharedTraffic, + peer_crypto: SharedPeerCrypto, + table: SharedTable } impl DeviceThread { @@ -46,31 +42,29 @@ impl DeviceThread Result<(), Error> { debug!("Sending msg with {} bytes to {}", msg.len(), addr); - let peer = match self.peers.get_mut(&addr) { - Some(peer) => peer, - None => return Err(Error::Message("Sending to node that is not a peer")) - }; - peer.crypto.send_message(type_, msg)?; - self.send_to(addr, msg) + if self.peer_crypto.send_message(addr, type_, msg)? { + self.send_to(addr, msg) + } else { + Err(Error::Message("Sending to node that is not a peer")) + } } #[inline] fn broadcast_msg(&mut self, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { - debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peers.len()); + debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peer_crypto.count()); let mut msg_data = MsgBuffer::new(100); - for (addr, peer) in &mut self.peers { + self.peer_crypto.for_each(|addr, crypto| { msg_data.set_start(msg.get_start()); msg_data.set_length(msg.len()); msg_data.message_mut().clone_from_slice(msg.message()); - peer.crypto.send_message(type_, &mut msg_data)?; - self.traffic.count_out_traffic(*addr, msg_data.len()); - match self.socket.send(msg_data.message(), *addr) { + crypto.send_message(type_, &mut msg_data)?; + self.traffic.count_out_traffic(addr, msg_data.len()); + match self.socket.send(msg_data.message(), addr) { Ok(written) if written == msg_data.len() => Ok(()), Ok(_) => Err(Error::Socket("Sent out truncated packet")), Err(e) => Err(Error::SocketIo("IOError when sending", e)) - }? - } - Ok(()) + } + }) } fn forward_packet(&mut self, data: &mut MsgBuffer) -> Result<(), Error> { @@ -97,7 +91,9 @@ impl DeviceThread Result<(), Error> { - // TODO: sync + self.peer_crypto.sync(); + self.table.sync(); + self.traffic.sync(); unimplemented!(); } diff --git a/src/engine/shared.rs b/src/engine/shared.rs index b015690..eaadf33 100644 --- a/src/engine/shared.rs +++ b/src/engine/shared.rs @@ -1 +1,106 @@ -pub struct SharedData {} +use crate::error::Error; +use crate::{ + crypto::CryptoCore, + engine::{Hash, PeerData, TimeSource}, + messages::NodeInfo, + table::ClaimTable, + traffic::TrafficStats, + types::{Address, NodeId, RangeList}, + util::MsgBuffer +}; +use parking_lot::Mutex; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; + +pub struct SharedPeerCrypto { + peers: Arc>> +} + +impl SharedPeerCrypto { + pub fn sync(&mut self) { + // TODO sync if needed + } + + pub fn send_message(&mut self, peer: SocketAddr, type_: u8, data: &mut MsgBuffer) -> Result { + let mut peers = self.peers.lock(); + if let Some(peer) = peers.get_mut(&peer) { + peer.send_message(type_, data); + Ok(true) + } else { + Ok(false) + } + } + + pub fn for_each(&mut self, mut callback: impl FnMut(SocketAddr, &mut PeerCrypto) -> Result<(), Error>) -> Result<(), Error> { + let mut peers = self.peers.lock(); + for (k, v) in peers.iter_mut() { + callback(*k, v)? + } + Ok(()) + } + + pub fn count(&self) -> usize { + self.peers.lock().len() + } +} + + +pub struct SharedTraffic { + traffic: Arc> +} + +impl SharedTraffic { + pub fn sync(&mut self) { + // TODO sync if needed + } + + pub fn count_out_traffic(&self, peer: SocketAddr, bytes: usize) { + self.traffic.lock().count_out_traffic(peer, bytes); + } + + pub fn count_in_traffic(&self, peer: SocketAddr, bytes: usize) { + self.traffic.lock().count_in_traffic(peer, bytes); + } + + pub fn count_out_payload(&self, remote: Address, local: Address, bytes: usize) { + self.traffic.lock().count_out_payload(remote, local, bytes); + } + + pub fn count_in_payload(&self, remote: Address, local: Address, bytes: usize) { + self.traffic.lock().count_in_payload(remote, local, bytes); + } + + pub fn count_dropped_payload(&self, bytes: usize) { + self.traffic.lock().count_dropped_payload(bytes); + } + + pub fn count_invalid_protocol(&self, bytes: usize) { + self.traffic.lock().count_invalid_protocol(bytes); + } +} + + +pub struct SharedTable { + table: Arc>> +} + +impl SharedTable { + pub fn sync(&mut self) { + // TODO sync if needed + } + + pub fn lookup(&self, addr: Address) -> Option { + self.table.lock().lookup(addr) + } + + pub fn set_claims(&self, peer: SocketAddr, claims: RangeList) { + self.table.lock().set_claims(peer, claims) + } + + pub fn remove_claims(&self, peer: SocketAddr) { + self.table.lock().remove_claims(peer) + } + + pub fn cache(&self, addr: Address, peer: SocketAddr) { + self.table.lock().cache(addr, peer) + } +} diff --git a/src/engine/socket_thread.rs b/src/engine/socket_thread.rs index 9f99dc9..b439583 100644 --- a/src/engine/socket_thread.rs +++ b/src/engine/socket_thread.rs @@ -1,4 +1,8 @@ -use super::{shared::SharedData, SPACE_BEFORE}; +use super::{ + shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, + SPACE_BEFORE +}; + use crate::{ config::DEFAULT_PEER_TIMEOUT, crypto::{is_init_message, MessageResult, PeerCrypto}, @@ -6,13 +10,11 @@ use crate::{ error::Error, messages::{AddrList, NodeInfo, PeerInfo}, net::{mapped_addr, Socket}, - table::ClaimTable, - traffic::TrafficStats, types::{NodeId, RangeList}, util::{MsgBuffer, Time, TimeSource}, Config, Crypto, Device, Protocol }; -use rand::{random, seq::SliceRandom, thread_rng}; +use rand::{seq::SliceRandom}; use smallvec::{smallvec, SmallVec}; use std::{ collections::HashMap, @@ -20,7 +22,6 @@ use std::{ io::Cursor, marker::PhantomData, net::{SocketAddr, ToSocketAddrs}, - sync::Arc }; pub struct SocketThread { @@ -38,12 +39,11 @@ pub struct SocketThread { next_housekeep: Time, own_addresses: AddrList, pending_inits: HashMap, Hash>, - // Shared fields - shared: Arc, - traffic: TrafficStats, - peers: HashMap, crypto: Crypto, - table: ClaimTable + peers: HashMap, + // Shared fields + traffic: SharedTraffic, + table: SharedTable } impl SocketThread { @@ -273,6 +273,7 @@ impl SocketThread Result<(), Error> { + // self.shared.sync(); // TODO: sync unimplemented!(); }