diff --git a/Cargo.lock b/Cargo.lock index eecd6f4..3f7085d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "async-trait" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "attohttpc" version = "0.16.3" @@ -571,6 +582,29 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mio" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5dede4e2065b3842b8b0af444119f3aa331cc7cc2dd20388bfb0f5d5a38823a" +dependencies = [ + "libc", + "log", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a33c1b55807fbed163481b5ba66db4b2fa6cde694a5027be10fb724206c5897" +dependencies = [ + "socket2", + "winapi", +] + [[package]] name = "nix" version = "0.14.1" @@ -596,6 +630,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -664,6 +707,12 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pin-project-lite" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439697af366c49a6d0a010c56a0d97685bc140ce0d377b13a2ea2aa42d64a827" + [[package]] name = "plotters" version = "0.3.0" @@ -1017,12 +1066,32 @@ dependencies = [ "nix 0.14.1", ] +[[package]] +name = "signal-hook-registry" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6" +dependencies = [ + "libc", +] + [[package]] name = "smallvec" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.5.2" @@ -1244,6 +1313,37 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +[[package]] +name = "tokio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8190d04c665ea9e6b6a0dc45523ade572c088d2e6566244c1122671dbf4ae3a" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "once_cell", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tungstenite" version = "0.13.0" @@ -1352,6 +1452,7 @@ checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" name = "vpncloud" version = "2.1.0" dependencies = [ + "async-trait", "byteorder", "criterion", "daemonize", @@ -1373,6 +1474,7 @@ dependencies = [ "tempfile", "thiserror", "time", + "tokio", "tungstenite", "url", "yaml-rust", diff --git a/Cargo.toml b/Cargo.toml index 473cba7..55dea1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ dialoguer = { version = "0.7", optional = true } tungstenite = { version = "0.13", optional = true, default-features = false } url = { version = "2.2", optional = true } igd = { version = "0.12", optional = true } +tokio = { version = "1", features = ["full"] } +async-trait = "0.1" [dev-dependencies] tempfile = "3" diff --git a/src/beacon.rs b/src/beacon.rs index 8fdef16..ccf58d9 100644 --- a/src/beacon.rs +++ b/src/beacon.rs @@ -322,7 +322,7 @@ impl BeaconSerializer { #[cfg(test)] use std::time::Duration; #[test] -fn encode() { +async fn encode() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let mut peers = vec![ @@ -340,7 +340,7 @@ fn encode() { } #[test] -fn decode() { +async fn decode() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let mut peers = vec![ @@ -356,7 +356,7 @@ fn decode() { } #[test] -fn decode_split() { +async fn decode_split() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ @@ -374,7 +374,7 @@ fn decode_split() { } #[test] -fn decode_offset() { +async fn decode_offset() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ @@ -388,7 +388,7 @@ fn decode_offset() { } #[test] -fn decode_multiple() { +async fn decode_multiple() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ @@ -402,7 +402,7 @@ fn decode_multiple() { } #[test] -fn decode_ttl() { +async fn decode_ttl() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); MockTimeSource::set_time(2000 * 3600); @@ -426,7 +426,7 @@ fn decode_ttl() { } #[test] -fn decode_invalid() { +async fn decode_invalid() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); assert_eq!(0, ser.decode("", None).len()); @@ -440,7 +440,7 @@ fn decode_invalid() { #[test] -fn encode_decode() { +async fn encode_decode() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ @@ -453,7 +453,7 @@ fn encode_decode() { } #[test] -fn encode_decode_file() { +async fn encode_decode_file() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ @@ -468,7 +468,7 @@ fn encode_decode_file() { } #[test] -fn encode_decode_cmd() { +async fn encode_decode_cmd() { MockTimeSource::set_time(2000 * 3600); let ser = BeaconSerializer::::new(b"mysecretkey"); let peers = vec![ diff --git a/src/config.rs b/src/config.rs index 6d9591a..c3a7c2c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -669,7 +669,7 @@ pub struct ConfigFile { } #[test] -fn config_file() { +async fn config_file() { let config_file = " device: type: tun @@ -744,12 +744,12 @@ statsd: } #[test] -fn parse_example_config() { +async fn parse_example_config() { serde_yaml::from_str::(include_str!("../assets/example.net.disabled")).unwrap(); } #[test] -fn config_merge() { +async fn config_merge() { let mut config = Config::default(); config.merge_file(ConfigFile { device: Some(ConfigFileDevice { diff --git a/src/crypto/common.rs b/src/crypto/common.rs index e9e50cf..d884953 100644 --- a/src/crypto/common.rs +++ b/src/crypto/common.rs @@ -339,7 +339,7 @@ mod tests { } #[test] - fn normal() { + async fn normal() { let config = Config { password: Some("test".to_string()), ..Default::default() }; let mut node1 = create_node(&config); let mut node2 = create_node(&config); diff --git a/src/crypto/core.rs b/src/crypto/core.rs index 1ffe826..8d16464 100644 --- a/src/crypto/core.rs +++ b/src/crypto/core.rs @@ -280,7 +280,7 @@ mod tests { use ring::aead::{self, LessSafeKey, UnboundKey}; #[test] - fn test_nonce() { + async fn test_nonce() { let mut nonce = Nonce::zero(); assert_eq!(nonce.as_bytes(), &[0; 12]); nonce.increment(); @@ -302,17 +302,17 @@ mod tests { } #[test] - fn test_encrypt_decrypt_aes128() { + async fn test_encrypt_decrypt_aes128() { test_encrypt_decrypt(&aead::AES_128_GCM) } #[test] - fn test_encrypt_decrypt_aes256() { + async fn test_encrypt_decrypt_aes256() { test_encrypt_decrypt(&aead::AES_256_GCM) } #[test] - fn test_encrypt_decrypt_chacha() { + async fn test_encrypt_decrypt_chacha() { test_encrypt_decrypt(&aead::CHACHA20_POLY1305) } @@ -343,17 +343,17 @@ mod tests { } #[test] - fn test_tampering_aes128() { + async fn test_tampering_aes128() { test_tampering(&aead::AES_128_GCM) } #[test] - fn test_tampering_aes256() { + async fn test_tampering_aes256() { test_tampering(&aead::AES_256_GCM) } #[test] - fn test_tampering_chacha() { + async fn test_tampering_chacha() { test_tampering(&aead::CHACHA20_POLY1305) } @@ -384,17 +384,17 @@ mod tests { } #[test] - fn test_nonce_pinning_aes128() { + async fn test_nonce_pinning_aes128() { test_nonce_pinning(&aead::AES_128_GCM) } #[test] - fn test_nonce_pinning_aes256() { + async fn test_nonce_pinning_aes256() { test_nonce_pinning(&aead::AES_256_GCM) } #[test] - fn test_nonce_pinning_chacha() { + async fn test_nonce_pinning_chacha() { test_nonce_pinning(&aead::CHACHA20_POLY1305) } @@ -433,40 +433,40 @@ mod tests { } #[test] - fn test_key_rotation_aes128() { + async fn test_key_rotation_aes128() { test_key_rotation(&aead::AES_128_GCM); } #[test] - fn test_key_rotation_aes256() { + async fn test_key_rotation_aes256() { test_key_rotation(&aead::AES_256_GCM); } #[test] - fn test_key_rotation_chacha() { + async fn test_key_rotation_chacha() { test_key_rotation(&aead::CHACHA20_POLY1305); } #[test] - fn test_core_size() { - assert_eq!(2384, mem::size_of::()); + async fn test_core_size() { + assert_eq!(2400, mem::size_of::()); } #[test] - fn test_speed_aes128() { + async fn test_speed_aes128() { let speed = test_speed(&aead::AES_128_GCM, &Duration::from_secs_f32(0.2)); assert!(speed > 10.0); } #[test] - fn test_speed_aes256() { + async fn test_speed_aes256() { let speed = test_speed(&aead::AES_256_GCM, &Duration::from_secs_f32(0.2)); assert!(speed > 10.0); } #[test] - fn test_speed_chacha() { + async fn test_speed_chacha() { let speed = test_speed(&aead::CHACHA20_POLY1305, &Duration::from_secs_f32(0.2)); assert!(speed > 10.0); } diff --git a/src/crypto/init.rs b/src/crypto/init.rs index 0741a4a..86d30af 100644 --- a/src/crypto/init.rs +++ b/src/crypto/init.rs @@ -739,7 +739,7 @@ mod tests { } #[test] - fn normal_init() { + async fn normal_init() { let (mut sender, mut receiver) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -761,7 +761,7 @@ mod tests { } #[test] - fn lost_init_sender_recovers() { + async fn lost_init_sender_recovers() { let (mut sender, mut receiver) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -794,7 +794,7 @@ mod tests { } #[test] - fn lost_init_receiver_recovers() { + async fn lost_init_receiver_recovers() { let (mut sender, mut receiver) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -826,7 +826,7 @@ mod tests { } #[test] - fn timeout() { + async fn timeout() { let (mut sender, _receiver) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -841,7 +841,7 @@ mod tests { } #[test] - fn untrusted_peer() { + async fn untrusted_peer() { let (mut sender, _) = create_pair(); let (_, mut receiver) = create_pair(); let mut out = MsgBuffer::new(8); @@ -851,7 +851,7 @@ mod tests { } #[test] - fn manipulated_message() { + async fn manipulated_message() { let (mut sender, mut receiver) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -861,7 +861,7 @@ mod tests { } #[test] - fn connect_to_self() { + async fn connect_to_self() { let (mut sender, _) = create_pair(); let mut out = MsgBuffer::new(8); sender.send_ping(&mut out); @@ -889,7 +889,7 @@ mod tests { } #[test] - fn algorithm_negotiation() { + async fn algorithm_negotiation() { // Equal algorithms test_algorithm_negotiation( Algorithms { diff --git a/src/crypto/rotate.rs b/src/crypto/rotate.rs index e47aa59..7355153 100644 --- a/src/crypto/rotate.rs +++ b/src/crypto/rotate.rs @@ -230,7 +230,7 @@ mod tests { } #[test] - fn test_encode_decode_message() { + async fn test_encode_decode_message() { let mut data = Vec::with_capacity(100); let (_, key) = RotationState::create_key(); let msg = RotationMessage { message_id: 1, propose: key, confirm: None }; @@ -251,7 +251,7 @@ mod tests { } #[test] - fn test_normal_rotation() { + async fn test_normal_rotation() { let mut out1 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8); @@ -325,7 +325,7 @@ mod tests { } #[test] - fn test_duplication() { + async fn test_duplication() { let mut out1 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8); @@ -361,7 +361,7 @@ mod tests { } #[test] - fn test_lost_message() { + async fn test_lost_message() { let mut out1 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8); @@ -387,7 +387,7 @@ mod tests { } #[test] - fn test_reflect_back() { + async fn test_reflect_back() { let mut out1 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8); diff --git a/src/device.rs b/src/device.rs index 2404871..bfaf8c4 100644 --- a/src/device.rs +++ b/src/device.rs @@ -2,20 +2,22 @@ // Copyright (C) 2015-2021 Dennis Schwerdel // This software is licensed under GPL-3 or newer (see LICENSE.md) +use async_trait::async_trait; use parking_lot::Mutex; use std::{ cmp, collections::VecDeque, convert::TryInto, fmt, - fs::{self, File}, - io::{self, BufRead, BufReader, Cursor, Error as IoError, Read, Write}, + io::{self, Cursor, Error as IoError}, net::{Ipv4Addr, UdpSocket}, - os::unix::io::{AsRawFd, RawFd}, + os::unix::io::AsRawFd, str, str::FromStr, - sync::Arc + sync::Arc, }; +use tokio::fs::{self, File}; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use crate::{crypto, error::Error, util::MsgBuffer}; @@ -26,13 +28,13 @@ union IfReqData { flags: libc::c_short, value: libc::c_int, addr: (libc::c_short, Ipv4Addr), - _dummy: [u8; 24] + _dummy: [u8; 24], } #[repr(C)] struct IfReq { ifr_name: [u8; libc::IF_NAMESIZE], - data: IfReqData + data: IfReqData, } impl IfReq { @@ -44,7 +46,6 @@ impl IfReq { } } - /// The type of a tun/tap device #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq)] pub enum Type { @@ -53,14 +54,14 @@ pub enum Type { Tun, /// Tap interface: This interface transports Ethernet frames. #[serde(rename = "tap")] - Tap + Tap, } impl fmt::Display for Type { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { Type::Tun => write!(formatter, "tun"), - Type::Tap => write!(formatter, "tap") + Type::Tap => write!(formatter, "tap"), } } } @@ -72,12 +73,13 @@ impl FromStr for Type { Ok(match &text.to_lowercase() as &str { "tun" => Self::Tun, "tap" => Self::Tap, - _ => return Err("Unknown device type") + _ => return Err("Unknown device type"), }) } } -pub trait Device: AsRawFd + Clone + Send + 'static { +#[async_trait] +pub trait Device: Send + 'static + Sized { /// Returns the type of this device fn get_type(&self) -> Type; @@ -95,7 +97,7 @@ pub trait Device: AsRawFd + Clone + Send + 'static { /// /// # Errors /// This method will return an error if the underlying read call fails. - fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; + async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; /// Writes a packet/frame to the device /// @@ -106,28 +108,18 @@ pub trait Device: AsRawFd + Clone + Send + 'static { /// /// # Errors /// This method will return an error if the underlying read call fails. - fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; + async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; + + async fn duplicate(&self) -> Result; fn get_ip(&self) -> Result; } - /// Represents a tun/tap device pub struct TunTapDevice { fd: File, ifname: String, - type_: Type -} - - -impl Clone for TunTapDevice { - fn clone(&self) -> Self { - Self { - fd: try_fail!(self.fd.try_clone(), "Failed to clone device: {}"), - ifname: self.ifname.clone(), - type_: self.type_ - } - } + type_: Type, } impl TunTapDevice { @@ -149,12 +141,12 @@ impl TunTapDevice { /// # Panics /// This method panics if the interface name is longer than 31 bytes. #[allow(clippy::useless_conversion)] - pub fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result { + pub async fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result { let path = path.unwrap_or_else(|| Self::default_path(type_)); - let fd = fs::OpenOptions::new().read(true).write(true).open(path)?; + let fd = fs::OpenOptions::new().read(true).write(true).open(path).await?; let flags = match type_ { Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI, - Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI + Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI, }; let mut ifreq = IfReq::new(ifname); ifreq.data.flags = flags as libc::c_short; @@ -163,11 +155,11 @@ impl TunTapDevice { 0 => { let mut ifname = String::with_capacity(32); let mut cursor = Cursor::new(ifreq.ifr_name); - cursor.read_to_string(&mut ifname)?; + cursor.read_to_string(&mut ifname).await?; ifname = ifname.trim_end_matches('\0').to_owned(); Ok(Self { fd, ifname, type_ }) } - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -175,7 +167,7 @@ impl TunTapDevice { #[inline] pub fn default_path(type_: Type) -> &'static str { match type_ { - Type::Tun | Type::Tap => "/dev/net/tun" + Type::Tun | Type::Tap => "/dev/net/tun", } } @@ -223,7 +215,7 @@ impl TunTapDevice { // IP version 4 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x08, 0x00]), 6 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x86, 0xdd]), - _ => unreachable!() + _ => unreachable!(), } } } @@ -239,11 +231,11 @@ impl TunTapDevice { } } - pub fn set_mtu(&self, value: Option) -> io::Result<()> { + pub async fn set_mtu(&self, value: Option) -> io::Result<()> { let value = match value { Some(value) => value, None => { - let default_device = get_default_device()?; + let default_device = get_default_device().await?; get_device_mtu(&default_device)? - self.get_overhead() } }; @@ -257,23 +249,24 @@ impl TunTapDevice { set_device_enabled(&self.ifname, true) } - pub fn get_rp_filter(&self) -> io::Result { - Ok(cmp::max(get_rp_filter("all")?, get_rp_filter(&self.ifname)?)) + pub async fn get_rp_filter(&self) -> io::Result { + Ok(cmp::max(get_rp_filter("all").await?, get_rp_filter(&self.ifname).await?)) } - pub fn fix_rp_filter(&self) -> io::Result<()> { - if get_rp_filter("all")? > 1 { + pub async fn fix_rp_filter(&self) -> io::Result<()> { + if get_rp_filter("all").await? > 1 { info!("Setting net.ipv4.conf.all.rp_filter=1"); - set_rp_filter("all", 1)? + set_rp_filter("all", 1).await? } - if get_rp_filter(&self.ifname)? != 1 { + if get_rp_filter(&self.ifname).await? != 1 { info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname); - set_rp_filter(&self.ifname, 1)? + set_rp_filter(&self.ifname, 1).await? } Ok(()) } } +#[async_trait] impl Device for TunTapDevice { fn get_type(&self) -> Type { self.type_ @@ -283,19 +276,27 @@ impl Device for TunTapDevice { &self.ifname } - fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { + async fn duplicate(&self) -> Result { + Ok(Self { + fd: self.fd.try_clone().await.map_err(|e| Error::DeviceIo("Failed to clone device", e))?, + ifname: self.ifname.clone(), + type_: self.type_, + }) + } + + async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { buffer.clear(); - let read = self.fd.read(buffer.buffer()).map_err(|e| Error::DeviceIo("Read error", e))?; + let read = self.fd.read(buffer.buffer()).await.map_err(|e| Error::DeviceIo("Read error", e))?; buffer.set_length(read); self.correct_data_after_read(buffer); Ok(()) } - fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { + async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { self.correct_data_before_write(buffer); - match self.fd.write_all(buffer.message()) { - Ok(_) => self.fd.flush().map_err(|e| Error::DeviceIo("Flush error", e)), - Err(e) => Err(Error::DeviceIo("Write error", e)) + match self.fd.write_all(buffer.message()).await { + Ok(_) => self.fd.flush().await.map_err(|e| Error::DeviceIo("Flush error", e)), + Err(e) => Err(Error::DeviceIo("Write error", e)), } } @@ -304,18 +305,10 @@ impl Device for TunTapDevice { } } -impl AsRawFd for TunTapDevice { - #[inline] - fn as_raw_fd(&self) -> RawFd { - self.fd.as_raw_fd() - } -} - - #[derive(Clone)] pub struct MockDevice { inbound: Arc>>>, - outbound: Arc>>> + outbound: Arc>>>, } impl MockDevice { @@ -336,7 +329,12 @@ impl MockDevice { } } +#[async_trait] impl Device for MockDevice { + async fn duplicate(&self) -> Result { + Ok(self.clone()) + } + fn get_type(&self) -> Type { Type::Tun } @@ -345,7 +343,7 @@ impl Device for MockDevice { "mock0" } - fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { + async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { if let Some(data) = self.inbound.lock().pop_front() { buffer.clear(); buffer.set_length(data.len()); @@ -356,7 +354,7 @@ impl Device for MockDevice { } } - fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { + async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { self.outbound.lock().push_back(buffer.message().into()); Ok(()) } @@ -370,19 +368,11 @@ impl Default for MockDevice { fn default() -> Self { Self { outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), - inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))) + inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), } } } -impl AsRawFd for MockDevice { - #[inline] - fn as_raw_fd(&self) -> RawFd { - unimplemented!() - } -} - - #[allow(clippy::useless_conversion)] fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> { let sock = UdpSocket::bind("0.0.0.0:0")?; @@ -391,7 +381,7 @@ fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> { let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU.try_into().unwrap(), &mut ifreq) }; match res { 0 => Ok(()), - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -402,7 +392,7 @@ fn get_device_mtu(ifname: &str) -> io::Result { let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFMTU.try_into().unwrap(), &mut ifreq) }; match res { 0 => Ok(unsafe { ifreq.data.value as usize }), - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -415,12 +405,12 @@ fn get_device_addr(ifname: &str) -> io::Result { 0 => { let af = unsafe { ifreq.data.addr.0 }; if af as libc::c_int != libc::AF_INET { - return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())) + return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())); } let ip = unsafe { ifreq.data.addr.1 }; Ok(ip) } - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -432,7 +422,7 @@ fn set_device_addr(ifname: &str, addr: Ipv4Addr) -> io::Result<()> { let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFADDR.try_into().unwrap(), &mut ifreq) }; match res { 0 => Ok(()), - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -446,12 +436,12 @@ fn get_device_netmask(ifname: &str) -> io::Result { 0 => { let af = unsafe { ifreq.data.addr.0 }; if af as libc::c_int != libc::AF_INET { - return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())) + return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())); } let ip = unsafe { ifreq.data.addr.1 }; Ok(ip) } - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -463,7 +453,7 @@ fn set_device_netmask(ifname: &str, addr: Ipv4Addr) -> io::Result<()> { let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFNETMASK.try_into().unwrap(), &mut ifreq) }; match res { 0 => Ok(()), - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } @@ -472,7 +462,7 @@ fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> { let sock = UdpSocket::bind("0.0.0.0:0")?; let mut ifreq = IfReq::new(ifname); if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFFLAGS.try_into().unwrap(), &mut ifreq) } != 0 { - return Err(IoError::last_os_error()) + return Err(IoError::last_os_error()); } if up { unsafe { ifreq.data.value |= libc::IFF_UP | libc::IFF_RUNNING } @@ -482,20 +472,22 @@ fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> { let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFFLAGS.try_into().unwrap(), &mut ifreq) }; match res { 0 => Ok(()), - _ => Err(IoError::last_os_error()) + _ => Err(IoError::last_os_error()), } } - -fn get_default_device() -> io::Result { - let fd = BufReader::new(File::open("/proc/net/route")?); +async fn get_default_device() -> io::Result { + let mut fd = BufReader::new(File::open("/proc/net/route").await?); let mut best = None; - for line in fd.lines() { - let line = line?; + let mut line = String::with_capacity(80); + while let Ok(read) = fd.read_line(&mut line).await { + if read == 0 { + break; + } let parts = line.split('\t').collect::>(); if parts[1] == "00000000" { best = Some(parts[0].to_string()); - break + break; } if parts[2] != "00000000" { best = Some(parts[0].to_string()) @@ -508,14 +500,14 @@ fn get_default_device() -> io::Result { } } -fn get_rp_filter(device: &str) -> io::Result { - let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?; +async fn get_rp_filter(device: &str) -> io::Result { + let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?; let mut contents = String::with_capacity(10); - fd.read_to_string(&mut contents)?; + fd.read_to_string(&mut contents).await?; u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value")) } -fn set_rp_filter(device: &str, val: u8) -> io::Result<()> { - let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?; - writeln!(fd, "{}", val) +async fn set_rp_filter(device: &str, val: u8) -> io::Result<()> { + let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?; + fd.write_all(format!("{}", val).as_bytes()).await } diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index 8cd7af8..89134a7 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -1,17 +1,18 @@ use super::{ shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, - SPACE_BEFORE + SPACE_BEFORE, }; use crate::{ + config::Config, device::Device, error::Error, messages::MESSAGE_TYPE_DATA, net::Socket, util::{MsgBuffer, Time, TimeSource}, Protocol, - config::Config }; use std::{marker::PhantomData, net::SocketAddr}; +use tokio::time::timeout; pub struct DeviceThread { // Read-only fields @@ -20,16 +21,21 @@ pub struct DeviceThread { broadcast: bool, // Device-only fields socket: S, - device: D, + pub device: D, next_housekeep: Time, + buffer: MsgBuffer, + broadcast_buffer: MsgBuffer, // Shared fields traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, - table: SharedTable + table: SharedTable, } impl DeviceThread { - pub fn new(config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, table: SharedTable) -> Self { + pub fn new( + config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, + table: SharedTable, + ) -> Self { Self { _dummy_ts: PhantomData, _dummy_p: PhantomData, @@ -39,97 +45,106 @@ impl DeviceThread Result<(), Error> { - debug!("Sending msg with {} bytes to {}", msg.len(), addr); - self.traffic.count_out_traffic(addr, msg.len()); - match self.socket.send(msg.message(), addr) { - Ok(written) if written == msg.len() => Ok(()), + async fn send_to(&mut self, addr: SocketAddr) -> Result<(), Error> { + let size = self.buffer.len(); + debug!("Sending msg with {} bytes to {}", size, addr); + self.traffic.count_out_traffic(addr, size); + match self.socket.send(self.buffer.message(), addr).await { + Ok(written) if written == size => Ok(()), Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)) + Err(e) => Err(Error::SocketIo("IOError when sending", e)), } } #[inline] - fn send_msg(&mut self, addr: SocketAddr, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { - debug!("Sending msg with {} bytes to {}", msg.len(), addr); - msg.prepend_byte(type_); - self.peer_crypto.encrypt_for(addr, msg)?; - self.send_to(addr, msg) + async fn send_msg(&mut self, addr: SocketAddr, type_: u8) -> Result<(), Error> { + debug!("Sending msg with {} bytes to {}", self.buffer.len(), addr); + self.buffer.prepend_byte(type_); + self.peer_crypto.encrypt_for(addr, &mut self.buffer)?; + self.send_to(addr).await } #[inline] - fn broadcast_msg(&mut self, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { - debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peer_crypto.count()); - let mut msg_data = MsgBuffer::new(100); + async fn broadcast_msg(&mut self, type_: u8) -> Result<(), Error> { + let size = self.buffer.len(); + debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, size, self.peer_crypto.count()); let traffic = &mut self.traffic; let socket = &mut self.socket; - self.peer_crypto.for_each(|addr, crypto| { - msg_data.set_start(msg.get_start()); - msg_data.set_length(msg.len()); - msg_data.message_mut().clone_from_slice(msg.message()); - msg_data.prepend_byte(type_); + let peers = self.peer_crypto.get_snapshot(); + for (addr, crypto) in peers { + self.broadcast_buffer.set_start(self.buffer.get_start()); + self.broadcast_buffer.set_length(self.buffer.len()); + self.broadcast_buffer.message_mut().clone_from_slice(self.buffer.message()); + self.broadcast_buffer.prepend_byte(type_); if let Some(crypto) = crypto { - crypto.encrypt(&mut msg_data); + crypto.encrypt(&mut self.broadcast_buffer); } - traffic.count_out_traffic(addr, msg_data.len()); - match socket.send(msg_data.message(), addr) { - Ok(written) if written == msg_data.len() => Ok(()), + traffic.count_out_traffic(addr, self.broadcast_buffer.len()); + match socket.send(self.broadcast_buffer.message(), addr).await { + Ok(written) if written == self.broadcast_buffer.len() => Ok(()), Ok(_) => Err(Error::Socket("Sent out truncated packet")), - Err(e) => Err(Error::SocketIo("IOError when sending", e)) - } - }) + Err(e) => Err(Error::SocketIo("IOError when sending", e)), + }? + } + Ok(()) } - fn forward_packet(&mut self, data: &mut MsgBuffer) -> Result<(), Error> { - let (src, dst) = P::parse(data.message())?; - debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, data.len()); - self.traffic.count_out_payload(dst, src, data.len()); + async fn forward_packet(&mut self) -> Result<(), Error> { + let (src, dst) = P::parse(self.buffer.message())?; + debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, self.buffer.len()); + self.traffic.count_out_payload(dst, src, self.buffer.len()); match self.table.lookup(dst) { Some(addr) => { // Peer found for destination debug!("Found destination for {} => {}", dst, addr); - self.send_msg(addr, MESSAGE_TYPE_DATA, data)?; + self.send_msg(addr, MESSAGE_TYPE_DATA).await?; } None => { if self.broadcast { debug!("No destination for {} found, broadcasting", dst); - self.broadcast_msg(MESSAGE_TYPE_DATA, data)?; + self.broadcast_msg(MESSAGE_TYPE_DATA).await?; } else { debug!("No destination for {} found, dropping", dst); - self.traffic.count_dropped_payload(data.len()); + self.traffic.count_dropped_payload(self.buffer.len()); } } } Ok(()) } - fn housekeep(&mut self) -> Result<(), Error> { + pub async fn housekeep(&mut self) -> Result<(), Error> { self.peer_crypto.sync(); self.table.sync(); self.traffic.sync(); - unimplemented!(); + Ok(()) } - pub fn run(mut self) { - let mut buffer = MsgBuffer::new(SPACE_BEFORE); - loop { - try_fail!(self.device.read(&mut buffer), "Failed to read from device: {}"); - //TODO: set and handle timeout - if let Err(e) = self.forward_packet(&mut buffer) { + pub async fn iteration(&mut self) { + if let Ok(result) = timeout(std::time::Duration::from_millis(1000), self.device.read(&mut self.buffer)).await { + try_fail!(result, "Failed to read from device: {}"); + if let Err(e) = self.forward_packet().await { error!("{}", e); } - let now = TS::now(); - if self.next_housekeep < now { - if let Err(e) = self.housekeep() { - error!("{}", e) - } - self.next_housekeep = now + 1 + } + let now = TS::now(); + if self.next_housekeep < now { + if let Err(e) = self.housekeep().await { + error!("{}", e) } + self.next_housekeep = now + 1 + } + } + + pub async fn run(mut self) { + loop { + self.iteration().await } } } diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 7ad1464..5225764 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -6,7 +6,8 @@ mod device_thread; mod shared; mod socket_thread; -use std::{fs::File, hash::BuildHasherDefault, thread}; +use std::{fs::File, hash::BuildHasherDefault}; +use tokio; use fnv::FnvHasher; @@ -17,7 +18,7 @@ use crate::{ engine::{ device_thread::DeviceThread, shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, - socket_thread::SocketThread + socket_thread::SocketThread, }, error::Error, messages::AddrList, @@ -25,7 +26,7 @@ use crate::{ payload::Protocol, port_forwarding::PortForwarding, types::NodeId, - util::{addr_nice, resolve, CtrlC, Time, TimeSource} + util::{addr_nice, resolve, CtrlC, Time, TimeSource}, }; pub type Hash = BuildHasherDefault; @@ -40,7 +41,7 @@ struct PeerData { timeout: Time, peer_timeout: u16, node_id: NodeId, - crypto: PeerCrypto + crypto: PeerCrypto, } #[derive(Clone)] @@ -50,30 +51,29 @@ pub struct ReconnectEntry { tries: u16, timeout: u16, next: Time, - final_timeout: Option