From 2fcb7646c7d33e23037244657d457235aaccc822 Mon Sep 17 00:00:00 2001 From: Dennis Schwerdel Date: Thu, 1 Dec 2022 18:42:57 +0100 Subject: [PATCH] Coms --- src/engine/common.rs | 21 ++--- src/engine/coms.rs | 152 ++++++++++++++++++++++++++++++++++++ src/engine/device_thread.rs | 108 ++----------------------- src/engine/mod.rs | 1 + src/engine/shared.rs | 18 +++++ src/engine/socket_thread.rs | 151 +++++++++++------------------------ 6 files changed, 230 insertions(+), 221 deletions(-) create mode 100644 src/engine/coms.rs diff --git a/src/engine/common.rs b/src/engine/common.rs index a6ae106..81c0cae 100644 --- a/src/engine/common.rs +++ b/src/engine/common.rs @@ -12,8 +12,8 @@ use crate::{ device::Device, engine::{ device_thread::DeviceThread, - shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, socket_thread::{ReconnectEntry, SocketThread}, + coms::Coms }, error::Error, messages::AddrList, @@ -50,26 +50,17 @@ impl GenericCloud, stats_file: Option, ) -> Result { - let table = SharedTable::::new(config); - let traffic = SharedTraffic::new(); - let peer_crypto = SharedPeerCrypto::new(); let running = Arc::new(AtomicBool::new(true)); + let coms = Coms::::new(config, socket.try_clone().map_err(|e| Error::SocketIo("Failed to clone socket", e))?); let device_thread = DeviceThread::::new( - config.clone(), device.duplicate()?, - socket.try_clone().map_err(|e| Error::SocketIo("Failed to clone socket", e))?, - traffic.clone(), - peer_crypto.clone(), - table.clone(), + coms.try_clone()?, running.clone(), ); let mut socket_thread = SocketThread::::new( config.clone(), + coms, device, - socket, - traffic, - peer_crypto, - table, port_forwarding, stats_file, running.clone(), @@ -120,7 +111,7 @@ use std::net::SocketAddr; #[cfg(test)] impl GenericCloud { pub fn socket(&mut self) -> &mut MockSocket { - &mut self.socket_thread.socket + &mut self.socket_thread.coms.socket } pub fn device(&mut self) -> &mut MockDevice { @@ -153,6 +144,6 @@ impl GenericCloud { } pub fn get_num(&self) -> usize { - self.socket_thread.socket.address().unwrap().port() as usize + self.socket_thread.coms.socket.address().unwrap().port() as usize } } diff --git a/src/engine/coms.rs b/src/engine/coms.rs new file mode 100644 index 0000000..2ca7506 --- /dev/null +++ b/src/engine/coms.rs @@ -0,0 +1,152 @@ +use std::{net::SocketAddr, marker::PhantomData, io}; + +use crate::{ + config::Config, + error::Error, + messages::MESSAGE_TYPE_DATA, + net::{Socket, mapped_addr}, + util::{MsgBuffer, TimeSource}, payload::Protocol, +}; + +use super::{ + common::{SPACE_BEFORE, PeerData}, + shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, +}; + +pub struct Coms { + _dummy_p: PhantomData

, + broadcast: bool, + broadcast_buffer: MsgBuffer, + peer_crypto: SharedPeerCrypto, + pub table: SharedTable, + pub traffic: SharedTraffic, + pub socket: S, +} + +impl Coms { + pub fn new(config: &Config, socket: S) -> Self { + Self { + _dummy_p: PhantomData, + broadcast: config.is_broadcasting(), + broadcast_buffer: MsgBuffer::new(SPACE_BEFORE), + traffic: SharedTraffic::new(), + peer_crypto: SharedPeerCrypto::new(), + table: SharedTable::::new(config), + socket, + } + } + + pub fn try_clone(&self) -> Result { + Ok(Self { + _dummy_p: PhantomData, + broadcast: self.broadcast, + broadcast_buffer: MsgBuffer::new(SPACE_BEFORE), + traffic: self.traffic.clone(), + peer_crypto: self.peer_crypto.clone(), + table: self.table.clone(), + socket: self.socket.try_clone().map_err(|e| Error::SocketIo("Failed to clone socket", e))?, + }) + } + + pub fn sync(&mut self) -> Result<(), Error> { + self.peer_crypto.load(); + self.table.sync(); + self.traffic.sync(); + Ok(()) + } + + pub fn get_address(&self) -> Result { + self.socket.address().map(mapped_addr) + } + + pub fn send_raw(&mut self, data: &[u8], addr: SocketAddr) -> Result<(), Error> { + match self.socket.send(data, addr) { + Ok(written) if written == data.len() => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)), + } + } + + pub fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { + self.socket.receive(buffer) + } + + #[inline] + pub fn send_to(&mut self, addr: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { + let size = data.len(); + debug!("Sending msg with {} bytes to {}", size, addr); + self.traffic.count_out_traffic(addr, size); + match self.socket.send(data.message(), addr) { + Ok(written) if written == size => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)), + } + } + + #[inline] + pub fn send_msg(&mut self, addr: SocketAddr, type_: u8, data: &mut MsgBuffer) -> Result<(), Error> { + debug!("Sending msg with {} bytes to {}", data.len(), addr); + data.prepend_byte(type_); + self.peer_crypto.encrypt_for(addr, data)?; + self.send_to(addr, data) + } + + #[inline] + pub fn broadcast_msg(&mut self, type_: u8, data: &mut MsgBuffer) -> Result<(), Error> { + let size = data.len(); + debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, size, self.peer_crypto.count()); + let traffic = &mut self.traffic; + let socket = &mut self.socket; + let peers = self.peer_crypto.get_snapshot(); + for (addr, crypto) in peers { + self.broadcast_buffer.set_start(data.get_start()); + self.broadcast_buffer.set_length(data.len()); + self.broadcast_buffer.message_mut().clone_from_slice(data.message()); + self.broadcast_buffer.prepend_byte(type_); + if let Some(crypto) = crypto { + crypto.encrypt(&mut self.broadcast_buffer); + } + traffic.count_out_traffic(*addr, self.broadcast_buffer.len()); + match socket.send(self.broadcast_buffer.message(), *addr) { + Ok(written) if written == self.broadcast_buffer.len() => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)), + }? + } + Ok(()) + } + + pub fn forward_packet(&mut self, data: &mut MsgBuffer) -> Result<(), Error> { + let (src, dst) = P::parse(data.message())?; + debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, data.len()); + self.traffic.count_out_payload(dst.clone(), src, data.len()); + match self.table.lookup(&dst) { + Some(addr) => { + // Peer found for destination + debug!("Found destination for {} => {}", dst, addr); + self.send_msg(addr, MESSAGE_TYPE_DATA, data)?; + } + //TODO: VIA: find relay peer and relay message + None => { + if self.broadcast { + debug!("No destination for {} found, broadcasting", dst); + self.broadcast_msg(MESSAGE_TYPE_DATA, data)?; + } else { + debug!("No destination for {} found, dropping", dst); + self.traffic.count_dropped_payload(data.len()); + } + } + } + Ok(()) + } + + pub fn add_peer(&mut self, addr: SocketAddr, peer: &PeerData) { + self.peer_crypto.add(addr, peer.crypto.get_core()); + } + + pub fn remove_peer(&mut self, addr: &SocketAddr) { + self.table.remove_claims(*addr); + self.peer_crypto.remove(addr); + } + +} diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index fe22586..bb85fd3 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -1,139 +1,43 @@ -use super::{ - common::SPACE_BEFORE, - shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, -}; +use super::{common::SPACE_BEFORE, coms::Coms}; use crate::{ - config::Config, device::Device, error::Error, - messages::MESSAGE_TYPE_DATA, net::Socket, util::{MsgBuffer, Time, TimeSource}, Protocol, }; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::{marker::PhantomData, net::SocketAddr}; pub struct DeviceThread { - // Read-only fields - _dummy_ts: PhantomData, - _dummy_p: PhantomData

, - broadcast: bool, // Device-only fields - socket: S, + coms: Coms, pub device: D, next_housekeep: Time, buffer: MsgBuffer, - broadcast_buffer: MsgBuffer, // Shared fields - traffic: SharedTraffic, - peer_crypto: SharedPeerCrypto, - table: SharedTable, running: Arc, } impl DeviceThread { - pub fn new( - config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, - table: SharedTable, running: Arc, - ) -> Self { + pub fn new(device: D, coms: Coms, running: Arc) -> Self { Self { - _dummy_ts: PhantomData, - _dummy_p: PhantomData, - broadcast: config.is_broadcasting(), - socket, device, next_housekeep: TS::now(), - traffic, - peer_crypto, - table, buffer: MsgBuffer::new(SPACE_BEFORE), - broadcast_buffer: MsgBuffer::new(SPACE_BEFORE), running, + coms, } } - #[inline] - fn send_to(&mut self, addr: SocketAddr) -> Result<(), Error> { - let size = self.buffer.len(); - debug!("Sending msg with {} bytes to {}", size, addr); - self.traffic.count_out_traffic(addr, size); - match self.socket.send(self.buffer.message(), addr) { - Ok(written) if written == size => Ok(()), - Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)), - } - } - - #[inline] - fn send_msg(&mut self, addr: SocketAddr, type_: u8) -> Result<(), Error> { - debug!("Sending msg with {} bytes to {}", self.buffer.len(), addr); - self.buffer.prepend_byte(type_); - self.peer_crypto.encrypt_for(addr, &mut self.buffer)?; - self.send_to(addr) - } - - #[inline] - fn broadcast_msg(&mut self, type_: u8) -> Result<(), Error> { - let size = self.buffer.len(); - debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, size, self.peer_crypto.count()); - let traffic = &mut self.traffic; - let socket = &mut self.socket; - let peers = self.peer_crypto.get_snapshot(); - for (addr, crypto) in peers { - self.broadcast_buffer.set_start(self.buffer.get_start()); - self.broadcast_buffer.set_length(self.buffer.len()); - self.broadcast_buffer.message_mut().clone_from_slice(self.buffer.message()); - self.broadcast_buffer.prepend_byte(type_); - if let Some(crypto) = crypto { - crypto.encrypt(&mut self.broadcast_buffer); - } - traffic.count_out_traffic(*addr, self.broadcast_buffer.len()); - match socket.send(self.broadcast_buffer.message(), *addr) { - Ok(written) if written == self.broadcast_buffer.len() => Ok(()), - Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)), - }? - } - Ok(()) - } - - fn forward_packet(&mut self) -> Result<(), Error> { - let (src, dst) = P::parse(self.buffer.message())?; - debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, self.buffer.len()); - self.traffic.count_out_payload(dst.clone(), src, self.buffer.len()); - match self.table.lookup(&dst) { - Some(addr) => { - // Peer found for destination - debug!("Found destination for {} => {}", dst, addr); - self.send_msg(addr, MESSAGE_TYPE_DATA)?; - } - //TODO: VIA: find relay peer and relay message - None => { - if self.broadcast { - debug!("No destination for {} found, broadcasting", dst); - self.broadcast_msg(MESSAGE_TYPE_DATA)?; - } else { - debug!("No destination for {} found, dropping", dst); - self.traffic.count_dropped_payload(self.buffer.len()); - } - } - } - Ok(()) - } - pub fn housekeep(&mut self) -> Result<(), Error> { - self.peer_crypto.load(); - self.table.sync(); - self.traffic.sync(); - Ok(()) + self.coms.sync() } pub fn iteration(&mut self) -> bool { if self.device.read(&mut self.buffer).is_ok() { //try_fail!(result, "Failed to read from device: {}"); - if let Err(e) = self.forward_packet() { + if let Err(e) = self.coms.forward_packet(&mut self.buffer) { error!("{}", e); } } diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 9251f96..64c7d57 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -5,4 +5,5 @@ mod device_thread; mod shared; mod socket_thread; +pub mod coms; pub mod common; \ No newline at end of file diff --git a/src/engine/shared.rs b/src/engine/shared.rs index 76f0c99..1b05d0c 100644 --- a/src/engine/shared.rs +++ b/src/engine/shared.rs @@ -49,6 +49,18 @@ impl SharedPeerCrypto { Ok(()) } + pub fn add(&mut self, addr: SocketAddr, crypto: Option>) { + self.cache.insert(addr, crypto.clone()); + let mut peers = self.peers.lock(); + peers.insert(addr, crypto); + } + + pub fn remove(&mut self, addr: &SocketAddr) { + self.cache.remove(addr); + let mut peers = self.peers.lock(); + peers.remove(addr); + } + pub fn store(&mut self, data: &HashMap) { self.cache.clear(); self.cache.extend(data.iter().map(|(k, v)| (*k, v.crypto.get_core()))); @@ -138,6 +150,12 @@ impl SharedTraffic { } } +impl Default for SharedTraffic { + fn default() -> Self { + Self::new() + } +} + #[derive(Clone)] pub struct SharedTable { table: Arc>>, diff --git a/src/engine/socket_thread.rs b/src/engine/socket_thread.rs index 07d13e1..90b280b 100644 --- a/src/engine/socket_thread.rs +++ b/src/engine/socket_thread.rs @@ -1,7 +1,4 @@ -use super::{ - common::SPACE_BEFORE, - shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, -}; +use super::{common::SPACE_BEFORE, coms::Coms}; use crate::{ beacon::BeaconSerializer, @@ -62,7 +59,7 @@ pub struct SocketThread { _dummy_ts: PhantomData, _dummy_p: PhantomData

, // Socket-only fields - pub socket: S, + pub coms: Coms, device: D, next_housekeep: Time, pub own_addresses: AddrList, @@ -78,11 +75,8 @@ pub struct SocketThread { statsd_server: Option, pub reconnect_peers: SmallVec<[ReconnectEntry; 3]>, buffer: MsgBuffer, - broadcast_buffer: MsgBuffer, // Shared fields - peer_crypto: SharedPeerCrypto, - traffic: SharedTraffic, - table: SharedTable, + //table: SharedTable, running: Arc, // Should not be here port_forwarding: Option, // TODO: 3rd thread @@ -91,9 +85,8 @@ pub struct SocketThread { impl SocketThread { #[allow(clippy::too_many_arguments)] pub fn new( - config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, - table: SharedTable, port_forwarding: Option, stats_file: Option, - running: Arc, + config: Config, coms: Coms, device: D, + port_forwarding: Option, stats_file: Option, running: Arc, ) -> Self { let mut claims = SmallVec::with_capacity(config.claims.len()); for s in &config.claims { @@ -122,10 +115,7 @@ impl SocketThread SocketThread Result<(), Error> { - let size = self.buffer.len(); - debug!("Sending msg with {} bytes to {}", size, addr); - self.traffic.count_out_traffic(addr, size); - match self.socket.send(self.buffer.message(), addr) { - Ok(written) if written == size => { - self.buffer.clear(); - Ok(()) - } - Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)), - } - } - - #[inline] - fn broadcast_msg(&mut self, type_: u8) -> Result<(), Error> { - debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, self.buffer.len(), self.peers.len()); - for (addr, peer) in &mut self.peers { - self.broadcast_buffer.set_start(self.buffer.get_start()); - self.broadcast_buffer.set_length(self.buffer.len()); - self.broadcast_buffer.message_mut().clone_from_slice(self.buffer.message()); - self.broadcast_buffer.prepend_byte(type_); - peer.crypto.encrypt_message(&mut self.broadcast_buffer); - self.traffic.count_out_traffic(*addr, self.broadcast_buffer.len()); - match self.socket.send(self.broadcast_buffer.message(), *addr) { - Ok(written) if written == self.broadcast_buffer.len() => Ok(()), - Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)), - }? - } - self.buffer.clear(); - Ok(()) - } - fn connect_sock(&mut self, addr: SocketAddr) -> Result<(), Error> { let addr = mapped_addr(addr); if self.peers.contains_key(&addr) @@ -198,7 +152,7 @@ impl SocketThread(&mut self, addr: Addr) -> Result<(), Error> { @@ -258,7 +212,7 @@ impl SocketThread SocketThread SocketThread SocketThread SocketThread SocketThread val, Err(err) => { - self.traffic.count_invalid_protocol(self.buffer.len()); + self.coms.traffic.count_invalid_protocol(self.buffer.len()); return Err(err); } }; @@ -371,11 +322,11 @@ impl SocketThread { self.buffer.clear(); - self.traffic.count_invalid_protocol(self.buffer.len()); + self.coms.traffic.count_invalid_protocol(self.buffer.len()); return Err(Error::Message("Unknown message type")); } }, - MessageResult::Reply => self.send_to(src)?, + MessageResult::Reply => self.coms.send_to(src, &mut self.buffer)?, MessageResult::None => { self.buffer.clear(); } @@ -387,7 +338,7 @@ impl SocketThread SocketThread SocketThread SocketThread { self.pending_inits.insert(src, init); - self.send_to(src) + self.coms.send_to(src, &mut self.buffer) } Err(err) => { - self.traffic.count_invalid_protocol(self.buffer.len()); + self.coms.traffic.count_invalid_protocol(self.buffer.len()); Err(err) } } @@ -440,10 +391,10 @@ impl SocketThread SocketThread SocketThread SocketThread SocketThread>() { self.buffer.clear(); self.peers.get_mut(&addr).unwrap().crypto.every_second(&mut self.buffer); if !self.buffer.is_empty() { - self.send_to(addr)? + self.coms.send_to(addr, &mut self.buffer)? } } for addr in del { @@ -523,7 +472,7 @@ impl SocketThread io::Result<()> { self.own_addresses.clear(); - let socket_addr = self.socket.address().map(mapped_addr)?; + let socket_addr = self.coms.get_address()?; // 1) Specified advertise addresses for addr in &self.config.advertise_addresses { self.own_addresses.push(parse_listen(addr, socket_addr.port())); @@ -601,9 +550,9 @@ impl SocketThread SocketThread Result<(), Error> { if let Some(ref endpoint) = self.statsd_server { - let peer_traffic = self.traffic.total_peer_traffic(); - let payload_traffic = self.traffic.total_payload_traffic(); - let dropped = &self.traffic.dropped(); + let peer_traffic = self.coms.traffic.total_peer_traffic(); + let payload_traffic = self.coms.traffic.total_payload_traffic(); + let dropped = &self.coms.traffic.dropped(); let prefix = self.config.statsd_prefix.as_ref().map(|s| s as &str).unwrap_or("vpncloud"); let msg = StatsdMsg::new() .with_ns(prefix, |msg| { msg.add("peer_count", self.peers.len(), "g"); - msg.add("table_cache_entries", self.table.cache_len(), "g"); - msg.add("table_claims", self.table.claim_len(), "g"); + msg.add("table_cache_entries", self.coms.table.cache_len(), "g"); + msg.add("table_claims", self.coms.table.claim_len(), "g"); msg.with_ns("traffic", |msg| { msg.with_ns("protocol", |msg| { msg.with_ns("inbound", |msg| { @@ -653,14 +602,9 @@ impl SocketThread Ok(()), - Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)), - }? + self.coms.send_raw(msg.as_bytes(), *addr)?; } else { error!("Failed to resolve statsd server {}", endpoint); } @@ -722,8 +666,7 @@ impl SocketThread bool { - if let Ok(src) = self.socket.receive(&mut self.buffer) - { + if let Ok(src) = self.coms.receive(&mut self.buffer) { match self.handle_message(src) { Err(e @ Error::CryptoInitFatal(_)) => { debug!("Fatal crypto init error from {}: {}", src, e);