Using tokio

This commit is contained in:
Dennis Schwerdel 2021-03-02 15:31:40 +00:00
parent 4653ed02eb
commit 87e715f475
23 changed files with 744 additions and 623 deletions

102
Cargo.lock generated
View File

@ -9,6 +9,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "async-trait"
version = "0.1.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "attohttpc" name = "attohttpc"
version = "0.16.3" version = "0.16.3"
@ -571,6 +582,29 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "mio"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5dede4e2065b3842b8b0af444119f3aa331cc7cc2dd20388bfb0f5d5a38823a"
dependencies = [
"libc",
"log",
"miow",
"ntapi",
"winapi",
]
[[package]]
name = "miow"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a33c1b55807fbed163481b5ba66db4b2fa6cde694a5027be10fb724206c5897"
dependencies = [
"socket2",
"winapi",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.14.1" version = "0.14.1"
@ -596,6 +630,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "ntapi"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.14" version = "0.2.14"
@ -664,6 +707,12 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
[[package]]
name = "pin-project-lite"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439697af366c49a6d0a010c56a0d97685bc140ce0d377b13a2ea2aa42d64a827"
[[package]] [[package]]
name = "plotters" name = "plotters"
version = "0.3.0" version = "0.3.0"
@ -1017,12 +1066,32 @@ dependencies = [
"nix 0.14.1", "nix 0.14.1",
] ]
[[package]]
name = "signal-hook-registry"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.6.1" version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
[[package]]
name = "socket2"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e"
dependencies = [
"cfg-if 1.0.0",
"libc",
"winapi",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.5.2"
@ -1244,6 +1313,37 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokio"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8190d04c665ea9e6b6a0dc45523ade572c088d2e6566244c1122671dbf4ae3a"
dependencies = [
"autocfg",
"bytes",
"libc",
"memchr",
"mio",
"num_cpus",
"once_cell",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"tokio-macros",
"winapi",
]
[[package]]
name = "tokio-macros"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "tungstenite" name = "tungstenite"
version = "0.13.0" version = "0.13.0"
@ -1352,6 +1452,7 @@ checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
name = "vpncloud" name = "vpncloud"
version = "2.1.0" version = "2.1.0"
dependencies = [ dependencies = [
"async-trait",
"byteorder", "byteorder",
"criterion", "criterion",
"daemonize", "daemonize",
@ -1373,6 +1474,7 @@ dependencies = [
"tempfile", "tempfile",
"thiserror", "thiserror",
"time", "time",
"tokio",
"tungstenite", "tungstenite",
"url", "url",
"yaml-rust", "yaml-rust",

View File

@ -37,6 +37,8 @@ dialoguer = { version = "0.7", optional = true }
tungstenite = { version = "0.13", optional = true, default-features = false } tungstenite = { version = "0.13", optional = true, default-features = false }
url = { version = "2.2", optional = true } url = { version = "2.2", optional = true }
igd = { version = "0.12", optional = true } igd = { version = "0.12", optional = true }
tokio = { version = "1", features = ["full"] }
async-trait = "0.1"
[dev-dependencies] [dev-dependencies]
tempfile = "3" tempfile = "3"

View File

@ -322,7 +322,7 @@ impl<TS: TimeSource> BeaconSerializer<TS> {
#[cfg(test)] use std::time::Duration; #[cfg(test)] use std::time::Duration;
#[test] #[test]
fn encode() { async fn encode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = vec![ let mut peers = vec![
@ -340,7 +340,7 @@ fn encode() {
} }
#[test] #[test]
fn decode() { async fn decode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = vec![ let mut peers = vec![
@ -356,7 +356,7 @@ fn decode() {
} }
#[test] #[test]
fn decode_split() { async fn decode_split() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![
@ -374,7 +374,7 @@ fn decode_split() {
} }
#[test] #[test]
fn decode_offset() { async fn decode_offset() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![
@ -388,7 +388,7 @@ fn decode_offset() {
} }
#[test] #[test]
fn decode_multiple() { async fn decode_multiple() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![
@ -402,7 +402,7 @@ fn decode_multiple() {
} }
#[test] #[test]
fn decode_ttl() { async fn decode_ttl() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
@ -426,7 +426,7 @@ fn decode_ttl() {
} }
#[test] #[test]
fn decode_invalid() { async fn decode_invalid() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
assert_eq!(0, ser.decode("", None).len()); assert_eq!(0, ser.decode("", None).len());
@ -440,7 +440,7 @@ fn decode_invalid() {
#[test] #[test]
fn encode_decode() { async fn encode_decode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![
@ -453,7 +453,7 @@ fn encode_decode() {
} }
#[test] #[test]
fn encode_decode_file() { async fn encode_decode_file() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![
@ -468,7 +468,7 @@ fn encode_decode_file() {
} }
#[test] #[test]
fn encode_decode_cmd() { async fn encode_decode_cmd() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let peers = vec![ let peers = vec![

View File

@ -669,7 +669,7 @@ pub struct ConfigFile {
} }
#[test] #[test]
fn config_file() { async fn config_file() {
let config_file = " let config_file = "
device: device:
type: tun type: tun
@ -744,12 +744,12 @@ statsd:
} }
#[test] #[test]
fn parse_example_config() { async fn parse_example_config() {
serde_yaml::from_str::<ConfigFile>(include_str!("../assets/example.net.disabled")).unwrap(); serde_yaml::from_str::<ConfigFile>(include_str!("../assets/example.net.disabled")).unwrap();
} }
#[test] #[test]
fn config_merge() { async fn config_merge() {
let mut config = Config::default(); let mut config = Config::default();
config.merge_file(ConfigFile { config.merge_file(ConfigFile {
device: Some(ConfigFileDevice { device: Some(ConfigFileDevice {

View File

@ -339,7 +339,7 @@ mod tests {
} }
#[test] #[test]
fn normal() { async fn normal() {
let config = Config { password: Some("test".to_string()), ..Default::default() }; let config = Config { password: Some("test".to_string()), ..Default::default() };
let mut node1 = create_node(&config); let mut node1 = create_node(&config);
let mut node2 = create_node(&config); let mut node2 = create_node(&config);

View File

@ -280,7 +280,7 @@ mod tests {
use ring::aead::{self, LessSafeKey, UnboundKey}; use ring::aead::{self, LessSafeKey, UnboundKey};
#[test] #[test]
fn test_nonce() { async fn test_nonce() {
let mut nonce = Nonce::zero(); let mut nonce = Nonce::zero();
assert_eq!(nonce.as_bytes(), &[0; 12]); assert_eq!(nonce.as_bytes(), &[0; 12]);
nonce.increment(); nonce.increment();
@ -302,17 +302,17 @@ mod tests {
} }
#[test] #[test]
fn test_encrypt_decrypt_aes128() { async fn test_encrypt_decrypt_aes128() {
test_encrypt_decrypt(&aead::AES_128_GCM) test_encrypt_decrypt(&aead::AES_128_GCM)
} }
#[test] #[test]
fn test_encrypt_decrypt_aes256() { async fn test_encrypt_decrypt_aes256() {
test_encrypt_decrypt(&aead::AES_256_GCM) test_encrypt_decrypt(&aead::AES_256_GCM)
} }
#[test] #[test]
fn test_encrypt_decrypt_chacha() { async fn test_encrypt_decrypt_chacha() {
test_encrypt_decrypt(&aead::CHACHA20_POLY1305) test_encrypt_decrypt(&aead::CHACHA20_POLY1305)
} }
@ -343,17 +343,17 @@ mod tests {
} }
#[test] #[test]
fn test_tampering_aes128() { async fn test_tampering_aes128() {
test_tampering(&aead::AES_128_GCM) test_tampering(&aead::AES_128_GCM)
} }
#[test] #[test]
fn test_tampering_aes256() { async fn test_tampering_aes256() {
test_tampering(&aead::AES_256_GCM) test_tampering(&aead::AES_256_GCM)
} }
#[test] #[test]
fn test_tampering_chacha() { async fn test_tampering_chacha() {
test_tampering(&aead::CHACHA20_POLY1305) test_tampering(&aead::CHACHA20_POLY1305)
} }
@ -384,17 +384,17 @@ mod tests {
} }
#[test] #[test]
fn test_nonce_pinning_aes128() { async fn test_nonce_pinning_aes128() {
test_nonce_pinning(&aead::AES_128_GCM) test_nonce_pinning(&aead::AES_128_GCM)
} }
#[test] #[test]
fn test_nonce_pinning_aes256() { async fn test_nonce_pinning_aes256() {
test_nonce_pinning(&aead::AES_256_GCM) test_nonce_pinning(&aead::AES_256_GCM)
} }
#[test] #[test]
fn test_nonce_pinning_chacha() { async fn test_nonce_pinning_chacha() {
test_nonce_pinning(&aead::CHACHA20_POLY1305) test_nonce_pinning(&aead::CHACHA20_POLY1305)
} }
@ -433,40 +433,40 @@ mod tests {
} }
#[test] #[test]
fn test_key_rotation_aes128() { async fn test_key_rotation_aes128() {
test_key_rotation(&aead::AES_128_GCM); test_key_rotation(&aead::AES_128_GCM);
} }
#[test] #[test]
fn test_key_rotation_aes256() { async fn test_key_rotation_aes256() {
test_key_rotation(&aead::AES_256_GCM); test_key_rotation(&aead::AES_256_GCM);
} }
#[test] #[test]
fn test_key_rotation_chacha() { async fn test_key_rotation_chacha() {
test_key_rotation(&aead::CHACHA20_POLY1305); test_key_rotation(&aead::CHACHA20_POLY1305);
} }
#[test] #[test]
fn test_core_size() { async fn test_core_size() {
assert_eq!(2384, mem::size_of::<CryptoCore>()); assert_eq!(2400, mem::size_of::<CryptoCore>());
} }
#[test] #[test]
fn test_speed_aes128() { async fn test_speed_aes128() {
let speed = test_speed(&aead::AES_128_GCM, &Duration::from_secs_f32(0.2)); let speed = test_speed(&aead::AES_128_GCM, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0); assert!(speed > 10.0);
} }
#[test] #[test]
fn test_speed_aes256() { async fn test_speed_aes256() {
let speed = test_speed(&aead::AES_256_GCM, &Duration::from_secs_f32(0.2)); let speed = test_speed(&aead::AES_256_GCM, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0); assert!(speed > 10.0);
} }
#[test] #[test]
fn test_speed_chacha() { async fn test_speed_chacha() {
let speed = test_speed(&aead::CHACHA20_POLY1305, &Duration::from_secs_f32(0.2)); let speed = test_speed(&aead::CHACHA20_POLY1305, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0); assert!(speed > 10.0);
} }

View File

@ -739,7 +739,7 @@ mod tests {
} }
#[test] #[test]
fn normal_init() { async fn normal_init() {
let (mut sender, mut receiver) = create_pair(); let (mut sender, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -761,7 +761,7 @@ mod tests {
} }
#[test] #[test]
fn lost_init_sender_recovers() { async fn lost_init_sender_recovers() {
let (mut sender, mut receiver) = create_pair(); let (mut sender, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -794,7 +794,7 @@ mod tests {
} }
#[test] #[test]
fn lost_init_receiver_recovers() { async fn lost_init_receiver_recovers() {
let (mut sender, mut receiver) = create_pair(); let (mut sender, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -826,7 +826,7 @@ mod tests {
} }
#[test] #[test]
fn timeout() { async fn timeout() {
let (mut sender, _receiver) = create_pair(); let (mut sender, _receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -841,7 +841,7 @@ mod tests {
} }
#[test] #[test]
fn untrusted_peer() { async fn untrusted_peer() {
let (mut sender, _) = create_pair(); let (mut sender, _) = create_pair();
let (_, mut receiver) = create_pair(); let (_, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
@ -851,7 +851,7 @@ mod tests {
} }
#[test] #[test]
fn manipulated_message() { async fn manipulated_message() {
let (mut sender, mut receiver) = create_pair(); let (mut sender, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -861,7 +861,7 @@ mod tests {
} }
#[test] #[test]
fn connect_to_self() { async fn connect_to_self() {
let (mut sender, _) = create_pair(); let (mut sender, _) = create_pair();
let mut out = MsgBuffer::new(8); let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out); sender.send_ping(&mut out);
@ -889,7 +889,7 @@ mod tests {
} }
#[test] #[test]
fn algorithm_negotiation() { async fn algorithm_negotiation() {
// Equal algorithms // Equal algorithms
test_algorithm_negotiation( test_algorithm_negotiation(
Algorithms { Algorithms {

View File

@ -230,7 +230,7 @@ mod tests {
} }
#[test] #[test]
fn test_encode_decode_message() { async fn test_encode_decode_message() {
let mut data = Vec::with_capacity(100); let mut data = Vec::with_capacity(100);
let (_, key) = RotationState::create_key(); let (_, key) = RotationState::create_key();
let msg = RotationMessage { message_id: 1, propose: key, confirm: None }; let msg = RotationMessage { message_id: 1, propose: key, confirm: None };
@ -251,7 +251,7 @@ mod tests {
} }
#[test] #[test]
fn test_normal_rotation() { async fn test_normal_rotation() {
let mut out1 = MsgBuffer::new(8); let mut out1 = MsgBuffer::new(8);
let mut out2 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8);
@ -325,7 +325,7 @@ mod tests {
} }
#[test] #[test]
fn test_duplication() { async fn test_duplication() {
let mut out1 = MsgBuffer::new(8); let mut out1 = MsgBuffer::new(8);
let mut out2 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8);
@ -361,7 +361,7 @@ mod tests {
} }
#[test] #[test]
fn test_lost_message() { async fn test_lost_message() {
let mut out1 = MsgBuffer::new(8); let mut out1 = MsgBuffer::new(8);
let mut out2 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8);
@ -387,7 +387,7 @@ mod tests {
} }
#[test] #[test]
fn test_reflect_back() { async fn test_reflect_back() {
let mut out1 = MsgBuffer::new(8); let mut out1 = MsgBuffer::new(8);
let mut out2 = MsgBuffer::new(8); let mut out2 = MsgBuffer::new(8);

View File

@ -2,20 +2,22 @@
// Copyright (C) 2015-2021 Dennis Schwerdel // Copyright (C) 2015-2021 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
use async_trait::async_trait;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{ use std::{
cmp, cmp,
collections::VecDeque, collections::VecDeque,
convert::TryInto, convert::TryInto,
fmt, fmt,
fs::{self, File}, io::{self, Cursor, Error as IoError},
io::{self, BufRead, BufReader, Cursor, Error as IoError, Read, Write},
net::{Ipv4Addr, UdpSocket}, net::{Ipv4Addr, UdpSocket},
os::unix::io::{AsRawFd, RawFd}, os::unix::io::AsRawFd,
str, str,
str::FromStr, str::FromStr,
sync::Arc sync::Arc,
}; };
use tokio::fs::{self, File};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use crate::{crypto, error::Error, util::MsgBuffer}; use crate::{crypto, error::Error, util::MsgBuffer};
@ -26,13 +28,13 @@ union IfReqData {
flags: libc::c_short, flags: libc::c_short,
value: libc::c_int, value: libc::c_int,
addr: (libc::c_short, Ipv4Addr), addr: (libc::c_short, Ipv4Addr),
_dummy: [u8; 24] _dummy: [u8; 24],
} }
#[repr(C)] #[repr(C)]
struct IfReq { struct IfReq {
ifr_name: [u8; libc::IF_NAMESIZE], ifr_name: [u8; libc::IF_NAMESIZE],
data: IfReqData data: IfReqData,
} }
impl IfReq { impl IfReq {
@ -44,7 +46,6 @@ impl IfReq {
} }
} }
/// The type of a tun/tap device /// The type of a tun/tap device
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq)]
pub enum Type { pub enum Type {
@ -53,14 +54,14 @@ pub enum Type {
Tun, Tun,
/// Tap interface: This interface transports Ethernet frames. /// Tap interface: This interface transports Ethernet frames.
#[serde(rename = "tap")] #[serde(rename = "tap")]
Tap Tap,
} }
impl fmt::Display for Type { impl fmt::Display for Type {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self { match *self {
Type::Tun => write!(formatter, "tun"), Type::Tun => write!(formatter, "tun"),
Type::Tap => write!(formatter, "tap") Type::Tap => write!(formatter, "tap"),
} }
} }
} }
@ -72,12 +73,13 @@ impl FromStr for Type {
Ok(match &text.to_lowercase() as &str { Ok(match &text.to_lowercase() as &str {
"tun" => Self::Tun, "tun" => Self::Tun,
"tap" => Self::Tap, "tap" => Self::Tap,
_ => return Err("Unknown device type") _ => return Err("Unknown device type"),
}) })
} }
} }
pub trait Device: AsRawFd + Clone + Send + 'static { #[async_trait]
pub trait Device: Send + 'static + Sized {
/// Returns the type of this device /// Returns the type of this device
fn get_type(&self) -> Type; fn get_type(&self) -> Type;
@ -95,7 +97,7 @@ pub trait Device: AsRawFd + Clone + Send + 'static {
/// ///
/// # Errors /// # Errors
/// This method will return an error if the underlying read call fails. /// This method will return an error if the underlying read call fails.
fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>;
/// Writes a packet/frame to the device /// Writes a packet/frame to the device
/// ///
@ -106,28 +108,18 @@ pub trait Device: AsRawFd + Clone + Send + 'static {
/// ///
/// # Errors /// # Errors
/// This method will return an error if the underlying read call fails. /// This method will return an error if the underlying read call fails.
fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>; async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>;
async fn duplicate(&self) -> Result<Self, Error>;
fn get_ip(&self) -> Result<Ipv4Addr, Error>; fn get_ip(&self) -> Result<Ipv4Addr, Error>;
} }
/// Represents a tun/tap device /// Represents a tun/tap device
pub struct TunTapDevice { pub struct TunTapDevice {
fd: File, fd: File,
ifname: String, ifname: String,
type_: Type type_: Type,
}
impl Clone for TunTapDevice {
fn clone(&self) -> Self {
Self {
fd: try_fail!(self.fd.try_clone(), "Failed to clone device: {}"),
ifname: self.ifname.clone(),
type_: self.type_
}
}
} }
impl TunTapDevice { impl TunTapDevice {
@ -149,12 +141,12 @@ impl TunTapDevice {
/// # Panics /// # Panics
/// This method panics if the interface name is longer than 31 bytes. /// This method panics if the interface name is longer than 31 bytes.
#[allow(clippy::useless_conversion)] #[allow(clippy::useless_conversion)]
pub fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result<Self> { pub async fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result<Self> {
let path = path.unwrap_or_else(|| Self::default_path(type_)); let path = path.unwrap_or_else(|| Self::default_path(type_));
let fd = fs::OpenOptions::new().read(true).write(true).open(path)?; let fd = fs::OpenOptions::new().read(true).write(true).open(path).await?;
let flags = match type_ { let flags = match type_ {
Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI, Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI,
Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI,
}; };
let mut ifreq = IfReq::new(ifname); let mut ifreq = IfReq::new(ifname);
ifreq.data.flags = flags as libc::c_short; ifreq.data.flags = flags as libc::c_short;
@ -163,11 +155,11 @@ impl TunTapDevice {
0 => { 0 => {
let mut ifname = String::with_capacity(32); let mut ifname = String::with_capacity(32);
let mut cursor = Cursor::new(ifreq.ifr_name); let mut cursor = Cursor::new(ifreq.ifr_name);
cursor.read_to_string(&mut ifname)?; cursor.read_to_string(&mut ifname).await?;
ifname = ifname.trim_end_matches('\0').to_owned(); ifname = ifname.trim_end_matches('\0').to_owned();
Ok(Self { fd, ifname, type_ }) Ok(Self { fd, ifname, type_ })
} }
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -175,7 +167,7 @@ impl TunTapDevice {
#[inline] #[inline]
pub fn default_path(type_: Type) -> &'static str { pub fn default_path(type_: Type) -> &'static str {
match type_ { match type_ {
Type::Tun | Type::Tap => "/dev/net/tun" Type::Tun | Type::Tap => "/dev/net/tun",
} }
} }
@ -223,7 +215,7 @@ impl TunTapDevice {
// IP version // IP version
4 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x08, 0x00]), 4 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x08, 0x00]),
6 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x86, 0xdd]), 6 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x86, 0xdd]),
_ => unreachable!() _ => unreachable!(),
} }
} }
} }
@ -239,11 +231,11 @@ impl TunTapDevice {
} }
} }
pub fn set_mtu(&self, value: Option<usize>) -> io::Result<()> { pub async fn set_mtu(&self, value: Option<usize>) -> io::Result<()> {
let value = match value { let value = match value {
Some(value) => value, Some(value) => value,
None => { None => {
let default_device = get_default_device()?; let default_device = get_default_device().await?;
get_device_mtu(&default_device)? - self.get_overhead() get_device_mtu(&default_device)? - self.get_overhead()
} }
}; };
@ -257,23 +249,24 @@ impl TunTapDevice {
set_device_enabled(&self.ifname, true) set_device_enabled(&self.ifname, true)
} }
pub fn get_rp_filter(&self) -> io::Result<u8> { pub async fn get_rp_filter(&self) -> io::Result<u8> {
Ok(cmp::max(get_rp_filter("all")?, get_rp_filter(&self.ifname)?)) Ok(cmp::max(get_rp_filter("all").await?, get_rp_filter(&self.ifname).await?))
} }
pub fn fix_rp_filter(&self) -> io::Result<()> { pub async fn fix_rp_filter(&self) -> io::Result<()> {
if get_rp_filter("all")? > 1 { if get_rp_filter("all").await? > 1 {
info!("Setting net.ipv4.conf.all.rp_filter=1"); info!("Setting net.ipv4.conf.all.rp_filter=1");
set_rp_filter("all", 1)? set_rp_filter("all", 1).await?
} }
if get_rp_filter(&self.ifname)? != 1 { if get_rp_filter(&self.ifname).await? != 1 {
info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname); info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname);
set_rp_filter(&self.ifname, 1)? set_rp_filter(&self.ifname, 1).await?
} }
Ok(()) Ok(())
} }
} }
#[async_trait]
impl Device for TunTapDevice { impl Device for TunTapDevice {
fn get_type(&self) -> Type { fn get_type(&self) -> Type {
self.type_ self.type_
@ -283,19 +276,27 @@ impl Device for TunTapDevice {
&self.ifname &self.ifname
} }
fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { async fn duplicate(&self) -> Result<Self, Error> {
Ok(Self {
fd: self.fd.try_clone().await.map_err(|e| Error::DeviceIo("Failed to clone device", e))?,
ifname: self.ifname.clone(),
type_: self.type_,
})
}
async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
buffer.clear(); buffer.clear();
let read = self.fd.read(buffer.buffer()).map_err(|e| Error::DeviceIo("Read error", e))?; let read = self.fd.read(buffer.buffer()).await.map_err(|e| Error::DeviceIo("Read error", e))?;
buffer.set_length(read); buffer.set_length(read);
self.correct_data_after_read(buffer); self.correct_data_after_read(buffer);
Ok(()) Ok(())
} }
fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
self.correct_data_before_write(buffer); self.correct_data_before_write(buffer);
match self.fd.write_all(buffer.message()) { match self.fd.write_all(buffer.message()).await {
Ok(_) => self.fd.flush().map_err(|e| Error::DeviceIo("Flush error", e)), Ok(_) => self.fd.flush().await.map_err(|e| Error::DeviceIo("Flush error", e)),
Err(e) => Err(Error::DeviceIo("Write error", e)) Err(e) => Err(Error::DeviceIo("Write error", e)),
} }
} }
@ -304,18 +305,10 @@ impl Device for TunTapDevice {
} }
} }
impl AsRawFd for TunTapDevice {
#[inline]
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct MockDevice { pub struct MockDevice {
inbound: Arc<Mutex<VecDeque<Vec<u8>>>>, inbound: Arc<Mutex<VecDeque<Vec<u8>>>>,
outbound: Arc<Mutex<VecDeque<Vec<u8>>>> outbound: Arc<Mutex<VecDeque<Vec<u8>>>>,
} }
impl MockDevice { impl MockDevice {
@ -336,7 +329,12 @@ impl MockDevice {
} }
} }
#[async_trait]
impl Device for MockDevice { impl Device for MockDevice {
async fn duplicate(&self) -> Result<Self, Error> {
Ok(self.clone())
}
fn get_type(&self) -> Type { fn get_type(&self) -> Type {
Type::Tun Type::Tun
} }
@ -345,7 +343,7 @@ impl Device for MockDevice {
"mock0" "mock0"
} }
fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
if let Some(data) = self.inbound.lock().pop_front() { if let Some(data) = self.inbound.lock().pop_front() {
buffer.clear(); buffer.clear();
buffer.set_length(data.len()); buffer.set_length(data.len());
@ -356,7 +354,7 @@ impl Device for MockDevice {
} }
} }
fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { async fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
self.outbound.lock().push_back(buffer.message().into()); self.outbound.lock().push_back(buffer.message().into());
Ok(()) Ok(())
} }
@ -370,19 +368,11 @@ impl Default for MockDevice {
fn default() -> Self { fn default() -> Self {
Self { Self {
outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))),
inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))) inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))),
} }
} }
} }
impl AsRawFd for MockDevice {
#[inline]
fn as_raw_fd(&self) -> RawFd {
unimplemented!()
}
}
#[allow(clippy::useless_conversion)] #[allow(clippy::useless_conversion)]
fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> { fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?; let sock = UdpSocket::bind("0.0.0.0:0")?;
@ -391,7 +381,7 @@ fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> {
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU.try_into().unwrap(), &mut ifreq) }; let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU.try_into().unwrap(), &mut ifreq) };
match res { match res {
0 => Ok(()), 0 => Ok(()),
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -402,7 +392,7 @@ fn get_device_mtu(ifname: &str) -> io::Result<usize> {
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFMTU.try_into().unwrap(), &mut ifreq) }; let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFMTU.try_into().unwrap(), &mut ifreq) };
match res { match res {
0 => Ok(unsafe { ifreq.data.value as usize }), 0 => Ok(unsafe { ifreq.data.value as usize }),
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -415,12 +405,12 @@ fn get_device_addr(ifname: &str) -> io::Result<Ipv4Addr> {
0 => { 0 => {
let af = unsafe { ifreq.data.addr.0 }; let af = unsafe { ifreq.data.addr.0 };
if af as libc::c_int != libc::AF_INET { if af as libc::c_int != libc::AF_INET {
return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())) return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned()));
} }
let ip = unsafe { ifreq.data.addr.1 }; let ip = unsafe { ifreq.data.addr.1 };
Ok(ip) Ok(ip)
} }
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -432,7 +422,7 @@ fn set_device_addr(ifname: &str, addr: Ipv4Addr) -> io::Result<()> {
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFADDR.try_into().unwrap(), &mut ifreq) }; let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFADDR.try_into().unwrap(), &mut ifreq) };
match res { match res {
0 => Ok(()), 0 => Ok(()),
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -446,12 +436,12 @@ fn get_device_netmask(ifname: &str) -> io::Result<Ipv4Addr> {
0 => { 0 => {
let af = unsafe { ifreq.data.addr.0 }; let af = unsafe { ifreq.data.addr.0 };
if af as libc::c_int != libc::AF_INET { if af as libc::c_int != libc::AF_INET {
return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned())) return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned()));
} }
let ip = unsafe { ifreq.data.addr.1 }; let ip = unsafe { ifreq.data.addr.1 };
Ok(ip) Ok(ip)
} }
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -463,7 +453,7 @@ fn set_device_netmask(ifname: &str, addr: Ipv4Addr) -> io::Result<()> {
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFNETMASK.try_into().unwrap(), &mut ifreq) }; let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFNETMASK.try_into().unwrap(), &mut ifreq) };
match res { match res {
0 => Ok(()), 0 => Ok(()),
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
@ -472,7 +462,7 @@ fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?; let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname); let mut ifreq = IfReq::new(ifname);
if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFFLAGS.try_into().unwrap(), &mut ifreq) } != 0 { if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFFLAGS.try_into().unwrap(), &mut ifreq) } != 0 {
return Err(IoError::last_os_error()) return Err(IoError::last_os_error());
} }
if up { if up {
unsafe { ifreq.data.value |= libc::IFF_UP | libc::IFF_RUNNING } unsafe { ifreq.data.value |= libc::IFF_UP | libc::IFF_RUNNING }
@ -482,20 +472,22 @@ fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> {
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFFLAGS.try_into().unwrap(), &mut ifreq) }; let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFFLAGS.try_into().unwrap(), &mut ifreq) };
match res { match res {
0 => Ok(()), 0 => Ok(()),
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error()),
} }
} }
async fn get_default_device() -> io::Result<String> {
fn get_default_device() -> io::Result<String> { let mut fd = BufReader::new(File::open("/proc/net/route").await?);
let fd = BufReader::new(File::open("/proc/net/route")?);
let mut best = None; let mut best = None;
for line in fd.lines() { let mut line = String::with_capacity(80);
let line = line?; while let Ok(read) = fd.read_line(&mut line).await {
if read == 0 {
break;
}
let parts = line.split('\t').collect::<Vec<_>>(); let parts = line.split('\t').collect::<Vec<_>>();
if parts[1] == "00000000" { if parts[1] == "00000000" {
best = Some(parts[0].to_string()); best = Some(parts[0].to_string());
break break;
} }
if parts[2] != "00000000" { if parts[2] != "00000000" {
best = Some(parts[0].to_string()) best = Some(parts[0].to_string())
@ -508,14 +500,14 @@ fn get_default_device() -> io::Result<String> {
} }
} }
fn get_rp_filter(device: &str) -> io::Result<u8> { async fn get_rp_filter(device: &str) -> io::Result<u8> {
let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?; let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?;
let mut contents = String::with_capacity(10); let mut contents = String::with_capacity(10);
fd.read_to_string(&mut contents)?; fd.read_to_string(&mut contents).await?;
u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value")) u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value"))
} }
fn set_rp_filter(device: &str, val: u8) -> io::Result<()> { async fn set_rp_filter(device: &str, val: u8) -> io::Result<()> {
let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?; let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?;
writeln!(fd, "{}", val) fd.write_all(format!("{}", val).as_bytes()).await
} }

View File

@ -1,17 +1,18 @@
use super::{ use super::{
shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, shared::{SharedPeerCrypto, SharedTable, SharedTraffic},
SPACE_BEFORE SPACE_BEFORE,
}; };
use crate::{ use crate::{
config::Config,
device::Device, device::Device,
error::Error, error::Error,
messages::MESSAGE_TYPE_DATA, messages::MESSAGE_TYPE_DATA,
net::Socket, net::Socket,
util::{MsgBuffer, Time, TimeSource}, util::{MsgBuffer, Time, TimeSource},
Protocol, Protocol,
config::Config
}; };
use std::{marker::PhantomData, net::SocketAddr}; use std::{marker::PhantomData, net::SocketAddr};
use tokio::time::timeout;
pub struct DeviceThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> { pub struct DeviceThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
// Read-only fields // Read-only fields
@ -20,16 +21,21 @@ pub struct DeviceThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
broadcast: bool, broadcast: bool,
// Device-only fields // Device-only fields
socket: S, socket: S,
device: D, pub device: D,
next_housekeep: Time, next_housekeep: Time,
buffer: MsgBuffer,
broadcast_buffer: MsgBuffer,
// Shared fields // Shared fields
traffic: SharedTraffic, traffic: SharedTraffic,
peer_crypto: SharedPeerCrypto, peer_crypto: SharedPeerCrypto,
table: SharedTable<TS> table: SharedTable<TS>,
} }
impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> DeviceThread<S, D, P, TS> { impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> DeviceThread<S, D, P, TS> {
pub fn new(config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, table: SharedTable<TS>) -> Self { pub fn new(
config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto,
table: SharedTable<TS>,
) -> Self {
Self { Self {
_dummy_ts: PhantomData, _dummy_ts: PhantomData,
_dummy_p: PhantomData, _dummy_p: PhantomData,
@ -39,97 +45,106 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> DeviceThread<S, D, P, TS
next_housekeep: TS::now(), next_housekeep: TS::now(),
traffic, traffic,
peer_crypto, peer_crypto,
table table,
buffer: MsgBuffer::new(SPACE_BEFORE),
broadcast_buffer: MsgBuffer::new(SPACE_BEFORE)
} }
} }
#[inline] #[inline]
fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { async fn send_to(&mut self, addr: SocketAddr) -> Result<(), Error> {
debug!("Sending msg with {} bytes to {}", msg.len(), addr); let size = self.buffer.len();
self.traffic.count_out_traffic(addr, msg.len()); debug!("Sending msg with {} bytes to {}", size, addr);
match self.socket.send(msg.message(), addr) { self.traffic.count_out_traffic(addr, size);
Ok(written) if written == msg.len() => Ok(()), match self.socket.send(self.buffer.message(), addr).await {
Ok(written) if written == size => Ok(()),
Ok(_) => Err(Error::Socket("Sent out truncated packet")), Ok(_) => Err(Error::Socket("Sent out truncated packet")),
Err(e) => Err(Error::SocketIo("IOError when sending", e)) Err(e) => Err(Error::SocketIo("IOError when sending", e)),
} }
} }
#[inline] #[inline]
fn send_msg(&mut self, addr: SocketAddr, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { async fn send_msg(&mut self, addr: SocketAddr, type_: u8) -> Result<(), Error> {
debug!("Sending msg with {} bytes to {}", msg.len(), addr); debug!("Sending msg with {} bytes to {}", self.buffer.len(), addr);
msg.prepend_byte(type_); self.buffer.prepend_byte(type_);
self.peer_crypto.encrypt_for(addr, msg)?; self.peer_crypto.encrypt_for(addr, &mut self.buffer)?;
self.send_to(addr, msg) self.send_to(addr).await
} }
#[inline] #[inline]
fn broadcast_msg(&mut self, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { async fn broadcast_msg(&mut self, type_: u8) -> Result<(), Error> {
debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peer_crypto.count()); let size = self.buffer.len();
let mut msg_data = MsgBuffer::new(100); debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, size, self.peer_crypto.count());
let traffic = &mut self.traffic; let traffic = &mut self.traffic;
let socket = &mut self.socket; let socket = &mut self.socket;
self.peer_crypto.for_each(|addr, crypto| { let peers = self.peer_crypto.get_snapshot();
msg_data.set_start(msg.get_start()); for (addr, crypto) in peers {
msg_data.set_length(msg.len()); self.broadcast_buffer.set_start(self.buffer.get_start());
msg_data.message_mut().clone_from_slice(msg.message()); self.broadcast_buffer.set_length(self.buffer.len());
msg_data.prepend_byte(type_); self.broadcast_buffer.message_mut().clone_from_slice(self.buffer.message());
self.broadcast_buffer.prepend_byte(type_);
if let Some(crypto) = crypto { if let Some(crypto) = crypto {
crypto.encrypt(&mut msg_data); crypto.encrypt(&mut self.broadcast_buffer);
} }
traffic.count_out_traffic(addr, msg_data.len()); traffic.count_out_traffic(addr, self.broadcast_buffer.len());
match socket.send(msg_data.message(), addr) { match socket.send(self.broadcast_buffer.message(), addr).await {
Ok(written) if written == msg_data.len() => Ok(()), Ok(written) if written == self.broadcast_buffer.len() => Ok(()),
Ok(_) => Err(Error::Socket("Sent out truncated packet")), Ok(_) => Err(Error::Socket("Sent out truncated packet")),
Err(e) => Err(Error::SocketIo("IOError when sending", e)) Err(e) => Err(Error::SocketIo("IOError when sending", e)),
}?
} }
}) Ok(())
} }
fn forward_packet(&mut self, data: &mut MsgBuffer) -> Result<(), Error> { async fn forward_packet(&mut self) -> Result<(), Error> {
let (src, dst) = P::parse(data.message())?; let (src, dst) = P::parse(self.buffer.message())?;
debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, data.len()); debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, self.buffer.len());
self.traffic.count_out_payload(dst, src, data.len()); self.traffic.count_out_payload(dst, src, self.buffer.len());
match self.table.lookup(dst) { match self.table.lookup(dst) {
Some(addr) => { Some(addr) => {
// Peer found for destination // Peer found for destination
debug!("Found destination for {} => {}", dst, addr); debug!("Found destination for {} => {}", dst, addr);
self.send_msg(addr, MESSAGE_TYPE_DATA, data)?; self.send_msg(addr, MESSAGE_TYPE_DATA).await?;
} }
None => { None => {
if self.broadcast { if self.broadcast {
debug!("No destination for {} found, broadcasting", dst); debug!("No destination for {} found, broadcasting", dst);
self.broadcast_msg(MESSAGE_TYPE_DATA, data)?; self.broadcast_msg(MESSAGE_TYPE_DATA).await?;
} else { } else {
debug!("No destination for {} found, dropping", dst); debug!("No destination for {} found, dropping", dst);
self.traffic.count_dropped_payload(data.len()); self.traffic.count_dropped_payload(self.buffer.len());
} }
} }
} }
Ok(()) Ok(())
} }
fn housekeep(&mut self) -> Result<(), Error> { pub async fn housekeep(&mut self) -> Result<(), Error> {
self.peer_crypto.sync(); self.peer_crypto.sync();
self.table.sync(); self.table.sync();
self.traffic.sync(); self.traffic.sync();
unimplemented!(); Ok(())
} }
pub fn run(mut self) { pub async fn iteration(&mut self) {
let mut buffer = MsgBuffer::new(SPACE_BEFORE); if let Ok(result) = timeout(std::time::Duration::from_millis(1000), self.device.read(&mut self.buffer)).await {
loop { try_fail!(result, "Failed to read from device: {}");
try_fail!(self.device.read(&mut buffer), "Failed to read from device: {}"); if let Err(e) = self.forward_packet().await {
//TODO: set and handle timeout
if let Err(e) = self.forward_packet(&mut buffer) {
error!("{}", e); error!("{}", e);
} }
}
let now = TS::now(); let now = TS::now();
if self.next_housekeep < now { if self.next_housekeep < now {
if let Err(e) = self.housekeep() { if let Err(e) = self.housekeep().await {
error!("{}", e) error!("{}", e)
} }
self.next_housekeep = now + 1 self.next_housekeep = now + 1
} }
} }
pub async fn run(mut self) {
loop {
self.iteration().await
}
} }
} }

View File

@ -6,7 +6,8 @@ mod device_thread;
mod shared; mod shared;
mod socket_thread; mod socket_thread;
use std::{fs::File, hash::BuildHasherDefault, thread}; use std::{fs::File, hash::BuildHasherDefault};
use tokio;
use fnv::FnvHasher; use fnv::FnvHasher;
@ -17,7 +18,7 @@ use crate::{
engine::{ engine::{
device_thread::DeviceThread, device_thread::DeviceThread,
shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, shared::{SharedPeerCrypto, SharedTable, SharedTraffic},
socket_thread::SocketThread socket_thread::SocketThread,
}, },
error::Error, error::Error,
messages::AddrList, messages::AddrList,
@ -25,7 +26,7 @@ use crate::{
payload::Protocol, payload::Protocol,
port_forwarding::PortForwarding, port_forwarding::PortForwarding,
types::NodeId, types::NodeId,
util::{addr_nice, resolve, CtrlC, Time, TimeSource} util::{addr_nice, resolve, CtrlC, Time, TimeSource},
}; };
pub type Hash = BuildHasherDefault<FnvHasher>; pub type Hash = BuildHasherDefault<FnvHasher>;
@ -40,7 +41,7 @@ struct PeerData {
timeout: Time, timeout: Time,
peer_timeout: u16, peer_timeout: u16,
node_id: NodeId, node_id: NodeId,
crypto: PeerCrypto crypto: PeerCrypto,
} }
#[derive(Clone)] #[derive(Clone)]
@ -50,30 +51,29 @@ pub struct ReconnectEntry {
tries: u16, tries: u16,
timeout: u16, timeout: u16,
next: Time, next: Time,
final_timeout: Option<Time> final_timeout: Option<Time>,
} }
pub struct GenericCloud<D: Device, P: Protocol, S: Socket, TS: TimeSource> { pub struct GenericCloud<D: Device, P: Protocol, S: Socket, TS: TimeSource> {
socket_thread: SocketThread<S, D, P, TS>, socket_thread: SocketThread<S, D, P, TS>,
device_thread: DeviceThread<S, D, P, TS> device_thread: DeviceThread<S, D, P, TS>,
} }
impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS> { impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS> {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub async fn new(
config: &Config, socket: S, device: D, port_forwarding: Option<PortForwarding>, stats_file: Option<File> config: &Config, socket: S, device: D, port_forwarding: Option<PortForwarding>, stats_file: Option<File>,
) -> Self { ) -> Result<Self, Error> {
let table = SharedTable::<TS>::new(&config); let table = SharedTable::<TS>::new(&config);
let traffic = SharedTraffic::new(); let traffic = SharedTraffic::new();
let peer_crypto = SharedPeerCrypto::new(); let peer_crypto = SharedPeerCrypto::new();
let device_thread = DeviceThread::<S, D, P, TS>::new( let device_thread = DeviceThread::<S, D, P, TS>::new(
config.clone(), config.clone(),
device.clone(), device.duplicate().await?,
socket.clone(), socket.clone(),
traffic.clone(), traffic.clone(),
peer_crypto.clone(), peer_crypto.clone(),
table.clone() table.clone(),
); );
let socket_thread = SocketThread::<S, D, P, TS>::new( let socket_thread = SocketThread::<S, D, P, TS>::new(
config.clone(), config.clone(),
@ -83,58 +83,60 @@ impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS
peer_crypto, peer_crypto,
table, table,
port_forwarding, port_forwarding,
stats_file stats_file,
); );
Self { socket_thread, device_thread } Ok(Self { socket_thread, device_thread })
} }
pub fn add_peer(&mut self, addr: String) -> Result<(), Error> { pub fn add_peer(&mut self, addr: String) -> Result<(), Error> {
unimplemented!() unimplemented!()
} }
pub fn run(self) { pub async fn run(self) {
// TODO: spawn threads
let ctrlc = CtrlC::new(); let ctrlc = CtrlC::new();
let device_thread = self.device_thread; let device_thread_handle = tokio::spawn(self.device_thread.run());
let device_thread_handle = thread::spawn(move || device_thread.run()); let socket_thread_handle = tokio::spawn(self.socket_thread.run());
let socket_thread = self.socket_thread;
let socket_thread_handle = thread::spawn(move || socket_thread.run());
// TODO: wait for ctrl-c // TODO: wait for ctrl-c
device_thread_handle.join().unwrap(); let (dev_ret, sock_ret) = join!(device_thread_handle, socket_thread_handle);
socket_thread_handle.join().unwrap(); dev_ret.unwrap();
sock_ret.unwrap();
} }
} }
#[cfg(test)]
#[cfg(test)] use super::device::MockDevice; use super::device::MockDevice;
#[cfg(test)] use super::net::MockSocket; #[cfg(test)]
#[cfg(test)] use super::util::{MockTimeSource, MsgBuffer}; use super::net::MockSocket;
#[cfg(test)] use std::net::SocketAddr; #[cfg(test)]
use super::util::MockTimeSource;
#[cfg(test)]
use std::net::SocketAddr;
#[cfg(test)] #[cfg(test)]
impl<P: Protocol> GenericCloud<MockDevice, P, MockSocket, MockTimeSource> { impl<P: Protocol> GenericCloud<MockDevice, P, MockSocket, MockTimeSource> {
pub fn socket(&mut self) -> &mut MockSocket { pub fn socket(&mut self) -> &mut MockSocket {
unimplemented!() &mut self.socket_thread.socket
//&mut self.socket
} }
pub fn device(&mut self) -> &mut MockDevice { pub fn device(&mut self) -> &mut MockDevice {
unimplemented!() &mut self.device_thread.device
//&mut self.device
} }
pub fn trigger_socket_event(&mut self) { pub fn connect(&mut self, addr: SocketAddr) -> Result<(), Error> {
let mut buffer = MsgBuffer::new(SPACE_BEFORE);
unimplemented!() unimplemented!()
} }
pub fn trigger_device_event(&mut self) { pub async fn trigger_socket_event(&mut self) {
let mut buffer = MsgBuffer::new(SPACE_BEFORE); self.socket_thread.iteration().await
unimplemented!()
} }
pub fn trigger_housekeep(&mut self) { pub async fn trigger_device_event(&mut self) {
unimplemented!() self.device_thread.iteration().await
}
pub async fn trigger_housekeep(&mut self) {
try_fail!(self.socket_thread.housekeep().await, "Housekeep failed: {}");
try_fail!(self.device_thread.housekeep().await, "Housekeep failed: {}");
} }
pub fn is_connected(&self, addr: &SocketAddr) -> bool { pub fn is_connected(&self, addr: &SocketAddr) -> bool {

View File

@ -42,14 +42,8 @@ impl SharedPeerCrypto {
} }
} }
pub fn for_each( pub fn get_snapshot(&self) -> HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash> {
&mut self, mut callback: impl FnMut(SocketAddr, Option<Arc<CryptoCore>>) -> Result<(), Error> self.peers.lock().clone()
) -> Result<(), Error> {
let mut peers = self.peers.lock();
for (k, v) in peers.iter_mut() {
callback(*k, v.clone())?
}
Ok(())
} }
pub fn count(&self) -> usize { pub fn count(&self) -> usize {

View File

@ -1,6 +1,6 @@
use super::{ use super::{
shared::{SharedPeerCrypto, SharedTable, SharedTraffic}, shared::{SharedPeerCrypto, SharedTable, SharedTraffic},
SPACE_BEFORE SPACE_BEFORE,
}; };
use crate::{ use crate::{
@ -12,13 +12,13 @@ use crate::{
error::Error, error::Error,
messages::{ messages::{
AddrList, NodeInfo, PeerInfo, MESSAGE_TYPE_CLOSE, MESSAGE_TYPE_DATA, MESSAGE_TYPE_KEEPALIVE, AddrList, NodeInfo, PeerInfo, MESSAGE_TYPE_CLOSE, MESSAGE_TYPE_DATA, MESSAGE_TYPE_KEEPALIVE,
MESSAGE_TYPE_NODE_INFO MESSAGE_TYPE_NODE_INFO,
}, },
net::{mapped_addr, Socket}, net::{mapped_addr, Socket},
port_forwarding::PortForwarding, port_forwarding::PortForwarding,
types::{Address, NodeId, Range, RangeList}, types::{Address, NodeId, Range, RangeList},
util::{MsgBuffer, StatsdMsg, Time, TimeSource}, util::{MsgBuffer, StatsdMsg, Time, TimeSource},
Config, Crypto, Device, Protocol Config, Crypto, Device, Protocol,
}; };
use rand::{random, seq::SliceRandom, thread_rng}; use rand::{random, seq::SliceRandom, thread_rng};
use smallvec::{smallvec, SmallVec}; use smallvec::{smallvec, SmallVec};
@ -31,16 +31,15 @@ use std::{
io::{Cursor, Seek, SeekFrom, Write}, io::{Cursor, Seek, SeekFrom, Write},
marker::PhantomData, marker::PhantomData,
net::{SocketAddr, ToSocketAddrs}, net::{SocketAddr, ToSocketAddrs},
str::FromStr str::FromStr,
}; };
use tokio::time::timeout;
const MAX_RECONNECT_INTERVAL: u16 = 3600; const MAX_RECONNECT_INTERVAL: u16 = 3600;
const RESOLVE_INTERVAL: Time = 300; const RESOLVE_INTERVAL: Time = 300;
const OWN_ADDRESS_RESET_INTERVAL: Time = 300; const OWN_ADDRESS_RESET_INTERVAL: Time = 300;
pub const STATS_INTERVAL: Time = 60; pub const STATS_INTERVAL: Time = 60;
#[derive(Clone)] #[derive(Clone)]
pub struct ReconnectEntry { pub struct ReconnectEntry {
address: Option<(String, Time)>, address: Option<(String, Time)>,
@ -48,7 +47,7 @@ pub struct ReconnectEntry {
tries: u16, tries: u16,
timeout: u16, timeout: u16,
next: Time, next: Time,
final_timeout: Option<Time> final_timeout: Option<Time>,
} }
pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> { pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
@ -62,7 +61,7 @@ pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
_dummy_ts: PhantomData<TS>, _dummy_ts: PhantomData<TS>,
_dummy_p: PhantomData<P>, _dummy_p: PhantomData<P>,
// Socket-only fields // Socket-only fields
socket: S, pub socket: S,
device: D, device: D,
next_housekeep: Time, next_housekeep: Time,
own_addresses: AddrList, own_addresses: AddrList,
@ -77,18 +76,20 @@ pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
stats_file: Option<File>, stats_file: Option<File>,
statsd_server: Option<String>, statsd_server: Option<String>,
reconnect_peers: SmallVec<[ReconnectEntry; 3]>, reconnect_peers: SmallVec<[ReconnectEntry; 3]>,
buffer: MsgBuffer,
broadcast_buffer: MsgBuffer,
// Shared fields // Shared fields
peer_crypto: SharedPeerCrypto, peer_crypto: SharedPeerCrypto,
traffic: SharedTraffic, traffic: SharedTraffic,
table: SharedTable<TS>, table: SharedTable<TS>,
// Should not be here // Should not be here
port_forwarding: Option<PortForwarding> // TODO: 3rd thread port_forwarding: Option<PortForwarding>, // TODO: 3rd thread
} }
impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS> { impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS> {
pub fn new( pub fn new(
config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto, config: Config, device: D, socket: S, traffic: SharedTraffic, peer_crypto: SharedPeerCrypto,
table: SharedTable<TS>, port_forwarding: Option<PortForwarding>, stats_file: Option<File> table: SharedTable<TS>, port_forwarding: Option<PortForwarding>, stats_file: Option<File>,
) -> Self { ) -> Self {
let mut claims = SmallVec::with_capacity(config.claims.len()); let mut claims = SmallVec::with_capacity(config.claims.len());
for s in &config.claims { for s in &config.claims {
@ -104,7 +105,7 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Err(Error::DeviceIo(_, e)) if e.kind() == io::ErrorKind::AddrNotAvailable => { Err(Error::DeviceIo(_, e)) if e.kind() == io::ErrorKind::AddrNotAvailable => {
info!("No address set on interface.") info!("No address set on interface.")
} }
Err(e) => error!("{}", e) Err(e) => error!("{}", e),
} }
} }
let now = TS::now(); let now = TS::now();
@ -139,72 +140,73 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
update_freq, update_freq,
statsd_server: config.statsd_server.clone(), statsd_server: config.statsd_server.clone(),
crypto: Crypto::new(node_id, &config.crypto).unwrap(), crypto: Crypto::new(node_id, &config.crypto).unwrap(),
config config,
buffer: MsgBuffer::new(SPACE_BEFORE),
broadcast_buffer: MsgBuffer::new(SPACE_BEFORE),
} }
} }
#[inline] #[inline]
fn send_to(&mut self, addr: SocketAddr, msg: &MsgBuffer) -> Result<(), Error> { async fn send_to(&mut self, addr: SocketAddr) -> Result<(), Error> {
debug!("Sending msg with {} bytes to {}", msg.len(), addr); let size = self.buffer.len();
self.traffic.count_out_traffic(addr, msg.len()); debug!("Sending msg with {} bytes to {}", size, addr);
match self.socket.send(msg.message(), addr) { self.traffic.count_out_traffic(addr, size);
Ok(written) if written == msg.len() => Ok(()), match self.socket.send(self.buffer.message(), addr).await {
Ok(written) if written == size => Ok(()),
Ok(_) => Err(Error::Socket("Sent out truncated packet")), Ok(_) => Err(Error::Socket("Sent out truncated packet")),
Err(e) => Err(Error::SocketIo("IOError when sending", e)) Err(e) => Err(Error::SocketIo("IOError when sending", e)),
} }
} }
#[inline] #[inline]
fn broadcast_msg(&mut self, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { async fn broadcast_msg(&mut self, type_: u8) -> Result<(), Error> {
debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peers.len()); debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, self.buffer.len(), self.peers.len());
let mut msg_data = MsgBuffer::new(100);
for (addr, peer) in &mut self.peers { for (addr, peer) in &mut self.peers {
msg_data.set_start(msg.get_start()); self.broadcast_buffer.set_start(self.buffer.get_start());
msg_data.set_length(msg.len()); self.broadcast_buffer.set_length(self.buffer.len());
msg_data.message_mut().clone_from_slice(msg.message()); self.broadcast_buffer.message_mut().clone_from_slice(self.buffer.message());
msg_data.prepend_byte(type_); self.broadcast_buffer.prepend_byte(type_);
peer.crypto.encrypt_message(&mut msg_data); peer.crypto.encrypt_message(&mut self.broadcast_buffer);
self.traffic.count_out_traffic(*addr, msg_data.len()); self.traffic.count_out_traffic(*addr, self.broadcast_buffer.len());
match self.socket.send(msg_data.message(), *addr) { match self.socket.send(self.broadcast_buffer.message(), *addr).await {
Ok(written) if written == msg_data.len() => Ok(()), Ok(written) if written == self.broadcast_buffer.len() => Ok(()),
Ok(_) => Err(Error::Socket("Sent out truncated packet")), Ok(_) => Err(Error::Socket("Sent out truncated packet")),
Err(e) => Err(Error::SocketIo("IOError when sending", e)) Err(e) => Err(Error::SocketIo("IOError when sending", e)),
}? }?
} }
Ok(()) Ok(())
} }
fn connect_sock(&mut self, addr: SocketAddr) -> Result<(), Error> { async fn connect_sock(&mut self, addr: SocketAddr) -> Result<(), Error> {
let addr = mapped_addr(addr); let addr = mapped_addr(addr);
if self.peers.contains_key(&addr) if self.peers.contains_key(&addr)
|| self.own_addresses.contains(&addr) || self.own_addresses.contains(&addr)
|| self.pending_inits.contains_key(&addr) || self.pending_inits.contains_key(&addr)
{ {
return Ok(()) return Ok(());
} }
debug!("Connecting to {:?}", addr); debug!("Connecting to {:?}", addr);
let payload = self.create_node_info(); let payload = self.create_node_info();
let mut init = self.crypto.peer_instance(payload); let mut init = self.crypto.peer_instance(payload);
let mut msg = MsgBuffer::new(SPACE_BEFORE); init.send_ping(&mut self.buffer);
init.send_ping(&mut msg);
self.pending_inits.insert(addr, init); self.pending_inits.insert(addr, init);
self.send_to(addr, &msg) self.send_to(addr).await
} }
pub fn connect<Addr: ToSocketAddrs + fmt::Debug + Clone>(&mut self, addr: Addr) -> Result<(), Error> { pub async fn connect<Addr: ToSocketAddrs + fmt::Debug + Clone>(&mut self, addr: Addr) -> Result<(), Error> {
let addrs = resolve(&addr)?.into_iter().map(mapped_addr).collect::<SmallVec<[SocketAddr; 3]>>(); let addrs = resolve(&addr)?.into_iter().map(mapped_addr).collect::<SmallVec<[SocketAddr; 3]>>();
for addr in &addrs { for addr in &addrs {
if self.own_addresses.contains(addr) if self.own_addresses.contains(addr)
|| self.peers.contains_key(addr) || self.peers.contains_key(addr)
|| self.pending_inits.contains_key(addr) || self.pending_inits.contains_key(addr)
{ {
return Ok(()) return Ok(());
} }
} }
// Send a message to each resolved address // Send a message to each resolved address
for a in addrs { for a in addrs {
// Ignore error this time // Ignore error this time
self.connect_sock(a).ok(); self.connect_sock(a).await.ok();
} }
Ok(()) Ok(())
} }
@ -224,43 +226,46 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
peers, peers,
claims: self.claims.clone(), claims: self.claims.clone(),
peer_timeout: Some(self.peer_timeout_publish), peer_timeout: Some(self.peer_timeout_publish),
addrs: self.own_addresses.clone() addrs: self.own_addresses.clone(),
} }
} }
fn update_peer_info(&mut self, addr: SocketAddr, info: Option<NodeInfo>) -> Result<(), Error> { async fn update_peer_info(&mut self, addr: SocketAddr, info: Option<NodeInfo>) -> Result<(), Error> {
if let Some(peer) = self.peers.get_mut(&addr) { if let Some(peer) = self.peers.get_mut(&addr) {
peer.last_seen = TS::now(); peer.last_seen = TS::now();
peer.timeout = TS::now() + self.config.peer_timeout as Time peer.timeout = TS::now() + self.config.peer_timeout as Time
} else { } else {
error!("Received peer update from non peer {}", addr_nice(addr)); error!("Received peer update from non peer {}", addr_nice(addr));
return Ok(()) return Ok(());
} }
if let Some(info) = info { if let Some(info) = info {
debug!("Adding claims of peer {}: {:?}", addr_nice(addr), info.claims); debug!("Adding claims of peer {}: {:?}", addr_nice(addr), info.claims);
self.table.set_claims(addr, info.claims); self.table.set_claims(addr, info.claims);
debug!("Received {} peers from {}: {:?}", info.peers.len(), addr_nice(addr), info.peers); debug!("Received {} peers from {}: {:?}", info.peers.len(), addr_nice(addr), info.peers);
self.connect_to_peers(&info.peers)?; self.connect_to_peers(&info.peers).await?;
} }
Ok(()) Ok(())
} }
fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo, msg: &mut MsgBuffer) -> Result<(), Error> { async fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo) -> Result<(), Error> {
info!("Added peer {}", addr_nice(addr)); info!("Added peer {}", addr_nice(addr));
if let Some(init) = self.pending_inits.remove(&addr) { if let Some(init) = self.pending_inits.remove(&addr) {
msg.clear(); self.buffer.clear();
let crypto = init.finish(msg); let crypto = init.finish(&mut self.buffer);
self.peers.insert(addr, PeerData { self.peers.insert(
addr,
PeerData {
addrs: info.addrs.clone(), addrs: info.addrs.clone(),
crypto, crypto,
node_id: info.node_id, node_id: info.node_id,
peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT), peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT),
last_seen: TS::now(), last_seen: TS::now(),
timeout: TS::now() + self.config.peer_timeout as Time timeout: TS::now() + self.config.peer_timeout as Time,
}); },
self.update_peer_info(addr, Some(info))?; );
if !msg.is_empty() { self.update_peer_info(addr, Some(info)).await?;
self.send_to(addr, msg)?; if !self.buffer.is_empty() {
self.send_to(addr).await?;
} }
} else { } else {
error!("No init for new peer {}", addr_nice(addr)); error!("No init for new peer {}", addr_nice(addr));
@ -268,24 +273,24 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(()) Ok(())
} }
fn connect_to_peers(&mut self, peers: &[PeerInfo]) -> Result<(), Error> { async fn connect_to_peers(&mut self, peers: &[PeerInfo]) -> Result<(), Error> {
'outer: for peer in peers { 'outer: for peer in peers {
for addr in &peer.addrs { for addr in &peer.addrs {
if self.peers.contains_key(addr) { if self.peers.contains_key(addr) {
continue 'outer continue 'outer;
} }
} }
if let Some(node_id) = peer.node_id { if let Some(node_id) = peer.node_id {
if self.node_id == node_id { if self.node_id == node_id {
continue 'outer continue 'outer;
} }
for p in self.peers.values() { for p in self.peers.values() {
if p.node_id == node_id { if p.node_id == node_id {
continue 'outer continue 'outer;
} }
} }
} }
self.connect(&peer.addrs as &[SocketAddr])?; self.connect(&peer.addrs as &[SocketAddr]).await?;
} }
Ok(()) Ok(())
} }
@ -297,14 +302,14 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
} }
fn handle_payload_from(&mut self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { async fn handle_payload_from(&mut self, peer: SocketAddr) -> Result<(), Error> {
let (src, dst) = P::parse(data.message())?; let (src, dst) = P::parse(self.buffer.message())?;
let len = data.len(); let len = self.buffer.len();
debug!("Writing data to device: {} bytes", len); debug!("Writing data to device: {} bytes", len);
self.traffic.count_in_payload(src, dst, len); self.traffic.count_in_payload(src, dst, len);
if let Err(e) = self.device.write(data) { if let Err(e) = self.device.write(&mut self.buffer).await {
error!("Failed to send via device: {}", e); error!("Failed to send via device: {}", e);
return Err(e) return Err(e);
} }
if self.learning { if self.learning {
// Learn single address // Learn single address
@ -313,85 +318,81 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(()) Ok(())
} }
fn process_message( async fn process_message(&mut self, src: SocketAddr, msg_result: MessageResult) -> Result<(), Error> {
&mut self, src: SocketAddr, msg_result: MessageResult, data: &mut MsgBuffer
) -> Result<(), Error> {
match msg_result { match msg_result {
MessageResult::Message(type_) => { MessageResult::Message(type_) => match type_ {
match type_ { MESSAGE_TYPE_DATA => self.handle_payload_from(src).await?,
MESSAGE_TYPE_DATA => self.handle_payload_from(src, data)?,
MESSAGE_TYPE_NODE_INFO => { MESSAGE_TYPE_NODE_INFO => {
let info = match NodeInfo::decode(Cursor::new(data.message())) { let info = match NodeInfo::decode(Cursor::new(self.buffer.message())) {
Ok(val) => val, Ok(val) => val,
Err(err) => { Err(err) => {
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(self.buffer.len());
return Err(err) return Err(err);
} }
}; };
self.update_peer_info(src, Some(info))? self.update_peer_info(src, Some(info)).await?
} }
MESSAGE_TYPE_KEEPALIVE => self.update_peer_info(src, None)?, MESSAGE_TYPE_KEEPALIVE => self.update_peer_info(src, None).await?,
MESSAGE_TYPE_CLOSE => self.remove_peer(src), MESSAGE_TYPE_CLOSE => self.remove_peer(src),
_ => { _ => {
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(self.buffer.len());
return Err(Error::Message("Unknown message type")) return Err(Error::Message("Unknown message type"));
} }
} },
} MessageResult::Reply => self.send_to(src).await?,
MessageResult::Reply => self.send_to(src, data)?, MessageResult::None => (),
MessageResult::None => ()
} }
Ok(()) Ok(())
} }
fn handle_message(&mut self, src: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { async fn handle_message(&mut self, src: SocketAddr) -> Result<(), Error> {
let src = mapped_addr(src); let src = mapped_addr(src);
debug!("Received {} bytes from {}", data.len(), src); debug!("Received {} bytes from {}", self.buffer.len(), src);
if let Some(result) = self.peers.get_mut(&src).map(|peer| peer.crypto.handle_message(data)) { let buffer = &mut self.buffer;
return self.process_message(src, result?, data) if let Some(result) = self.peers.get_mut(&src).map(|peer| peer.crypto.handle_message(buffer)) {
return self.process_message(src, result?).await;
} }
let is_init = is_init_message(data.message()); let is_init = is_init_message(buffer.message());
if let Some(result) = self.pending_inits.get_mut(&src).map(|init| { if let Some(result) = self.pending_inits.get_mut(&src).map(|init| {
if is_init { if is_init {
init.handle_init(data) init.handle_init(buffer)
} else { } else {
data.clear(); buffer.clear();
init.repeat_last_message(data); init.repeat_last_message(buffer);
Ok(InitResult::Continue) Ok(InitResult::Continue)
} }
}) { }) {
match result? { match result? {
InitResult::Continue => { InitResult::Continue => {
if !data.is_empty() { if !buffer.is_empty() {
self.send_to(src, data)? self.send_to(src).await?
} }
} }
InitResult::Success { peer_payload, .. } => self.add_new_peer(src, peer_payload, data)? InitResult::Success { peer_payload, .. } => self.add_new_peer(src, peer_payload).await?,
} }
return Ok(()) return Ok(());
} }
if !is_init_message(data.message()) { if !is_init_message(self.buffer.message()) {
info!("Ignoring non-init message from unknown peer {}", addr_nice(src)); info!("Ignoring non-init message from unknown peer {}", addr_nice(src));
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(self.buffer.len());
return Ok(()) return Ok(());
} }
let mut init = self.crypto.peer_instance(self.create_node_info()); let mut init = self.crypto.peer_instance(self.create_node_info());
let msg_result = init.handle_init(data); let msg_result = init.handle_init(&mut self.buffer);
match msg_result { match msg_result {
Ok(_) => { Ok(_) => {
self.pending_inits.insert(src, init); self.pending_inits.insert(src, init);
self.send_to(src, data) self.send_to(src).await
} }
Err(err) => { Err(err) => {
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(self.buffer.len());
Err(err) Err(err)
} }
} }
} }
fn housekeep(&mut self) -> Result<(), Error> { pub async fn housekeep(&mut self) -> Result<(), Error> {
let now = TS::now(); let now = TS::now();
let mut buffer = MsgBuffer::new(SPACE_BEFORE);
let mut del: SmallVec<[SocketAddr; 3]> = SmallVec::new(); let mut del: SmallVec<[SocketAddr; 3]> = SmallVec::new();
for (&addr, ref data) in &self.peers { for (&addr, ref data) in &self.peers {
if data.timeout < now { if data.timeout < now {
@ -402,10 +403,10 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
info!("Forgot peer {} due to timeout", addr_nice(addr)); info!("Forgot peer {} due to timeout", addr_nice(addr));
self.peers.remove(&addr); self.peers.remove(&addr);
self.table.remove_claims(addr); self.table.remove_claims(addr);
self.connect_sock(addr)?; // Try to reconnect self.connect_sock(addr).await?; // Try to reconnect
} }
self.table.housekeep(); self.table.housekeep();
self.crypto_housekeep()?; self.crypto_housekeep().await?;
// Periodically extend the port-forwarding // Periodically extend the port-forwarding
if let Some(ref mut pfw) = self.port_forwarding { if let Some(ref mut pfw) = self.port_forwarding {
pfw.check_extend(); pfw.check_extend();
@ -413,37 +414,37 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
let now = TS::now(); let now = TS::now();
// Periodically reset own peers // Periodically reset own peers
if self.next_own_address_reset <= now { if self.next_own_address_reset <= now {
self.reset_own_addresses().map_err(|err| Error::SocketIo("Failed to get own addresses", err))?; self.reset_own_addresses().await.map_err(|err| Error::SocketIo("Failed to get own addresses", err))?;
self.next_own_address_reset = now + OWN_ADDRESS_RESET_INTERVAL; self.next_own_address_reset = now + OWN_ADDRESS_RESET_INTERVAL;
} }
// Periodically send peer list to peers // Periodically send peer list to peers
if self.next_peers <= now { if self.next_peers <= now {
debug!("Send peer list to all peers"); debug!("Send peer list to all peers");
let info = self.create_node_info(); let info = self.create_node_info();
info.encode(&mut buffer); info.encode(&mut self.buffer);
self.broadcast_msg(MESSAGE_TYPE_NODE_INFO, &mut buffer)?; self.broadcast_msg(MESSAGE_TYPE_NODE_INFO).await?;
// Reschedule for next update // Reschedule for next update
let min_peer_timeout = self.peers.iter().map(|p| p.1.peer_timeout).min().unwrap_or(DEFAULT_PEER_TIMEOUT); let min_peer_timeout = self.peers.iter().map(|p| p.1.peer_timeout).min().unwrap_or(DEFAULT_PEER_TIMEOUT);
let interval = min(self.update_freq as u16, max(min_peer_timeout / 2 - 60, 1)); let interval = min(self.update_freq as u16, max(min_peer_timeout / 2 - 60, 1));
self.next_peers = now + Time::from(interval); self.next_peers = now + Time::from(interval);
} }
self.reconnect_to_peers()?; self.reconnect_to_peers().await?;
if self.next_stats_out < now { if self.next_stats_out < now {
// Write out the statistics // Write out the statistics
self.write_out_stats().map_err(|err| Error::FileIo("Failed to write stats file", err))?; self.write_out_stats().map_err(|err| Error::FileIo("Failed to write stats file", err))?;
self.send_stats_to_statsd()?; self.send_stats_to_statsd().await?;
self.next_stats_out = now + STATS_INTERVAL; self.next_stats_out = now + STATS_INTERVAL;
self.traffic.period(Some(5)); self.traffic.period(Some(5));
} }
if let Some(peers) = self.beacon_serializer.get_cmd_results() { if let Some(peers) = self.beacon_serializer.get_cmd_results() {
debug!("Loaded beacon with peers: {:?}", peers); debug!("Loaded beacon with peers: {:?}", peers);
for peer in peers { for peer in peers {
self.connect_sock(peer)?; self.connect_sock(peer).await?;
} }
} }
if self.next_beacon < now { if self.next_beacon < now {
self.store_beacon()?; self.store_beacon()?;
self.load_beacon()?; self.load_beacon().await?;
self.next_beacon = now + Time::from(self.config.beacon_interval); self.next_beacon = now + Time::from(self.config.beacon_interval);
} }
// TODO: sync peer_crypto // TODO: sync peer_crypto
@ -452,36 +453,35 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
unimplemented!(); unimplemented!();
} }
fn crypto_housekeep(&mut self) -> Result<(), Error> { async fn crypto_housekeep(&mut self) -> Result<(), Error> {
let mut msg = MsgBuffer::new(SPACE_BEFORE);
let mut del: SmallVec<[SocketAddr; 4]> = smallvec![]; let mut del: SmallVec<[SocketAddr; 4]> = smallvec![];
for addr in self.pending_inits.keys().copied().collect::<SmallVec<[SocketAddr; 4]>>() { for addr in self.pending_inits.keys().copied().collect::<SmallVec<[SocketAddr; 4]>>() {
msg.clear(); self.buffer.clear();
if self.pending_inits.get_mut(&addr).unwrap().every_second(&mut msg).is_err() { if self.pending_inits.get_mut(&addr).unwrap().every_second(&mut self.buffer).is_err() {
del.push(addr) del.push(addr)
} else if !msg.is_empty() { } else if !self.buffer.is_empty() {
self.send_to(addr, &msg)? self.send_to(addr).await?
} }
} }
for addr in self.peers.keys().copied().collect::<SmallVec<[SocketAddr; 16]>>() { for addr in self.peers.keys().copied().collect::<SmallVec<[SocketAddr; 16]>>() {
msg.clear(); self.buffer.clear();
self.peers.get_mut(&addr).unwrap().crypto.every_second(&mut msg); self.peers.get_mut(&addr).unwrap().crypto.every_second(&mut self.buffer);
if !msg.is_empty() { if !self.buffer.is_empty() {
self.send_to(addr, &msg)? self.send_to(addr).await?
} }
} }
for addr in del { for addr in del {
self.pending_inits.remove(&addr); self.pending_inits.remove(&addr);
if self.peers.remove(&addr).is_some() { if self.peers.remove(&addr).is_some() {
self.connect_sock(addr)?; self.connect_sock(addr).await?;
} }
} }
Ok(()) Ok(())
} }
fn reset_own_addresses(&mut self) -> io::Result<()> { async fn reset_own_addresses(&mut self) -> io::Result<()> {
self.own_addresses.clear(); self.own_addresses.clear();
self.own_addresses.push(self.socket.address().map(mapped_addr)?); self.own_addresses.push(self.socket.address().await.map(mapped_addr)?);
if let Some(ref pfw) = self.port_forwarding { if let Some(ref pfw) = self.port_forwarding {
self.own_addresses.push(pfw.get_internal_ip().into()); self.own_addresses.push(pfw.get_internal_ip().into());
self.own_addresses.push(pfw.get_external_ip().into()); self.own_addresses.push(pfw.get_external_ip().into());
@ -510,14 +510,14 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
/// Loads the beacon /// Loads the beacon
fn load_beacon(&mut self) -> Result<(), Error> { async fn load_beacon(&mut self) -> Result<(), Error> {
let peers; let peers;
if let Some(ref path) = self.config.beacon_load { if let Some(ref path) = self.config.beacon_load {
if let Some(path) = path.strip_prefix('|') { if let Some(path) = path.strip_prefix('|') {
self.beacon_serializer self.beacon_serializer
.read_from_cmd(path, Some(50)) .read_from_cmd(path, Some(50))
.map_err(|e| Error::BeaconIo("Failed to call beacon command", e))?; .map_err(|e| Error::BeaconIo("Failed to call beacon command", e))?;
return Ok(()) return Ok(());
} else { } else {
peers = self peers = self
.beacon_serializer .beacon_serializer
@ -525,11 +525,11 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
.map_err(|e| Error::BeaconIo("Failed to read beacon from file", e))?; .map_err(|e| Error::BeaconIo("Failed to read beacon from file", e))?;
} }
} else { } else {
return Ok(()) return Ok(());
} }
debug!("Loaded beacon with peers: {:?}", peers); debug!("Loaded beacon with peers: {:?}", peers);
for peer in peers { for peer in peers {
self.connect_sock(peer)?; self.connect_sock(peer).await?;
} }
Ok(()) Ok(())
} }
@ -561,7 +561,7 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
/// Sends the statistics to a statsd endpoint /// Sends the statistics to a statsd endpoint
fn send_stats_to_statsd(&mut self) -> Result<(), Error> { async fn send_stats_to_statsd(&mut self) -> Result<(), Error> {
if let Some(ref endpoint) = self.statsd_server { if let Some(ref endpoint) = self.statsd_server {
let peer_traffic = self.traffic.total_peer_traffic(); let peer_traffic = self.traffic.total_peer_traffic();
let payload_traffic = self.traffic.total_payload_traffic(); let payload_traffic = self.traffic.total_payload_traffic();
@ -607,10 +607,10 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
let msg_data = msg.as_bytes(); let msg_data = msg.as_bytes();
let addrs = resolve(endpoint)?; let addrs = resolve(endpoint)?;
if let Some(addr) = addrs.first() { if let Some(addr) = addrs.first() {
match self.socket.send(msg_data, *addr) { match self.socket.send(msg_data, *addr).await {
Ok(written) if written == msg_data.len() => Ok(()), Ok(written) if written == msg_data.len() => Ok(()),
Ok(_) => Err(Error::Socket("Sent out truncated packet")), Ok(_) => Err(Error::Socket("Sent out truncated packet")),
Err(e) => Err(Error::SocketIo("IOError when sending", e)) Err(e) => Err(Error::SocketIo("IOError when sending", e)),
}? }?
} else { } else {
error!("Failed to resolve statsd server {}", endpoint); error!("Failed to resolve statsd server {}", endpoint);
@ -619,14 +619,14 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(()) Ok(())
} }
fn reconnect_to_peers(&mut self) -> Result<(), Error> { async fn reconnect_to_peers(&mut self) -> Result<(), Error> {
let now = TS::now(); let now = TS::now();
// Connect to those reconnect_peers that are due // Connect to those reconnect_peers that are due
for entry in self.reconnect_peers.clone() { for entry in self.reconnect_peers.clone() {
if entry.next > now { if entry.next > now {
continue continue;
} }
self.connect(&entry.resolved as &[SocketAddr])?; self.connect(&entry.resolved as &[SocketAddr]).await?;
} }
for entry in &mut self.reconnect_peers { for entry in &mut self.reconnect_peers {
// Schedule for next second if node is connected // Schedule for next second if node is connected
@ -635,7 +635,7 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
entry.tries = 0; entry.tries = 0;
entry.timeout = 1; entry.timeout = 1;
entry.next = now + 1; entry.next = now + 1;
continue continue;
} }
} }
// Resolve entries anew // Resolve entries anew
@ -643,19 +643,17 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
if *next_resolve <= now { if *next_resolve <= now {
match resolve(address as &str) { match resolve(address as &str) {
Ok(addrs) => entry.resolved = addrs, Ok(addrs) => entry.resolved = addrs,
Err(_) => { Err(_) => match resolve(&format!("{}:{}", address, DEFAULT_PORT)) {
match resolve(&format!("{}:{}", address, DEFAULT_PORT)) {
Ok(addrs) => entry.resolved = addrs, Ok(addrs) => entry.resolved = addrs,
Err(err) => warn!("Failed to resolve {}: {}", address, err) Err(err) => warn!("Failed to resolve {}: {}", address, err),
} },
}
} }
*next_resolve = now + RESOLVE_INTERVAL; *next_resolve = now + RESOLVE_INTERVAL;
} }
} }
// Ignore if next attempt is already in the future // Ignore if next attempt is already in the future
if entry.next > now { if entry.next > now {
continue continue;
} }
// Exponential back-off: every 10 tries, the interval doubles // Exponential back-off: every 10 tries, the interval doubles
entry.tries += 1; entry.tries += 1;
@ -674,19 +672,10 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(()) Ok(())
} }
pub fn run(mut self) { pub async fn iteration(&mut self) {
let mut buffer = MsgBuffer::new(SPACE_BEFORE); if let Ok(result) = timeout(std::time::Duration::from_millis(1000), self.socket.receive(&mut self.buffer)).await {
loop { let src = try_fail!(result, "Failed to read from network socket: {}");
match self.socket.receive(&mut buffer) { match self.handle_message(src).await {
Err(err) => {
if err.kind() == io::ErrorKind::TimedOut || err.kind() == io::ErrorKind::WouldBlock {
// ok, this is a normal timeout
} else {
fail!("Failed to read from network socket: {}", err);
}
}
Ok(src) => {
match self.handle_message(src, &mut buffer) {
Err(e @ Error::CryptoInitFatal(_)) => { Err(e @ Error::CryptoInitFatal(_)) => {
debug!("Fatal crypto init error from {}: {}", src, e); debug!("Fatal crypto init error from {}: {}", src, e);
info!("Closing pending connection to {} due to error in crypto init", addr_nice(src)); info!("Closing pending connection to {} due to error in crypto init", addr_nice(src));
@ -702,14 +691,18 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(_) => {} Ok(_) => {}
} }
} }
}
let now = TS::now(); let now = TS::now();
if self.next_housekeep < now { if self.next_housekeep < now {
if let Err(e) = self.housekeep() { if let Err(e) = self.housekeep().await {
error!("{}", e) error!("{}", e)
} }
self.next_housekeep = now + 1 self.next_housekeep = now + 1
} }
} }
pub async fn run(mut self) {
loop {
self.iteration().await
}
} }
} }

View File

@ -4,6 +4,7 @@
#[macro_use] extern crate log; #[macro_use] extern crate log;
#[macro_use] extern crate serde; #[macro_use] extern crate serde;
#[macro_use] extern crate tokio;
#[cfg(test)] extern crate tempfile; #[cfg(test)] extern crate tempfile;
@ -134,16 +135,16 @@ fn parse_ip_netmask(addr: &str) -> Result<(Ipv4Addr, Ipv4Addr), String> {
Ok((ip, netmask)) Ok((ip, netmask))
} }
fn setup_device(config: &Config) -> TunTapDevice { async fn setup_device(config: &Config) -> TunTapDevice {
let device = try_fail!( let device = try_fail!(
TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)), TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)).await,
"Failed to open virtual {} interface {}: {}", "Failed to open virtual {} interface {}: {}",
config.device_type, config.device_type,
config.device_name config.device_name
); );
info!("Opened device {}", device.ifname()); info!("Opened device {}", device.ifname());
config.call_hook("device_setup", vec![("IFNAME", device.ifname())], true); config.call_hook("device_setup", vec![("IFNAME", device.ifname())], true);
if let Err(err) = device.set_mtu(None) { if let Err(err) = device.set_mtu(None).await {
error!("Error setting optimal MTU on {}: {}", device.ifname(), err); error!("Error setting optimal MTU on {}: {}", device.ifname(), err);
} }
if let Some(ip) = &config.ip { if let Some(ip) = &config.ip {
@ -155,9 +156,9 @@ fn setup_device(config: &Config) -> TunTapDevice {
run_script(script, device.ifname()); run_script(script, device.ifname());
} }
if config.fix_rp_filter { if config.fix_rp_filter {
try_fail!(device.fix_rp_filter(), "Failed to change rp_filter settings: {}"); try_fail!(device.fix_rp_filter().await, "Failed to change rp_filter settings: {}");
} }
if let Ok(val) = device.get_rp_filter() { if let Ok(val) = device.get_rp_filter().await {
if val != 1 { if val != 1 {
warn!("Your networking configuration might be affected by a vulnerability (https://vpncloud.ddswd.de/docs/security/cve-2019-14899/), please change your rp_filter setting to 1 (currently {}).", val); warn!("Your networking configuration might be affected by a vulnerability (https://vpncloud.ddswd.de/docs/security/cve-2019-14899/), please change your rp_filter setting to 1 (currently {}).", val);
} }
@ -167,9 +168,9 @@ fn setup_device(config: &Config) -> TunTapDevice {
} }
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
fn run<P: Protocol, S: Socket>(config: Config, socket: S) { async fn run<P: Protocol, S: Socket>(config: Config, socket: S) {
let device = setup_device(&config); let device = setup_device(&config).await;
let port_forwarding = if config.port_forwarding { socket.create_port_forwarding() } else { None }; let port_forwarding = if config.port_forwarding { socket.create_port_forwarding().await } else { None };
let stats_file = match config.stats_file { let stats_file = match config.stats_file {
None => None, None => None,
Some(ref name) => { Some(ref name) => {
@ -187,7 +188,7 @@ fn run<P: Protocol, S: Socket>(config: Config, socket: S) {
}; };
let ifname = device.ifname().to_string(); let ifname = device.ifname().to_string();
let mut cloud = let mut cloud =
GenericCloud::<TunTapDevice, P, S, SystemTimeSource>::new(&config, socket, device, port_forwarding, stats_file); try_fail!(GenericCloud::<TunTapDevice, P, S, SystemTimeSource>::new(&config, socket, device, port_forwarding, stats_file).await, "Failed to create engine: {}");
for mut addr in config.peers { for mut addr in config.peers {
if addr.find(':').unwrap_or(0) <= addr.find(']').unwrap_or(0) { if addr.find(':').unwrap_or(0) <= addr.find(']').unwrap_or(0) {
// : not present or only in IPv6 address // : not present or only in IPv6 address
@ -221,13 +222,14 @@ fn run<P: Protocol, S: Socket>(config: Config, socket: S) {
} }
try_fail!(pd.apply(), "Failed to drop privileges: {}"); try_fail!(pd.apply(), "Failed to drop privileges: {}");
} }
cloud.run(); cloud.run().await;
if let Some(script) = config.ifdown { if let Some(script) = config.ifdown {
run_script(&script, &ifname); run_script(&script, &ifname);
} }
} }
fn main() { #[tokio::main]
async fn main() {
let args: Args = Args::from_args(); let args: Args = Args::from_args();
if args.version { if args.version {
println!("VpnCloud v{}", env!("CARGO_PKG_VERSION")); println!("VpnCloud v{}", env!("CARGO_PKG_VERSION"));
@ -321,16 +323,16 @@ fn main() {
} }
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
if config.listen.starts_with("ws://") { if config.listen.starts_with("ws://") {
let socket = try_fail!(ProxyConnection::listen(&config.listen), "Failed to open socket {}: {}", config.listen); let socket = try_fail!(ProxyConnection::listen(&config.listen).await, "Failed to open socket {}: {}", config.listen);
match config.device_type { match config.device_type {
Type::Tap => run::<payload::Frame, _>(config, socket), Type::Tap => run::<payload::Frame, _>(config, socket).await,
Type::Tun => run::<payload::Packet, _>(config, socket) Type::Tun => run::<payload::Packet, _>(config, socket).await
} }
return return
} }
let socket = try_fail!(NetSocket::listen(&config.listen), "Failed to open socket {}: {}", config.listen); let socket = try_fail!(NetSocket::listen(&config.listen).await, "Failed to open socket {}: {}", config.listen);
match config.device_type { match config.device_type {
Type::Tap => run::<payload::Frame, _>(config, socket), Type::Tap => run::<payload::Frame, _>(config, socket).await,
Type::Tun => run::<payload::Packet, _>(config, socket) Type::Tun => run::<payload::Packet, _>(config, socket).await
} }
} }

View File

@ -4,24 +4,24 @@
use super::util::{MockTimeSource, MsgBuffer, Time, TimeSource}; use super::util::{MockTimeSource, MsgBuffer, Time, TimeSource};
use crate::port_forwarding::PortForwarding; use crate::port_forwarding::PortForwarding;
use async_trait::async_trait;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{ use std::{
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
io::{self, ErrorKind}, io::{self, ErrorKind},
net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket}, net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket},
os::unix::io::{AsRawFd, RawFd}, os::unix::io::AsRawFd,
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc Arc,
}, },
time::Duration
}; };
pub fn mapped_addr(addr: SocketAddr) -> SocketAddr { pub fn mapped_addr(addr: SocketAddr) -> SocketAddr {
// HOT PATH // HOT PATH
match addr { match addr {
SocketAddr::V4(addr4) => SocketAddr::new(IpAddr::V6(addr4.ip().to_ipv6_mapped()), addr4.port()), SocketAddr::V4(addr4) => SocketAddr::new(IpAddr::V6(addr4.ip().to_ipv6_mapped()), addr4.port()),
_ => addr _ => addr,
} }
} }
@ -31,12 +31,13 @@ pub fn get_ip() -> IpAddr {
s.local_addr().unwrap().ip() s.local_addr().unwrap().ip()
} }
pub trait Socket: AsRawFd + Sized + Clone + Send + 'static { #[async_trait]
fn listen(addr: &str) -> Result<Self, io::Error>; pub trait Socket: Sized + Clone + Send + Sync + 'static {
fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error>; async fn listen(addr: &str) -> Result<Self, io::Error>;
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>; async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error>;
fn address(&self) -> Result<SocketAddr, io::Error>; async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>;
fn create_port_forwarding(&self) -> Option<PortForwarding>; async fn address(&self) -> Result<SocketAddr, io::Error>;
async fn create_port_forwarding(&self) -> Option<PortForwarding>;
} }
pub fn parse_listen(addr: &str) -> SocketAddr { pub fn parse_listen(addr: &str) -> SocketAddr {
@ -59,40 +60,33 @@ impl Clone for NetSocket {
} }
} }
impl AsRawFd for NetSocket {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}
#[async_trait]
impl Socket for NetSocket { impl Socket for NetSocket {
fn listen(addr: &str) -> Result<Self, io::Error> { async fn listen(addr: &str) -> Result<Self, io::Error> {
let addr = parse_listen(addr); let addr = parse_listen(addr);
Ok(NetSocket(UdpSocket::bind(addr).and_then(|s| { Ok(NetSocket(UdpSocket::bind(addr)?))
s.set_read_timeout(Some(Duration::from_secs(1)))?;
Ok(s)
})?))
} }
fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
buffer.clear(); buffer.clear();
let (size, addr) = self.0.recv_from(buffer.buffer())?; let (size, addr) = self.0.recv_from(buffer.buffer())?;
buffer.set_length(size); buffer.set_length(size);
Ok(addr) Ok(addr)
} }
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> { async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> {
self.0.send_to(data, addr) self.0.send_to(data, addr)
} }
fn address(&self) -> Result<SocketAddr, io::Error> { async fn address(&self) -> Result<SocketAddr, io::Error> {
let mut addr = self.0.local_addr()?; let mut addr = self.0.local_addr()?;
addr.set_ip(get_ip()); addr.set_ip(get_ip());
Ok(addr) Ok(addr)
} }
fn create_port_forwarding(&self) -> Option<PortForwarding> { async fn create_port_forwarding(&self) -> Option<PortForwarding> {
PortForwarding::new(self.address().unwrap().port()) PortForwarding::new(self.address().await.unwrap().port())
} }
} }
@ -100,14 +94,13 @@ thread_local! {
static MOCK_SOCKET_NAT: AtomicBool = AtomicBool::new(false); static MOCK_SOCKET_NAT: AtomicBool = AtomicBool::new(false);
} }
#[derive(Clone)] #[derive(Clone)]
pub struct MockSocket { pub struct MockSocket {
nat: bool, nat: bool,
nat_peers: Arc<Mutex<HashMap<SocketAddr, Time>>>, nat_peers: Arc<Mutex<HashMap<SocketAddr, Time>>>,
address: SocketAddr, address: SocketAddr,
outbound: Arc<Mutex<VecDeque<(SocketAddr, Vec<u8>)>>>, outbound: Arc<Mutex<VecDeque<(SocketAddr, Vec<u8>)>>>,
inbound: Arc<Mutex<VecDeque<(SocketAddr, Vec<u8>)>>> inbound: Arc<Mutex<VecDeque<(SocketAddr, Vec<u8>)>>>,
} }
impl MockSocket { impl MockSocket {
@ -117,7 +110,7 @@ impl MockSocket {
nat_peers: Default::default(), nat_peers: Default::default(),
address, address,
outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))), outbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))),
inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))) inbound: Arc::new(Mutex::new(VecDeque::with_capacity(10))),
} }
} }
@ -132,12 +125,12 @@ impl MockSocket {
pub fn put_inbound(&mut self, from: SocketAddr, data: Vec<u8>) -> bool { pub fn put_inbound(&mut self, from: SocketAddr, data: Vec<u8>) -> bool {
if !self.nat { if !self.nat {
self.inbound.lock().push_back((from, data)); self.inbound.lock().push_back((from, data));
return true return true;
} }
if let Some(timeout) = self.nat_peers.lock().get(&from) { if let Some(timeout) = self.nat_peers.lock().get(&from) {
if *timeout >= MockTimeSource::now() { if *timeout >= MockTimeSource::now() {
self.inbound.lock().push_back((from, data)); self.inbound.lock().push_back((from, data));
return true return true;
} }
} }
warn!("Sender {:?} is filtered out by NAT", from); warn!("Sender {:?} is filtered out by NAT", from);
@ -149,18 +142,14 @@ impl MockSocket {
} }
} }
impl AsRawFd for MockSocket {
fn as_raw_fd(&self) -> RawFd {
unimplemented!()
}
}
#[async_trait]
impl Socket for MockSocket { impl Socket for MockSocket {
fn listen(addr: &str) -> Result<Self, io::Error> { async fn listen(addr: &str) -> Result<Self, io::Error> {
Ok(Self::new(parse_listen(addr))) Ok(Self::new(parse_listen(addr)))
} }
fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
if let Some((addr, data)) = self.inbound.lock().pop_front() { if let Some((addr, data)) = self.inbound.lock().pop_front() {
buffer.clear(); buffer.clear();
buffer.set_length(data.len()); buffer.set_length(data.len());
@ -171,7 +160,7 @@ impl Socket for MockSocket {
} }
} }
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> { async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> {
self.outbound.lock().push_back((addr, data.into())); self.outbound.lock().push_back((addr, data.into()));
if self.nat { if self.nat {
self.nat_peers.lock().insert(addr, MockTimeSource::now() + 300); self.nat_peers.lock().insert(addr, MockTimeSource::now() + 300);
@ -179,11 +168,11 @@ impl Socket for MockSocket {
Ok(data.len()) Ok(data.len())
} }
fn address(&self) -> Result<SocketAddr, io::Error> { async fn address(&self) -> Result<SocketAddr, io::Error> {
Ok(self.address) Ok(self.address)
} }
fn create_port_forwarding(&self) -> Option<PortForwarding> { async fn create_port_forwarding(&self) -> Option<PortForwarding> {
None None
} }
} }

View File

@ -54,7 +54,7 @@ impl Protocol for Frame {
#[test] #[test]
fn decode_frame_without_vlan() { async fn decode_frame_without_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8]; let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap(); let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 }); assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 });
@ -62,7 +62,7 @@ fn decode_frame_without_vlan() {
} }
#[test] #[test]
fn decode_frame_with_vlan() { async fn decode_frame_with_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0, 4, 210, 1, 2, 3, 4, 5, 6, 7, 8]; let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0, 4, 210, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap(); let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [4, 210, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 }); assert_eq!(src, Address { data: [4, 210, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 });
@ -70,7 +70,7 @@ fn decode_frame_with_vlan() {
} }
#[test] #[test]
fn decode_invalid_frame() { async fn decode_invalid_frame() {
assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8]).is_ok()); assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8]).is_ok());
// truncated frame // truncated frame
assert!(Frame::parse(&[]).is_err()); assert!(Frame::parse(&[]).is_err());
@ -120,7 +120,7 @@ impl Protocol for Packet {
#[test] #[test]
fn decode_ipv4_packet() { async fn decode_ipv4_packet() {
let data = [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2]; let data = [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2];
let (src, dst) = Packet::parse(&data).unwrap(); let (src, dst) = Packet::parse(&data).unwrap();
assert_eq!(src, Address { data: [192, 168, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }); assert_eq!(src, Address { data: [192, 168, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 });
@ -128,7 +128,7 @@ fn decode_ipv4_packet() {
} }
#[test] #[test]
fn decode_ipv6_packet() { async fn decode_ipv6_packet() {
let data = [ let data = [
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5, 0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2, 1 4, 3, 2, 1
@ -139,7 +139,7 @@ fn decode_ipv6_packet() {
} }
#[test] #[test]
fn decode_invalid_packet() { async fn decode_invalid_packet() {
assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2]).is_ok()); assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2]).is_ok());
assert!(Packet::parse(&[ assert!(Packet::parse(&[
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5, 0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,

View File

@ -82,7 +82,7 @@ impl<P: Protocol> Simulator<P> {
Self { next_port: 1, nodes: HashMap::default(), messages: VecDeque::default() } Self { next_port: 1, nodes: HashMap::default(), messages: VecDeque::default() }
} }
pub fn add_node(&mut self, nat: bool, config: &Config) -> SocketAddr { pub async fn add_node(&mut self, nat: bool, config: &Config) -> SocketAddr {
let mut config = config.clone(); let mut config = config.clone();
MockSocket::set_nat(nat); MockSocket::set_nat(nat);
config.listen = format!("[::]:{}", self.next_port); config.listen = format!("[::]:{}", self.next_port);
@ -92,7 +92,7 @@ impl<P: Protocol> Simulator<P> {
} }
DebugLogger::set_node(self.next_port as usize); DebugLogger::set_node(self.next_port as usize);
self.next_port += 1; self.next_port += 1;
let node = TestNode::new(&config, MockSocket::new(addr), MockDevice::new(), None, None); let node = TestNode::new(&config, MockSocket::new(addr), MockDevice::new(), None, None).await.unwrap();
DebugLogger::set_node(0); DebugLogger::set_node(0);
self.nodes.insert(addr, node); self.nodes.insert(addr, node);
addr addr
@ -105,12 +105,12 @@ impl<P: Protocol> Simulator<P> {
node node
} }
pub fn simulate_next_message(&mut self) { pub async fn simulate_next_message(&mut self) {
if let Some((src, dst, data)) = self.messages.pop_front() { if let Some((src, dst, data)) = self.messages.pop_front() {
if let Some(node) = self.nodes.get_mut(&dst) { if let Some(node) = self.nodes.get_mut(&dst) {
if node.socket().put_inbound(src, data) { if node.socket().put_inbound(src, data) {
DebugLogger::set_node(node.get_num()); DebugLogger::set_node(node.get_num());
node.trigger_socket_event(); node.trigger_socket_event().await;
DebugLogger::set_node(0); DebugLogger::set_node(0);
let sock = node.socket(); let sock = node.socket();
let src = dst; let src = dst;
@ -124,16 +124,16 @@ impl<P: Protocol> Simulator<P> {
} }
} }
pub fn simulate_all_messages(&mut self) { pub async fn simulate_all_messages(&mut self) {
while !self.messages.is_empty() { while !self.messages.is_empty() {
self.simulate_next_message() self.simulate_next_message().await
} }
} }
pub fn trigger_node_housekeep(&mut self, addr: SocketAddr) { pub async fn trigger_node_housekeep(&mut self, addr: SocketAddr) {
let node = self.nodes.get_mut(&addr).unwrap(); let node = self.nodes.get_mut(&addr).unwrap();
DebugLogger::set_node(node.get_num()); DebugLogger::set_node(node.get_num());
node.trigger_housekeep(); node.trigger_housekeep().await;
DebugLogger::set_node(0); DebugLogger::set_node(0);
let sock = node.socket(); let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() { while let Some((dst, data)) = sock.pop_outbound() {
@ -141,10 +141,10 @@ impl<P: Protocol> Simulator<P> {
} }
} }
pub fn trigger_housekeep(&mut self) { pub async fn trigger_housekeep(&mut self) {
for (src, node) in &mut self.nodes { for (src, node) in &mut self.nodes {
DebugLogger::set_node(node.get_num()); DebugLogger::set_node(node.get_num());
node.trigger_housekeep(); node.trigger_housekeep().await;
DebugLogger::set_node(0); DebugLogger::set_node(0);
let sock = node.socket(); let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() { while let Some((dst, data)) = sock.pop_outbound() {
@ -157,13 +157,13 @@ impl<P: Protocol> Simulator<P> {
MockTimeSource::set_time(time); MockTimeSource::set_time(time);
} }
pub fn simulate_time(&mut self, time: Time) { pub async fn simulate_time(&mut self, time: Time) {
let mut t = MockTimeSource::now(); let mut t = MockTimeSource::now();
while t < time { while t < time {
t += 1; t += 1;
self.set_time(t); self.set_time(t);
self.trigger_housekeep(); self.trigger_housekeep().await;
self.simulate_all_messages(); self.simulate_all_messages().await;
} }
} }
@ -192,11 +192,11 @@ impl<P: Protocol> Simulator<P> {
self.messages.len() self.messages.len()
} }
pub fn put_payload(&mut self, addr: SocketAddr, data: Vec<u8>) { pub async fn put_payload(&mut self, addr: SocketAddr, data: Vec<u8>) {
let node = self.nodes.get_mut(&addr).unwrap(); let node = self.nodes.get_mut(&addr).unwrap();
node.device().put_inbound(data); node.device().put_inbound(data);
DebugLogger::set_node(node.get_num()); DebugLogger::set_node(node.get_num());
node.trigger_device_event(); node.trigger_device_event().await;
DebugLogger::set_node(0); DebugLogger::set_node(0);
let sock = node.socket(); let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() { while let Some((dst, data)) = sock.pop_outbound() {

View File

@ -5,35 +5,35 @@
use super::common::*; use super::common::*;
#[test] #[test]
fn connect_nat_2_peers() { async fn connect_nat_2_peers() {
let config = Config { port_forwarding: false, ..Default::default() }; let config = Config { port_forwarding: false, ..Default::default() };
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(true, &config); let node1 = sim.add_node(true, &config).await;
let node2 = sim.add_node(true, &config); let node2 = sim.add_node(true, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node2, node1); sim.connect(node2, node1);
sim.simulate_time(60); sim.simulate_time(60).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn connect_nat_3_peers() { async fn connect_nat_3_peers() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(true, &config); let node1 = sim.add_node(true, &config).await;
let node2 = sim.add_node(true, &config); let node2 = sim.add_node(true, &config).await;
let node3 = sim.add_node(true, &config); let node3 = sim.add_node(true, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node2, node1); sim.connect(node2, node1);
sim.connect(node1, node3); sim.connect(node1, node3);
sim.connect(node3, node1); sim.connect(node3, node1);
sim.simulate_time(300); sim.simulate_time(300).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3)); assert!(sim.is_connected(node1, node3));
@ -43,19 +43,19 @@ fn connect_nat_3_peers() {
} }
#[test] #[test]
fn nat_keepalive() { async fn nat_keepalive() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(true, &config); let node1 = sim.add_node(true, &config).await;
let node2 = sim.add_node(true, &config); let node2 = sim.add_node(true, &config).await;
let node3 = sim.add_node(true, &config); let node3 = sim.add_node(true, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node2, node1); sim.connect(node2, node1);
sim.connect(node1, node3); sim.connect(node1, node3);
sim.connect(node3, node1); sim.connect(node3, node1);
sim.simulate_time(1000); sim.simulate_time(1000).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3)); assert!(sim.is_connected(node1, node3));
@ -63,7 +63,7 @@ fn nat_keepalive() {
assert!(sim.is_connected(node2, node3)); assert!(sim.is_connected(node2, node3));
assert!(sim.is_connected(node3, node2)); assert!(sim.is_connected(node3, node2));
sim.simulate_time(10000); sim.simulate_time(10000).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3)); assert!(sim.is_connected(node1, node3));

View File

@ -5,37 +5,37 @@
use super::common::*; use super::common::*;
#[test] #[test]
fn switch_delivers() { async fn switch_delivers() {
let config = Config { device_type: Type::Tap, ..Config::default() }; let config = Config { device_type: Type::Tap, ..Config::default() };
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5]; let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5];
sim.put_payload(node1, payload.clone()); sim.put_payload(node1, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload), sim.pop_payload(node2)); assert_eq!(Some(payload), sim.pop_payload(node2));
} }
#[test] #[test]
fn switch_learns() { async fn switch_learns() {
let config = Config { device_type: Type::Tap, ..Config::default() }; let config = Config { device_type: Type::Tap, ..Config::default() };
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
let node3 = sim.add_node(false, &config); let node3 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node1, node3); sim.connect(node1, node3);
sim.connect(node2, node3); sim.connect(node2, node3);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3)); assert!(sim.is_connected(node1, node3));
@ -47,8 +47,8 @@ fn switch_learns() {
// Nothing learnt so far, node1 broadcasts // Nothing learnt so far, node1 broadcasts
sim.put_payload(node1, payload.clone()); sim.put_payload(node1, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload.clone()), sim.pop_payload(node2)); assert_eq!(Some(payload.clone()), sim.pop_payload(node2));
assert_eq!(Some(payload), sim.pop_payload(node3)); assert_eq!(Some(payload), sim.pop_payload(node3));
@ -57,25 +57,25 @@ fn switch_learns() {
// Node 2 learned the address by receiving it, does not broadcast // Node 2 learned the address by receiving it, does not broadcast
sim.put_payload(node2, payload.clone()); sim.put_payload(node2, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload), sim.pop_payload(node1)); assert_eq!(Some(payload), sim.pop_payload(node1));
assert_eq!(None, sim.pop_payload(node3)); assert_eq!(None, sim.pop_payload(node3));
} }
#[test] #[test]
fn switch_honours_vlans() { async fn switch_honours_vlans() {
let config = Config { device_type: Type::Tap, ..Config::default() }; let config = Config { device_type: Type::Tap, ..Config::default() };
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
let node3 = sim.add_node(false, &config); let node3 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node1, node3); sim.connect(node1, node3);
sim.connect(node2, node3); sim.connect(node2, node3);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3)); assert!(sim.is_connected(node1, node3));
@ -87,8 +87,8 @@ fn switch_honours_vlans() {
// Nothing learnt so far, node1 broadcasts // Nothing learnt so far, node1 broadcasts
sim.put_payload(node1, payload.clone()); sim.put_payload(node1, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload.clone()), sim.pop_payload(node2)); assert_eq!(Some(payload.clone()), sim.pop_payload(node2));
assert_eq!(Some(payload), sim.pop_payload(node3)); assert_eq!(Some(payload), sim.pop_payload(node3));
@ -97,8 +97,8 @@ fn switch_honours_vlans() {
// Node 2 learned the address by receiving it, does not broadcast // Node 2 learned the address by receiving it, does not broadcast
sim.put_payload(node2, payload.clone()); sim.put_payload(node2, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload), sim.pop_payload(node1)); assert_eq!(Some(payload), sim.pop_payload(node1));
assert_eq!(None, sim.pop_payload(node3)); assert_eq!(None, sim.pop_payload(node3));
@ -107,8 +107,8 @@ fn switch_honours_vlans() {
// Different VLANs, node 2 does not learn, still broadcasts // Different VLANs, node 2 does not learn, still broadcasts
sim.put_payload(node2, payload.clone()); sim.put_payload(node2, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload.clone()), sim.pop_payload(node1)); assert_eq!(Some(payload.clone()), sim.pop_payload(node1));
assert_eq!(Some(payload), sim.pop_payload(node3)); assert_eq!(Some(payload), sim.pop_payload(node3));
@ -116,13 +116,14 @@ fn switch_honours_vlans() {
#[test] #[test]
#[ignore] #[ignore]
fn switch_forgets() { async fn switch_forgets() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }
#[test] #[test]
fn router_delivers() { async fn size() {
let future = async {
let config1 = Config { let config1 = Config {
device_type: Type::Tun, device_type: Type::Tun,
auto_claim: false, auto_claim: false,
@ -136,24 +137,27 @@ fn router_delivers() {
..Config::default() ..Config::default()
}; };
let mut sim = TunSimulator::new(); let mut sim = TunSimulator::new();
let node1 = sim.add_node(false, &config1); let node1 = sim.add_node(false, &config1).await;
let node2 = sim.add_node(false, &config2); let node2 = sim.add_node(false, &config2).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
let payload = vec![0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]; let payload = vec![0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
sim.put_payload(node1, payload.clone()); sim.put_payload(node1, payload.clone()).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(Some(payload), sim.pop_payload(node2)); assert_eq!(Some(payload), sim.pop_payload(node2));
};
assert_eq!(std::mem::size_of_val(&future), 100);
future.await
} }
#[test] #[test]
fn router_drops_unknown_dest() { async fn router_delivers() {
let config1 = Config { let config1 = Config {
device_type: Type::Tun, device_type: Type::Tun,
auto_claim: false, auto_claim: false,
@ -167,18 +171,49 @@ fn router_drops_unknown_dest() {
..Config::default() ..Config::default()
}; };
let mut sim = TunSimulator::new(); let mut sim = TunSimulator::new();
let node1 = sim.add_node(false, &config1); let node1 = sim.add_node(false, &config1).await;
let node2 = sim.add_node(false, &config2); let node2 = sim.add_node(false, &config2).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
let payload = vec![0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
sim.put_payload(node1, payload.clone()).await;
sim.simulate_all_messages().await;
assert_eq!(Some(payload), sim.pop_payload(node2));
}
#[test]
async fn router_drops_unknown_dest() {
let config1 = Config {
device_type: Type::Tun,
auto_claim: false,
claims: vec!["1.1.1.1/32".to_string()],
..Config::default()
};
let config2 = Config {
device_type: Type::Tun,
auto_claim: false,
claims: vec!["2.2.2.2/32".to_string()],
..Config::default()
};
let mut sim = TunSimulator::new();
let node1 = sim.add_node(false, &config1).await;
let node2 = sim.add_node(false, &config2).await;
sim.connect(node1, node2);
sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
let payload = vec![0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 3, 3, 3, 3]; let payload = vec![0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 3, 3, 3, 3];
sim.put_payload(node1, payload); sim.put_payload(node1, payload).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert_eq!(None, sim.pop_payload(node2)); assert_eq!(None, sim.pop_payload(node2));
} }

View File

@ -5,47 +5,47 @@
use super::common::*; use super::common::*;
#[test] #[test]
fn direct_connect() { async fn direct_connect() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn direct_connect_unencrypted() { async fn direct_connect_unencrypted() {
let config = Config { let config = Config {
crypto: CryptoConfig { algorithms: vec!["plain".to_string()], ..CryptoConfig::default() }, crypto: CryptoConfig { algorithms: vec!["plain".to_string()], ..CryptoConfig::default() },
..Config::default() ..Config::default()
}; };
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn cross_connect() { async fn cross_connect() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
let node3 = sim.add_node(false, &config); let node3 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.connect(node1, node3); sim.connect(node1, node3);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
sim.simulate_time(120); sim.simulate_time(120).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
@ -56,124 +56,124 @@ fn cross_connect() {
} }
#[test] #[test]
fn connect_via_beacons() { async fn connect_via_beacons() {
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let beacon_path = "target/.vpncloud_test"; let beacon_path = "target/.vpncloud_test";
let config1 = Config { beacon_store: Some(beacon_path.to_string()), ..Default::default() }; let config1 = Config { beacon_store: Some(beacon_path.to_string()), ..Default::default() };
let node1 = sim.add_node(false, &config1); let node1 = sim.add_node(false, &config1).await;
let config2 = Config { beacon_load: Some(beacon_path.to_string()), ..Default::default() }; let config2 = Config { beacon_load: Some(beacon_path.to_string()), ..Default::default() };
let node2 = sim.add_node(false, &config2); let node2 = sim.add_node(false, &config2).await;
sim.set_time(100); sim.set_time(100);
sim.trigger_node_housekeep(node1); sim.trigger_node_housekeep(node1).await;
sim.trigger_node_housekeep(node2); sim.trigger_node_housekeep(node2).await;
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn reconnect_after_timeout() { async fn reconnect_after_timeout() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
sim.set_time(5000); sim.set_time(5000);
sim.trigger_housekeep(); sim.trigger_housekeep().await;
assert!(!sim.is_connected(node1, node2)); assert!(!sim.is_connected(node1, node2));
assert!(!sim.is_connected(node2, node1)); assert!(!sim.is_connected(node2, node1));
sim.simulate_all_messages(); sim.simulate_all_messages().await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn lost_init_ping() { async fn lost_init_ping() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.drop_message(); // drop init ping sim.drop_message(); // drop init ping
sim.simulate_time(120); sim.simulate_time(120).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn lost_init_pong() { async fn lost_init_pong() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_next_message(); // init ping sim.simulate_next_message().await; // init ping
sim.drop_message(); // drop init pong sim.drop_message(); // drop init pong
sim.simulate_time(120); sim.simulate_time(120).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn lost_init_peng() { async fn lost_init_peng() {
let config = Config::default(); let config = Config::default();
let mut sim = TapSimulator::new(); let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config); let node1 = sim.add_node(false, &config).await;
let node2 = sim.add_node(false, &config); let node2 = sim.add_node(false, &config).await;
sim.connect(node1, node2); sim.connect(node1, node2);
sim.simulate_next_message(); // init ping sim.simulate_next_message().await; // init ping
sim.simulate_next_message(); // init pong sim.simulate_next_message().await; // init pong
sim.drop_message(); // drop init peng sim.drop_message(); // drop init peng
sim.simulate_time(120); sim.simulate_time(120).await;
assert!(sim.is_connected(node1, node2)); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1)); assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
#[ignore] #[ignore]
fn peer_exchange() { async fn peer_exchange() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }
#[test] #[test]
#[ignore] #[ignore]
fn lost_peer_exchange() { async fn lost_peer_exchange() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }
#[test] #[test]
#[ignore] #[ignore]
fn remove_dead_peers() { async fn remove_dead_peers() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }
#[test] #[test]
#[ignore] #[ignore]
fn update_primary_address() { async fn update_primary_address() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }
#[test] #[test]
#[ignore] #[ignore]
fn automatic_peer_timeout() { async fn automatic_peer_timeout() {
// TODO Test // TODO Test
unimplemented!() unimplemented!()
} }

View File

@ -240,7 +240,7 @@ mod tests {
use std::io::Cursor; use std::io::Cursor;
#[test] #[test]
fn address_parse_fmt() { async fn address_parse_fmt() {
assert_eq!(format!("{}", Address::from_str("120.45.22.5").unwrap()), "120.45.22.5"); assert_eq!(format!("{}", Address::from_str("120.45.22.5").unwrap()), "120.45.22.5");
assert_eq!(format!("{}", Address::from_str("78:2d:16:05:01:02").unwrap()), "78:2d:16:05:01:02"); assert_eq!(format!("{}", Address::from_str("78:2d:16:05:01:02").unwrap()), "78:2d:16:05:01:02");
assert_eq!( assert_eq!(
@ -256,7 +256,7 @@ mod tests {
} }
#[test] #[test]
fn address_decode_encode() { async fn address_decode_encode() {
let mut buf = vec![]; let mut buf = vec![];
let addr = Address::from_str("120.45.22.5").unwrap(); let addr = Address::from_str("120.45.22.5").unwrap();
addr.write_to(Cursor::new(&mut buf)); addr.write_to(Cursor::new(&mut buf));
@ -277,7 +277,7 @@ mod tests {
} }
#[test] #[test]
fn address_eq() { async fn address_eq() {
assert_eq!( assert_eq!(
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap(), Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap(),
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap() Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap()
@ -297,7 +297,7 @@ mod tests {
} }
#[test] #[test]
fn address_range_decode_encode() { async fn address_range_decode_encode() {
let mut buf = vec![]; let mut buf = vec![];
let range = let range =
Range { base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }, prefix_len: 24 }; Range { base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }, prefix_len: 24 };

View File

@ -24,14 +24,14 @@ pub type Time = i64;
#[derive(Clone)] #[derive(Clone)]
pub struct MsgBuffer { pub struct MsgBuffer {
space_before: usize, space_before: usize,
buffer: [u8; 65535], buffer: Box<[u8; 65535]>,
start: usize, start: usize,
end: usize, end: usize,
} }
impl MsgBuffer { impl MsgBuffer {
pub fn new(space_before: usize) -> Self { pub fn new(space_before: usize) -> Self {
Self { buffer: [0; 65535], space_before, start: space_before, end: space_before } Self { buffer: Box::new([0; 65535]), space_before, start: space_before, end: space_before }
} }
pub fn get_start(&self) -> usize { pub fn get_start(&self) -> usize {
@ -427,7 +427,7 @@ pub fn run_cmd(mut cmd: Command) {
} }
#[test] #[test]
fn base62() { async fn base62() {
assert_eq!("", to_base62(&[0])); assert_eq!("", to_base62(&[0]));
assert_eq!("z", to_base62(&[61])); assert_eq!("z", to_base62(&[61]));
assert_eq!("10", to_base62(&[62])); assert_eq!("10", to_base62(&[62]));

View File

@ -6,20 +6,20 @@ use super::{
net::{get_ip, mapped_addr, parse_listen, Socket}, net::{get_ip, mapped_addr, parse_listen, Socket},
poll::{WaitImpl, WaitResult}, poll::{WaitImpl, WaitResult},
port_forwarding::PortForwarding, port_forwarding::PortForwarding,
util::MsgBuffer util::MsgBuffer,
}; };
use async_trait::async_trait;
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use std::{ use std::{
io::{self, Cursor, Read, Write}, io::{self, Cursor, Read, Write},
net::{Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket}, net::{Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket},
os::unix::io::{AsRawFd, RawFd}, os::unix::io::AsRawFd,
sync::Arc, sync::Arc,
thread::spawn thread::spawn,
}; };
use tungstenite::{client::AutoStream, connect, protocol::WebSocket, server::accept, Message}; use tungstenite::{client::AutoStream, connect, protocol::WebSocket, server::accept, Message};
use url::Url; use url::Url;
macro_rules! io_error { macro_rules! io_error {
($val:expr, $format:expr) => ( { ($val:expr, $format:expr) => ( {
$val.map_err(|err| io::Error::new(io::ErrorKind::Other, format!($format, err))) $val.map_err(|err| io::Error::new(io::ErrorKind::Other, format!($format, err)))
@ -36,7 +36,7 @@ fn write_addr<W: Write>(addr: SocketAddr, mut out: W) -> Result<(), io::Error> {
out.write_all(&addr.ip().octets())?; out.write_all(&addr.ip().octets())?;
out.write_u16::<NetworkEndian>(addr.port())?; out.write_u16::<NetworkEndian>(addr.port())?;
} }
_ => unreachable!() _ => unreachable!(),
} }
Ok(()) Ok(())
} }
@ -87,7 +87,7 @@ fn serve_proxy_connection(stream: TcpStream) -> Result<(), io::Error> {
WaitResult::Timeout => { WaitResult::Timeout => {
io_error!(websocket.write_message(Message::Ping(vec![])), "Failed to send ping: {}")?; io_error!(websocket.write_message(Message::Ping(vec![])), "Failed to send ping: {}")?;
} }
WaitResult::Error(err) => return Err(err) WaitResult::Error(err) => return Err(err),
} }
} }
Ok(()) Ok(())
@ -112,7 +112,7 @@ pub fn run_proxy(listen: &str) -> Result<(), io::Error> {
#[derive(Clone)] #[derive(Clone)]
pub struct ProxyConnection { pub struct ProxyConnection {
addr: SocketAddr, addr: SocketAddr,
socket: Arc<WebSocket<AutoStream>> socket: Arc<WebSocket<AutoStream>>,
} }
impl ProxyConnection { impl ProxyConnection {
@ -128,14 +128,9 @@ impl ProxyConnection {
} }
} }
impl AsRawFd for ProxyConnection { #[async_trait]
fn as_raw_fd(&self) -> RawFd {
self.socket.get_ref().as_raw_fd()
}
}
impl Socket for ProxyConnection { impl Socket for ProxyConnection {
fn listen(url: &str) -> Result<Self, io::Error> { async fn listen(url: &str) -> Result<Self, io::Error> {
let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?;
let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?;
socket.get_mut().set_nodelay(true)?; socket.get_mut().set_nodelay(true)?;
@ -146,7 +141,7 @@ impl Socket for ProxyConnection {
Ok(con) Ok(con)
} }
fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
buffer.clear(); buffer.clear();
let data = self.read_message()?; let data = self.read_message()?;
let addr = read_addr(Cursor::new(&data))?; let addr = read_addr(Cursor::new(&data))?;
@ -154,7 +149,7 @@ impl Socket for ProxyConnection {
Ok(addr) Ok(addr)
} }
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> { async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> {
let mut msg = Vec::with_capacity(data.len() + 18); let mut msg = Vec::with_capacity(data.len() + 18);
write_addr(addr, &mut msg)?; write_addr(addr, &mut msg)?;
msg.write_all(data)?; msg.write_all(data)?;
@ -165,11 +160,11 @@ impl Socket for ProxyConnection {
*/ */
} }
fn address(&self) -> Result<SocketAddr, io::Error> { async fn address(&self) -> Result<SocketAddr, io::Error> {
Ok(self.addr) Ok(self.addr)
} }
fn create_port_forwarding(&self) -> Option<PortForwarding> { async fn create_port_forwarding(&self) -> Option<PortForwarding> {
None None
} }
} }