diff --git a/src/config.rs b/src/config.rs index 5a9c983..6d9591a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -350,6 +350,34 @@ impl Config { } } + pub fn is_learning(&self) -> bool { + match self.mode { + Mode::Normal => { + match self.device_type { + Type::Tap => true, + Type::Tun => false + } + } + Mode::Router => false, + Mode::Switch => true, + Mode::Hub => false + } + } + + pub fn is_broadcasting(&self) -> bool { + match self.mode { + Mode::Normal => { + match self.device_type { + Type::Tap => true, + Type::Tun => false + } + } + Mode::Router => false, + Mode::Switch => true, + Mode::Hub => true + } + } + pub fn call_hook( &self, event: &'static str, envs: impl IntoIterator)>, detach: bool ) { diff --git a/src/device.rs b/src/device.rs index 42df878..c6a16d1 100644 --- a/src/device.rs +++ b/src/device.rs @@ -2,6 +2,7 @@ // Copyright (C) 2015-2021 Dennis Schwerdel // This software is licensed under GPL-3 or newer (see LICENSE.md) +use parking_lot::Mutex; use std::{ cmp, collections::VecDeque, @@ -12,7 +13,8 @@ use std::{ net::{Ipv4Addr, UdpSocket}, os::unix::io::{AsRawFd, RawFd}, str, - str::FromStr + str::FromStr, + sync::Arc }; use crate::{crypto, error::Error, util::MsgBuffer}; @@ -75,7 +77,7 @@ impl FromStr for Type { } } -pub trait Device: AsRawFd { +pub trait Device: AsRawFd + Clone { /// Returns the type of this device fn get_type(&self) -> Type; @@ -118,6 +120,16 @@ pub struct TunTapDevice { } +impl Clone for TunTapDevice { + fn clone(&self) -> Self { + Self { + fd: try_fail!(self.fd.try_clone(), "Failed to clone device: {}"), + ifname: self.ifname.clone(), + type_: self.type_ + } + } +} + impl TunTapDevice { /// Creates a new tun/tap device /// @@ -300,9 +312,10 @@ impl AsRawFd for TunTapDevice { } +#[derive(Clone)] pub struct MockDevice { - inbound: VecDeque>, - outbound: VecDeque> + inbound: Arc>>>, + outbound: Arc>>> } impl MockDevice { @@ -311,15 +324,15 @@ impl MockDevice { } pub fn put_inbound(&mut self, data: Vec) { - self.inbound.push_back(data) + self.inbound.lock().push_back(data) } pub fn pop_outbound(&mut self) -> Option> { - self.outbound.pop_front() + self.outbound.lock().pop_front() } pub fn has_inbound(&self) -> bool { - !self.inbound.is_empty() + !self.inbound.lock().is_empty() } } @@ -333,7 +346,7 @@ impl Device for MockDevice { } fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { - if let Some(data) = self.inbound.pop_front() { + if let Some(data) = self.inbound.lock().pop_front() { buffer.clear(); buffer.set_length(data.len()); buffer.message_mut().copy_from_slice(&data); @@ -344,7 +357,7 @@ impl Device for MockDevice { } fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { - self.outbound.push_back(buffer.message().into()); + self.outbound.lock().push_back(buffer.message().into()); Ok(()) } @@ -355,7 +368,10 @@ impl Device for MockDevice { impl Default for MockDevice { fn default() -> Self { - Self { outbound: VecDeque::with_capacity(10), inbound: VecDeque::with_capacity(10) } + Self { + outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), + inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))) + } } } diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index ffb05db..8cd7af8 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -8,7 +8,8 @@ use crate::{ messages::MESSAGE_TYPE_DATA, net::Socket, util::{MsgBuffer, Time, TimeSource}, - Protocol + Protocol, + config::Config }; use std::{marker::PhantomData, net::SocketAddr}; @@ -28,6 +29,20 @@ pub struct DeviceThread { } impl DeviceThread { + pub fn new(config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, table: SharedTable) -> Self { + Self { + _dummy_ts: PhantomData, + _dummy_p: PhantomData, + broadcast: config.is_broadcasting(), + socket, + device, + next_housekeep: TS::now(), + traffic, + peer_crypto, + table + } + } + #[inline] fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { debug!("Sending msg with {} bytes to {}", msg.len(), addr); @@ -104,15 +119,16 @@ impl DeviceThread { impl GenericCloud { #[allow(clippy::too_many_arguments)] - pub fn new(config: &Config, socket: S, device: D, port_forwarding: Option, stats_file: Option) -> Self { + pub fn new( + config: &Config, socket: S, device: D, port_forwarding: Option, stats_file: Option + ) -> Self { let (learning, broadcast) = match config.mode { Mode::Normal => { match config.device_type { @@ -288,8 +295,26 @@ impl GenericCloud::new(&self.config); + let traffic = SharedTraffic::new(); + let peer_crypto = SharedPeerCrypto::new(); + let device_thread = DeviceThread::::new( + self.config.clone(), + self.device.clone(), + self.socket.clone(), + traffic.clone(), + peer_crypto.clone(), + table.clone() + ); + + // TODO: create shared data structures + // TODO: create and spawn threads let ctrlc = CtrlC::new(); - let waiter = try_fail!(WaitImpl::new(self.socket.as_raw_fd(), self.device.as_raw_fd(), 1000), "Failed to setup poll: {}"); + // TODO: wait for ctrl-c + let waiter = try_fail!( + WaitImpl::new(self.socket.as_raw_fd(), self.device.as_raw_fd(), 1000), + "Failed to setup poll: {}" + ); let mut buffer = MsgBuffer::new(SPACE_BEFORE); let mut poll_error = false; self.config.call_hook("vpn_started", vec![("IFNAME", self.device.ifname())], true); diff --git a/src/engine/shared.rs b/src/engine/shared.rs index e07781a..594e9d0 100644 --- a/src/engine/shared.rs +++ b/src/engine/shared.rs @@ -6,7 +6,9 @@ use crate::{ table::ClaimTable, traffic::{TrafficStats, TrafficEntry}, types::{Address, NodeId, RangeList}, - util::MsgBuffer + util::MsgBuffer, + util::Duration, + config::Config }; use parking_lot::Mutex; use std::{ @@ -16,11 +18,16 @@ use std::{ sync::Arc }; +#[derive(Clone)] pub struct SharedPeerCrypto { peers: Arc>, Hash>>> } impl SharedPeerCrypto { + pub fn new() -> Self { + SharedPeerCrypto { peers: Arc::new(Mutex::new(HashMap::default())) } + } + pub fn sync(&mut self) { // TODO sync if needed } @@ -50,11 +57,16 @@ impl SharedPeerCrypto { } +#[derive(Clone)] pub struct SharedTraffic { traffic: Arc> } impl SharedTraffic { + pub fn new() -> Self { + Self { traffic: Arc::new(Mutex::new(Default::default())) } + } + pub fn sync(&mut self) { // TODO sync if needed } @@ -105,11 +117,17 @@ impl SharedTraffic { } +#[derive(Clone)] pub struct SharedTable { table: Arc>> } impl SharedTable { + pub fn new(config: &Config) -> Self { + let table = ClaimTable::new(config.switch_timeout as Duration, config.peer_timeout as Duration); + SharedTable { table: Arc::new(Mutex::new(table)) } + } + pub fn sync(&mut self) { // TODO sync if needed } diff --git a/src/engine/socket_thread.rs b/src/engine/socket_thread.rs index 1a3f9ea..9b65d9f 100644 --- a/src/engine/socket_thread.rs +++ b/src/engine/socket_thread.rs @@ -24,7 +24,7 @@ use std::{ fmt, fs::File, io, - io::{Write, Cursor, Seek, SeekFrom}, + io::{Cursor, Seek, SeekFrom, Write}, marker::PhantomData, net::{SocketAddr, ToSocketAddrs} }; @@ -615,28 +615,38 @@ impl SocketThread { - debug!("Fatal crypto init error from {}: {}", src, e); - info!("Closing pending connection to {} due to error in crypto init", addr_nice(src)); - self.pending_inits.remove(&src); + match self.socket.receive(&mut buffer) { + Err(err) => { + if err.kind() == io::ErrorKind::TimedOut || err.kind() == io::ErrorKind::WouldBlock { + // ok, this is a normal timeout + } else { + fail!("Failed to read from network socket: {}", err); + } } - Err(e @ Error::CryptoInit(_)) => { - debug!("Recoverable init error from {}: {}", src, e); - info!("Ignoring invalid init message from peer {}", addr_nice(src)); + Ok(src) => { + match self.handle_message(src, &mut buffer) { + Err(e @ Error::CryptoInitFatal(_)) => { + debug!("Fatal crypto init error from {}: {}", src, e); + info!("Closing pending connection to {} due to error in crypto init", addr_nice(src)); + self.pending_inits.remove(&src); + } + Err(e @ Error::CryptoInit(_)) => { + debug!("Recoverable init error from {}: {}", src, e); + info!("Ignoring invalid init message from peer {}", addr_nice(src)); + } + Err(e) => { + error!("{}", e); + } + Ok(_) => {} + } } - Err(e) => { - error!("{}", e); - } - Ok(_) => {} } let now = TS::now(); if self.next_housekeep < now { if let Err(e) = self.housekeep() { error!("{}", e) } - self.next_housekeep = TS::now() + 1 + self.next_housekeep = now + 1 } } } diff --git a/src/main.rs b/src/main.rs index 9242a67..0ffbfbb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,7 +36,7 @@ use structopt::StructOpt; use std::{ fs::{self, File, Permissions}, io::{self, Write}, - net::{Ipv4Addr, UdpSocket}, + net::{Ipv4Addr}, os::unix::fs::PermissionsExt, path::Path, process, @@ -50,7 +50,7 @@ use crate::{ config::{Args, Command, Config, DEFAULT_PORT}, crypto::Crypto, device::{Device, TunTapDevice, Type}, - net::Socket, + net::{Socket, NetSocket}, oldconfig::OldConfigFile, payload::Protocol, util::SystemTimeSource, @@ -328,7 +328,7 @@ fn main() { } return } - let socket = try_fail!(UdpSocket::listen(&config.listen), "Failed to open socket {}: {}", config.listen); + let socket = try_fail!(NetSocket::listen(&config.listen), "Failed to open socket {}: {}", config.listen); match config.device_type { Type::Tap => run::(config, socket), Type::Tun => run::(config, socket) diff --git a/src/net.rs b/src/net.rs index 7356158..69e19db 100644 --- a/src/net.rs +++ b/src/net.rs @@ -2,17 +2,21 @@ // Copyright (C) 2015-2021 Dennis Schwerdel // This software is licensed under GPL-3 or newer (see LICENSE.md) +use super::util::{MockTimeSource, MsgBuffer, Time, TimeSource}; +use crate::port_forwarding::PortForwarding; +use parking_lot::Mutex; use std::{ collections::{HashMap, VecDeque}, io::{self, ErrorKind}, - net::{IpAddr, SocketAddr, UdpSocket, Ipv6Addr}, + net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket}, os::unix::io::{AsRawFd, RawFd}, - sync::atomic::{AtomicBool, Ordering} + sync::{ + atomic::{AtomicBool, Ordering}, + Arc + }, + time::Duration }; -use super::util::{MockTimeSource, MsgBuffer, Time, TimeSource}; -use crate::port_forwarding::PortForwarding; - pub fn mapped_addr(addr: SocketAddr) -> SocketAddr { // HOT PATH match addr { @@ -27,7 +31,7 @@ pub fn get_ip() -> IpAddr { s.local_addr().unwrap().ip() } -pub trait Socket: AsRawFd + Sized { +pub trait Socket: AsRawFd + Sized + Clone { fn listen(addr: &str) -> Result; fn receive(&mut self, buffer: &mut MsgBuffer) -> Result; fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result; @@ -47,25 +51,42 @@ pub fn parse_listen(addr: &str) -> SocketAddr { } } -impl Socket for UdpSocket { +pub struct NetSocket(UdpSocket); + +impl Clone for NetSocket { + fn clone(&self) -> Self { + Self(try_fail!(self.0.try_clone(), "Failed to clone socket: {}")) + } +} + +impl AsRawFd for NetSocket { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl Socket for NetSocket { fn listen(addr: &str) -> Result { let addr = parse_listen(addr); - UdpSocket::bind(addr) + Ok(NetSocket(UdpSocket::bind(addr).and_then(|s| { + s.set_read_timeout(Some(Duration::from_secs(1)))?; + Ok(s) + })?)) } fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { buffer.clear(); - let (size, addr) = self.recv_from(buffer.buffer())?; + let (size, addr) = self.0.recv_from(buffer.buffer())?; buffer.set_length(size); Ok(addr) } fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result { - self.send_to(data, addr) + self.0.send_to(data, addr) } fn address(&self) -> Result { - let mut addr = self.local_addr()?; + let mut addr = self.0.local_addr()?; addr.set_ip(get_ip()); Ok(addr) } @@ -79,22 +100,24 @@ thread_local! { static MOCK_SOCKET_NAT: AtomicBool = AtomicBool::new(false); } + +#[derive(Clone)] pub struct MockSocket { nat: bool, - nat_peers: HashMap, + nat_peers: Arc>>, address: SocketAddr, - outbound: VecDeque<(SocketAddr, Vec)>, - inbound: VecDeque<(SocketAddr, Vec)> + outbound: Arc)>>>, + inbound: Arc)>>> } impl MockSocket { pub fn new(address: SocketAddr) -> Self { Self { nat: Self::get_nat(), - nat_peers: HashMap::new(), + nat_peers: Default::default(), address, - outbound: VecDeque::with_capacity(10), - inbound: VecDeque::with_capacity(10) + outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), + inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))) } } @@ -108,12 +131,12 @@ impl MockSocket { pub fn put_inbound(&mut self, from: SocketAddr, data: Vec) -> bool { if !self.nat { - self.inbound.push_back((from, data)); + self.inbound.lock().push_back((from, data)); return true } - if let Some(timeout) = self.nat_peers.get(&from) { + if let Some(timeout) = self.nat_peers.lock().get(&from) { if *timeout >= MockTimeSource::now() { - self.inbound.push_back((from, data)); + self.inbound.lock().push_back((from, data)); return true } } @@ -122,7 +145,7 @@ impl MockSocket { } pub fn pop_outbound(&mut self) -> Option<(SocketAddr, Vec)> { - self.outbound.pop_front() + self.outbound.lock().pop_front() } } @@ -138,7 +161,7 @@ impl Socket for MockSocket { } fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { - if let Some((addr, data)) = self.inbound.pop_front() { + if let Some((addr, data)) = self.inbound.lock().pop_front() { buffer.clear(); buffer.set_length(data.len()); buffer.message_mut().copy_from_slice(&data); @@ -149,9 +172,9 @@ impl Socket for MockSocket { } fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result { - self.outbound.push_back((addr, data.into())); + self.outbound.lock().push_back((addr, data.into())); if self.nat { - self.nat_peers.insert(addr, MockTimeSource::now() + 300); + self.nat_peers.lock().insert(addr, MockTimeSource::now() + 300); } Ok(data.len()) } @@ -178,4 +201,4 @@ mod bench { b.iter(|| sock.send_to(&data, &addr).unwrap()); b.bytes = 1400; } -} \ No newline at end of file +} diff --git a/src/wsproxy.rs b/src/wsproxy.rs index 0cc97d5..c784ae9 100644 --- a/src/wsproxy.rs +++ b/src/wsproxy.rs @@ -13,6 +13,7 @@ use std::{ io::{self, Cursor, Read, Write}, net::{Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket}, os::unix::io::{AsRawFd, RawFd}, + sync::Arc, thread::spawn }; use tungstenite::{client::AutoStream, connect, protocol::WebSocket, server::accept, Message}; @@ -108,17 +109,21 @@ pub fn run_proxy(listen: &str) -> Result<(), io::Error> { Ok(()) } +#[derive(Clone)] pub struct ProxyConnection { addr: SocketAddr, - socket: WebSocket + socket: Arc> } impl ProxyConnection { fn read_message(&mut self) -> Result, io::Error> { loop { + unimplemented!(); + /* if let Message::Binary(data) = io_error!(self.socket.read_message(), "Failed to read from ws proxy: {}")? { return Ok(data) } + */ } } } @@ -128,14 +133,14 @@ impl AsRawFd for ProxyConnection { self.socket.get_ref().as_raw_fd() } } - + impl Socket for ProxyConnection { fn listen(url: &str) -> Result { let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; socket.get_mut().set_nodelay(true)?; let addr = "0.0.0.0:0".parse::().unwrap(); - let mut con = ProxyConnection { addr, socket }; + let mut con = ProxyConnection { addr, socket: Arc::new(socket) }; let addr_data = con.read_message()?; con.addr = read_addr(Cursor::new(&addr_data))?; Ok(con) @@ -153,7 +158,8 @@ impl Socket for ProxyConnection { let mut msg = Vec::with_capacity(data.len() + 18); write_addr(addr, &mut msg)?; msg.write_all(data)?; - io_error!(self.socket.write_message(Message::Binary(msg)), "Failed to write to ws proxy: {}")?; + unimplemented!(); + //io_error!(self.socket.write_message(Message::Binary(msg)), "Failed to write to ws proxy: {}")?; Ok(data.len()) }