Compare commits

..

4 Commits

Author SHA1 Message Date
Dennis Schwerdel edb841adb1 Imports 2021-05-12 23:48:13 +02:00
Dennis Schwerdel 7163412db5 Close to fix ws proxy 2021-05-12 00:06:33 +02:00
Dennis Schwerdel a6b9bad583 Fix portforwarding 2021-05-11 23:51:36 +02:00
Dennis Schwerdel a03d367293 Fix daemonize, break websocket and port-forward 2021-05-11 23:38:03 +02:00
13 changed files with 356 additions and 190 deletions

4
Cargo.lock generated
View File

@ -1143,9 +1143,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.2.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8190d04c665ea9e6b6a0dc45523ade572c088d2e6566244c1122671dbf4ae3a" checksum = "83f0c8e7c0addab50b663055baf787d0af7f413a46e6e7fb9559a4e4db7137a5"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"bytes", "bytes",

View File

@ -36,7 +36,7 @@ dialoguer = { version = "0.8", optional = true }
tungstenite = { version = "0.13", optional = true, default-features = false } tungstenite = { version = "0.13", optional = true, default-features = false }
url = { version = "2.2", optional = true } url = { version = "2.2", optional = true }
igd = { version = "0.12", optional = true } igd = { version = "0.12", optional = true }
tokio = { version = "1", features = ["full"] } tokio = { version = "^1.5", features = ["full"] }
async-trait = "0.1" async-trait = "0.1"

View File

@ -49,25 +49,27 @@ class Node:
def run_cmd(self, cmd): def run_cmd(self, cmd):
return run_cmd(self.connection, cmd) return run_cmd(self.connection, cmd)
def start_vpncloud(self, ip=None, crypto=None, password="test", device_type="tun", listen="3210", mode="normal", peers=[], claims=[]): def start_vpncloud(self, ip=None, crypto=None, password="test", device_type="tun", listen="3210", mode="normal", peers=[], claims=[], logfile="/var/log/vpncloud.log", extra_args=[]):
args = [ args = [
"--daemon", "--daemon",
"--no-port-forwarding", "--no-port-forwarding",
"-t {}".format(device_type), f"-t {device_type}",
"-m {}".format(mode), f"-m {mode}",
"-l {}".format(listen), f"-l {listen}",
"--password '{}'".format(password) f"--password '{password}'",
f"--log-file '{logfile}'",
*extra_args
] ]
if ip: if ip:
args.append("--ip {}".format(ip)) args.append(f"--ip {ip}")
if crypto: if crypto:
args.append("--algo {}".format(crypto)) args.append(f"--algo {crypto}")
for p in peers: for p in peers:
args.append("-c {}".format(p)) args.append(f"-c {p}")
for c in claims: for c in claims:
args.append("--claim {}".format(c)) args.append(f"--claim {c}")
args = " ".join(args) args = " ".join(args)
self.run_cmd("sudo vpncloud {}".format(args)) self.run_cmd(f"sudo vpncloud {args}")
def stop_vpncloud(self, wait=True): def stop_vpncloud(self, wait=True):
self.run_cmd("sudo killall vpncloud") self.run_cmd("sudo killall vpncloud")
@ -75,7 +77,7 @@ class Node:
time.sleep(3.0) time.sleep(3.0)
def ping(self, dst, size=100, count=10, interval=0.001): def ping(self, dst, size=100, count=10, interval=0.001):
(out, _) = self.run_cmd('sudo ping {dst} -c {count} -i {interval} -s {size} -U -q'.format(dst=dst, size=size, count=count, interval=interval)) (out, _) = self.run_cmd(f'sudo ping {dst} -c {count} -i {interval} -s {size} -U -q')
match = re.search(r'([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)', out) match = re.search(r'([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)', out)
ping_min = float(match.group(1)) ping_min = float(match.group(1))
ping_avg = float(match.group(2)) ping_avg = float(match.group(2))
@ -97,7 +99,7 @@ class Node:
self.run_cmd('killall iperf3') self.run_cmd('killall iperf3')
def run_iperf(self, dst, duration): def run_iperf(self, dst, duration):
(out, _) = self.run_cmd('iperf3 -c {dst} -t {duration} --json'.format(dst=dst, duration=duration)) (out, _) = self.run_cmd(f'iperf3 -c {dst} -t {duration} --json')
data = json.loads(out) data = json.loads(out)
return { return {
"throughput": data['end']['streams'][0]['receiver']['bits_per_second'], "throughput": data['end']['streams'][0]['receiver']['bits_per_second'],
@ -197,7 +199,7 @@ class EC2Environment:
if self.subnet == CREATE: if self.subnet == CREATE:
self.setup_vpc() self.setup_vpc()
else: else:
eprint("\tUsing subnet {}".format(self.subnet)) eprint(f"\tUsing subnet {self.subnet}")
vpc = ec2.Subnet(self.subnet).vpc vpc = ec2.Subnet(self.subnet).vpc
@ -209,7 +211,7 @@ class EC2Environment:
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='udp', FromPort=0, ToPort=65535) sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='udp', FromPort=0, ToPort=65535)
if self.keyname == CREATE: if self.keyname == CREATE:
key_pair = ec2.create_key_pair(KeyName="{}-keypair".format(self.tag)) key_pair = ec2.create_key_pair(KeyName=f"{self.tag}-keypair")
self.track_resource(key_pair) self.track_resource(key_pair)
self.keyname = key_pair.name self.keyname = key_pair.name
self.privatekey = key_pair.key_material self.privatekey = key_pair.key_material
@ -217,7 +219,7 @@ class EC2Environment:
placement = {} placement = {}
if self.cluster_nodes: if self.cluster_nodes:
placement_group = ec2.create_placement_group(GroupName="{}-placement".format(self.tag), Strategy="cluster") placement_group = ec2.create_placement_group(GroupName=f"{self.tag}-placement", Strategy="cluster")
self.track_resource(placement_group) self.track_resource(placement_group)
placement = { 'GroupName': placement_group.name } placement = { 'GroupName': placement_group.name }
@ -227,12 +229,11 @@ packages:
- socat - socat
""" """
if not self.vpncloud_file: if not self.vpncloud_file:
userdata += """ userdata += f"""
runcmd: runcmd:
- wget https://github.com/dswd/vpncloud/releases/download/v{version}/vpncloud_{version}.x86_64.rpm -O /tmp/vpncloud.rpm - wget https://github.com/dswd/vpncloud/releases/download/v{self.vpncloud_version}/vpncloud_{self.vpncloud_version}.x86_64.rpm -O /tmp/vpncloud.rpm
- yum install -y /tmp/vpncloud.rpm - yum install -y /tmp/vpncloud.rpm
""".format(version=self.vpncloud_version) """
if self.use_spot: if self.use_spot:
response = ec2client.request_spot_instances( response = ec2client.request_spot_instances(
SpotPrice = self.max_price, SpotPrice = self.max_price,
@ -368,7 +369,7 @@ runcmd:
eprint("Deleting resources...") eprint("Deleting resources...")
ec2client = boto3.client('ec2', region_name=self.region) ec2client = boto3.client('ec2', region_name=self.region)
for res in reversed(self.resources): for res in reversed(self.resources):
eprint("\t{} {}".format(res.__class__.__name__, res.id if hasattr(res, "id") else "")) eprint(f"\t{res.__class__.__name__} {res.id if hasattr(res, 'id') else ''}")
if isinstance(res, SpotInstanceRequest): if isinstance(res, SpotInstanceRequest):
ec2client.cancel_spot_instance_requests(SpotInstanceRequestIds=[res.id]) ec2client.cancel_spot_instance_requests(SpotInstanceRequestIds=[res.id])
if hasattr(res, "attachments"): if hasattr(res, "attachments"):

View File

@ -14,7 +14,7 @@ sender = setup.nodes[0]
receiver = setup.nodes[1] receiver = setup.nodes[1]
sender.start_vpncloud(ip="10.0.0.1/24") sender.start_vpncloud(ip="10.0.0.1/24")
receiver.start_vpncloud(ip="10.0.0.2/24", peers=["{}:3210".format(sender.private_ip)]) receiver.start_vpncloud(ip="10.0.0.2/24", peers=[f"{sender.private_ip}:3210"])
time.sleep(1.0) time.sleep(1.0)
sender.ping("10.0.0.2") sender.ping("10.0.0.2")

View File

@ -47,11 +47,11 @@ class PerfTest:
return cls(env.nodes[0], env.nodes[1], meta) return cls(env.nodes[0], env.nodes[1], meta)
def run_ping(self, dst, size): def run_ping(self, dst, size):
eprint("\tRunning ping {} with size {} ...".format(dst, size)) eprint(f"\tRunning ping {dst} with size {size} ...")
return self.sender.ping(dst=dst, size=size, count=30000, interval=0.001) return self.sender.ping(dst=dst, size=size, count=30000, interval=0.001)
def run_iperf(self, dst): def run_iperf(self, dst):
eprint("\tRunning iperf on {} ...".format(dst)) eprint(f"\tRunning iperf on {dst} ...")
self.receiver.start_iperf_server() self.receiver.start_iperf_server()
time.sleep(0.1) time.sleep(0.1)
result = self.sender.run_iperf(dst=dst, duration=30) result = self.sender.run_iperf(dst=dst, duration=30)
@ -68,9 +68,9 @@ class PerfTest:
def start_vpncloud(self, crypto=None): def start_vpncloud(self, crypto=None):
eprint("\tSetting up vpncloud on receiver") eprint("\tSetting up vpncloud on receiver")
self.receiver.start_vpncloud(crypto=crypto, ip="{}/24".format(self.receiver_ip_vpncloud)) self.receiver.start_vpncloud(crypto=crypto, ip=f"{self.receiver_ip_vpncloud}/24")
eprint("\tSetting up vpncloud on sender") eprint("\tSetting up vpncloud on sender")
self.sender.start_vpncloud(crypto=crypto, peers=["{}:3210".format(self.receiver.private_ip)], ip="{}/24".format(self.sender_ip_vpncloud)) self.sender.start_vpncloud(crypto=crypto, peers=[f"{self.receiver_ip_vpncloud}:3210"], ip=f"{self.sender_ip_vpncloud}/24")
time.sleep(1.0) time.sleep(1.0)
def stop_vpncloud(self): def stop_vpncloud(self):
@ -84,7 +84,7 @@ class PerfTest:
"native": self.run_suite(self.receiver.private_ip) "native": self.run_suite(self.receiver.private_ip)
} }
for crypto in CRYPTO: for crypto in CRYPTO:
eprint("Running with crypto {}".format(crypto)) eprint(f"Running with crypto {crypto}")
self.start_vpncloud(crypto=crypto) self.start_vpncloud(crypto=crypto)
res = self.run_suite(self.receiver_ip_vpncloud) res = self.run_suite(self.receiver_ip_vpncloud)
self.stop_vpncloud() self.stop_vpncloud()
@ -109,8 +109,8 @@ duration = time.time() - start
results["meta"]["duration"] = duration results["meta"]["duration"] = duration
name = "measurements/{date}_{version}_perf.json".format(date=date.today().strftime('%Y-%m-%d'), version=VERSION) name = f"measurements/{date.today().strftime('%Y-%m-%d')}_{VERSION}_perf.json"
eprint('Storing results in {}'.format(name)) eprint(f'Storing results in {name}')
with open(name, 'w') as fp: with open(name, 'w') as fp:
json.dump(results, fp, indent=2) json.dump(results, fp, indent=2)
eprint("done.") eprint("done.")

100
contrib/aws/quick_perf.py Executable file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
from common import EC2Environment, CREATE, eprint
import time, json, os, atexit
from datetime import date
# Note: this script will run for ~8 minutes and incur costs of about $ 0.02
FILE = "../../target/release/vpncloud"
VERSION = "2.2.0"
REGION = "eu-central-1"
env = EC2Environment(
region = REGION,
node_count = 2,
instance_type = "m5.large",
use_spot = True,
max_price = "0.08", # USD per hour per VM
vpncloud_version = VERSION,
vpncloud_file = FILE,
cluster_nodes = True,
subnet = CREATE,
keyname = CREATE
)
CRYPTO = ["plain", "aes256", "aes128", "chacha20"]
class PerfTest:
def __init__(self, sender, receiver):
self.sender = sender
self.receiver = receiver
self.sender_ip_vpncloud = "10.0.0.1"
self.receiver_ip_vpncloud = "10.0.0.2"
@classmethod
def from_ec2_env(cls, env):
return cls(env.nodes[0], env.nodes[1])
def run_ping(self, dst, size):
eprint(f"\tRunning ping {dst} with size {size} ...")
return self.sender.ping(dst=dst, size=size, count=30000, interval=0.001)
def run_iperf(self, dst):
eprint(f"\tRunning iperf on {dst} ...")
self.receiver.start_iperf_server()
time.sleep(0.1)
result = self.sender.run_iperf(dst=dst, duration=30)
self.receiver.stop_iperf_server()
return result
def start_vpncloud(self):
eprint("\tSetting up vpncloud on receiver")
self.receiver.start_vpncloud(ip=f"{self.receiver_ip_vpncloud}/24")
eprint("\tSetting up vpncloud on sender")
self.sender.start_vpncloud(peers=[f"{self.receiver.private_ip}:3210"], ip=f"{self.sender_ip_vpncloud}/24")
time.sleep(1.0)
def stop_vpncloud(self):
self.sender.stop_vpncloud(wait=False)
self.receiver.stop_vpncloud(wait=True)
def run(self):
print()
self.start_vpncloud()
throughput = self.run_iperf(self.receiver_ip_vpncloud)["throughput"]
print(f"Throughput: {throughput / 1_000_000.0} MBit/s")
native_ping_100 = self.run_ping(self.receiver.private_ip, 100)["rtt_avg"]
ping_100 = self.run_ping(self.receiver_ip_vpncloud, 100)["rtt_avg"]
print(f"Latency 100: +{(ping_100 - native_ping_100)*1000.0/2.0} µs")
native_ping_1000 = self.run_ping(self.receiver.private_ip, 1000)["rtt_avg"]
ping_1000 = self.run_ping(self.receiver_ip_vpncloud, 1000)["rtt_avg"]
print(f"Latency 1000: +{(ping_1000 - native_ping_1000)*1000.0/2.0} µs")
self.stop_vpncloud()
keyfile = "key.pem"
assert not os.path.exists(keyfile)
with open(keyfile, 'x') as fp:
fp.write(env.privatekey)
os.chmod(keyfile, 0o400)
print(f"SSH private key written to {keyfile}")
atexit.register(lambda : os.remove(keyfile))
print()
print("Nodes:")
for node in env.nodes:
print(f"\t {env.username}@{node.public_ip}\tprivate: {node.private_ip}")
print()
perf = PerfTest.from_ec2_env(env)
try:
perf.run()
except Exception as e:
eprint(f"Exception: {e}")
print("Press ENTER to shut down")
input()
eprint("done.")

View File

@ -48,13 +48,13 @@ if not args.keyname:
with open(args.keyfile, 'x') as fp: with open(args.keyfile, 'x') as fp:
fp.write(setup.privatekey) fp.write(setup.privatekey)
os.chmod(args.keyfile, 0o400) os.chmod(args.keyfile, 0o400)
print("SSH private key written to {}".format(args.keyfile)) print(f"SSH private key written to {args.keyfile}")
atexit.register(lambda : os.remove(args.keyfile)) atexit.register(lambda : os.remove(args.keyfile))
print() print()
print("Nodes:") print("Nodes:")
for node in setup.nodes: for node in setup.nodes:
print("\t {}@{}\tprivate: {}".format(setup.username, node.public_ip, node.private_ip)) print(f"\t {setup.username}@{node.public_ip}\tprivate: {node.private_ip}")
print() print()
print("Press ENTER to shut down") print("Press ENTER to shut down")

View File

@ -9,15 +9,16 @@ use std::{
collections::VecDeque, collections::VecDeque,
convert::TryInto, convert::TryInto,
fmt, fmt,
io::{self, Cursor, Error as IoError}, io::{self, Cursor, Read, Write, Error as IoError, BufReader, BufRead},
net::{Ipv4Addr, UdpSocket}, net::{Ipv4Addr, UdpSocket},
fs::{self, File},
os::unix::io::AsRawFd, os::unix::io::AsRawFd,
str, str,
str::FromStr, str::FromStr,
sync::Arc, sync::Arc,
}; };
use tokio::fs::{self, File}; use tokio::fs::{File as AsyncFile};
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::{crypto, error::Error, util::MsgBuffer}; use crate::{crypto, error::Error, util::MsgBuffer};
@ -83,9 +84,6 @@ pub trait Device: Send + 'static + Sized {
/// Returns the type of this device /// Returns the type of this device
fn get_type(&self) -> Type; fn get_type(&self) -> Type;
/// Returns the interface name of this device.
fn ifname(&self) -> &str;
/// Reads a packet/frame from the device /// Reads a packet/frame from the device
/// ///
/// This method reads one packet or frame (depending on the device type) into the `buffer`. /// This method reads one packet or frame (depending on the device type) into the `buffer`.
@ -141,9 +139,9 @@ impl TunTapDevice {
/// # Panics /// # Panics
/// This method panics if the interface name is longer than 31 bytes. /// This method panics if the interface name is longer than 31 bytes.
#[allow(clippy::useless_conversion)] #[allow(clippy::useless_conversion)]
pub async fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result<Self> { pub fn new(ifname: &str, type_: Type, path: Option<&str>) -> io::Result<Self> {
let path = path.unwrap_or_else(|| Self::default_path(type_)); let path = path.unwrap_or_else(|| Self::default_path(type_));
let fd = fs::OpenOptions::new().read(true).write(true).open(path).await?; let fd = fs::OpenOptions::new().read(true).write(true).open(path)?;
let flags = match type_ { let flags = match type_ {
Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI, Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI,
Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI, Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI,
@ -155,7 +153,7 @@ impl TunTapDevice {
0 => { 0 => {
let mut ifname = String::with_capacity(32); let mut ifname = String::with_capacity(32);
let mut cursor = Cursor::new(ifreq.ifr_name); let mut cursor = Cursor::new(ifreq.ifr_name);
cursor.read_to_string(&mut ifname).await?; Read::read_to_string(&mut cursor, &mut ifname)?;
ifname = ifname.trim_end_matches('\0').to_owned(); ifname = ifname.trim_end_matches('\0').to_owned();
Ok(Self { fd, ifname, type_ }) Ok(Self { fd, ifname, type_ })
} }
@ -163,6 +161,10 @@ impl TunTapDevice {
} }
} }
pub fn ifname(&self) -> &str {
&self.ifname
}
/// Returns the default device path for a given type /// Returns the default device path for a given type
#[inline] #[inline]
pub fn default_path(type_: Type) -> &'static str { pub fn default_path(type_: Type) -> &'static str {
@ -171,6 +173,69 @@ impl TunTapDevice {
} }
} }
pub fn get_overhead(&self) -> usize {
40 /* for outer IPv6 header, can't be sure to only have IPv4 peers */
+ 8 /* for outer UDP header */
+ crypto::EXTRA_LEN + crypto::TAG_LEN /* crypto overhead */
+ 1 /* message type header */
+ match self.type_ {
Type::Tap => 14, /* inner ethernet header */
Type::Tun => 0
}
}
pub fn set_mtu(&self, value: Option<usize>) -> io::Result<()> {
let value = match value {
Some(value) => value,
None => {
let default_device = get_default_device()?;
info!("Deriving MTU from default device {}", default_device);
get_device_mtu(&default_device)? - self.get_overhead()
}
};
info!("Setting MTU {} on device {}", value, self.ifname);
set_device_mtu(&self.ifname, value)
}
pub fn configure(&self, addr: Ipv4Addr, netmask: Ipv4Addr) -> io::Result<()> {
set_device_addr(&self.ifname, addr)?;
set_device_netmask(&self.ifname, netmask)?;
set_device_enabled(&self.ifname, true)
}
pub fn get_rp_filter(&self) -> io::Result<u8> {
Ok(cmp::max(get_rp_filter("all")?, get_rp_filter(&self.ifname)?))
}
pub fn fix_rp_filter(&self) -> io::Result<()> {
if get_rp_filter("all")? > 1 {
info!("Setting net.ipv4.conf.all.rp_filter=1");
set_rp_filter("all", 1)?
}
if get_rp_filter(&self.ifname)? != 1 {
info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname);
set_rp_filter(&self.ifname, 1)?
}
Ok(())
}
}
/// Represents a tun/tap device
pub struct AsyncTunTapDevice {
fd: AsyncFile,
ifname: String,
type_: Type,
}
impl AsyncTunTapDevice {
pub fn from_sync(dev: TunTapDevice) -> Self {
Self {
fd: AsyncFile::from_std(dev.fd),
ifname: dev.ifname,
type_: dev.type_
}
}
#[cfg(any(target_os = "linux", target_os = "android"))] #[cfg(any(target_os = "linux", target_os = "android"))]
#[inline] #[inline]
fn correct_data_after_read(&mut self, _buffer: &mut MsgBuffer) {} fn correct_data_after_read(&mut self, _buffer: &mut MsgBuffer) {}
@ -219,64 +284,14 @@ impl TunTapDevice {
} }
} }
} }
pub fn get_overhead(&self) -> usize {
40 /* for outer IPv6 header, can't be sure to only have IPv4 peers */
+ 8 /* for outer UDP header */
+ crypto::EXTRA_LEN + crypto::TAG_LEN /* crypto overhead */
+ 1 /* message type header */
+ match self.type_ {
Type::Tap => 14, /* inner ethernet header */
Type::Tun => 0
}
}
pub async fn set_mtu(&self, value: Option<usize>) -> io::Result<()> {
let value = match value {
Some(value) => value,
None => {
let default_device = get_default_device().await?;
info!("Deriving MTU from default device {}", default_device);
get_device_mtu(&default_device)? - self.get_overhead()
}
};
info!("Setting MTU {} on device {}", value, self.ifname);
set_device_mtu(&self.ifname, value)
}
pub fn configure(&self, addr: Ipv4Addr, netmask: Ipv4Addr) -> io::Result<()> {
set_device_addr(&self.ifname, addr)?;
set_device_netmask(&self.ifname, netmask)?;
set_device_enabled(&self.ifname, true)
}
pub async fn get_rp_filter(&self) -> io::Result<u8> {
Ok(cmp::max(get_rp_filter("all").await?, get_rp_filter(&self.ifname).await?))
}
pub async fn fix_rp_filter(&self) -> io::Result<()> {
if get_rp_filter("all").await? > 1 {
info!("Setting net.ipv4.conf.all.rp_filter=1");
set_rp_filter("all", 1).await?
}
if get_rp_filter(&self.ifname).await? != 1 {
info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname);
set_rp_filter(&self.ifname, 1).await?
}
Ok(())
}
} }
#[async_trait] #[async_trait]
impl Device for TunTapDevice { impl Device for AsyncTunTapDevice {
fn get_type(&self) -> Type { fn get_type(&self) -> Type {
self.type_ self.type_
} }
fn ifname(&self) -> &str {
&self.ifname
}
async fn duplicate(&self) -> Result<Self, Error> { async fn duplicate(&self) -> Result<Self, Error> {
Ok(Self { Ok(Self {
fd: self.fd.try_clone().await.map_err(|e| Error::DeviceIo("Failed to clone device", e))?, fd: self.fd.try_clone().await.map_err(|e| Error::DeviceIo("Failed to clone device", e))?,
@ -340,10 +355,6 @@ impl Device for MockDevice {
Type::Tun Type::Tun
} }
fn ifname(&self) -> &str {
"mock0"
}
async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { async fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
if let Some(data) = self.inbound.lock().pop_front() { if let Some(data) = self.inbound.lock().pop_front() {
buffer.clear(); buffer.clear();
@ -477,13 +488,13 @@ fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> {
} }
} }
async fn get_default_device() -> io::Result<String> { fn get_default_device() -> io::Result<String> {
let mut fd = BufReader::new(File::open("/proc/net/route").await?); let mut fd = BufReader::new(File::open("/proc/net/route")?);
let mut best = None; let mut best = None;
let mut line = String::with_capacity(80); let mut line = String::with_capacity(80);
fd.read_line(&mut line).await?; fd.read_line(&mut line)?;
line.clear(); line.clear();
while let Ok(read) = fd.read_line(&mut line).await { while let Ok(read) = fd.read_line(&mut line) {
if read == 0 { if read == 0 {
break; break;
} }
@ -504,14 +515,14 @@ async fn get_default_device() -> io::Result<String> {
} }
} }
async fn get_rp_filter(device: &str) -> io::Result<u8> { fn get_rp_filter(device: &str) -> io::Result<u8> {
let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?; let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?;
let mut contents = String::with_capacity(10); let mut contents = String::with_capacity(10);
fd.read_to_string(&mut contents).await?; fd.read_to_string(&mut contents)?;
u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value")) u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value"))
} }
async fn set_rp_filter(device: &str, val: u8) -> io::Result<()> { fn set_rp_filter(device: &str, val: u8) -> io::Result<()> {
let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device)).await?; let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?;
fd.write_all(format!("{}", val).as_bytes()).await fd.write_all(format!("{}", val).as_bytes())
} }

View File

@ -21,6 +21,7 @@ use super::common::PeerData;
#[derive(Clone)] #[derive(Clone)]
pub struct SharedPeerCrypto { pub struct SharedPeerCrypto {
peers: Arc<Mutex<HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash>>>, peers: Arc<Mutex<HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash>>>,
//TODO: local hashmap as cache
} }
impl SharedPeerCrypto { impl SharedPeerCrypto {
@ -29,6 +30,7 @@ impl SharedPeerCrypto {
} }
pub fn encrypt_for(&self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { pub fn encrypt_for(&self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> {
//TODO: use cache first
let mut peers = self.peers.lock(); let mut peers = self.peers.lock();
match peers.get_mut(&peer) { match peers.get_mut(&peer) {
None => Err(Error::InvalidCryptoState("No crypto found for peer")), None => Err(Error::InvalidCryptoState("No crypto found for peer")),
@ -41,6 +43,7 @@ impl SharedPeerCrypto {
} }
pub fn store(&self, data: &HashMap<SocketAddr, PeerData, Hash>) { pub fn store(&self, data: &HashMap<SocketAddr, PeerData, Hash>) {
//TODO: store in shared and in cache
let mut peers = self.peers.lock(); let mut peers = self.peers.lock();
peers.clear(); peers.clear();
peers.extend(data.iter().map(|(k, v)| (*k, v.crypto.get_core()))); peers.extend(data.iter().map(|(k, v)| (*k, v.crypto.get_core())));
@ -51,6 +54,7 @@ impl SharedPeerCrypto {
} }
pub fn get_snapshot(&self) -> HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash> { pub fn get_snapshot(&self) -> HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash> {
//TODO: return local cache
self.peers.lock().clone() self.peers.lock().clone()
} }
@ -121,6 +125,8 @@ impl SharedTraffic {
#[derive(Clone)] #[derive(Clone)]
pub struct SharedTable<TS: TimeSource> { pub struct SharedTable<TS: TimeSource> {
table: Arc<Mutex<ClaimTable<TS>>>, table: Arc<Mutex<ClaimTable<TS>>>,
//TODO: local reader lookup table Addr => Option<SocketAddr>
//TODO: local writer cache Addr => SocketAddr
} }
impl<TS: TimeSource> SharedTable<TS> { impl<TS: TimeSource> SharedTable<TS> {
@ -131,21 +137,29 @@ impl<TS: TimeSource> SharedTable<TS> {
pub fn sync(&mut self) { pub fn sync(&mut self) {
// TODO sync if needed // TODO sync if needed
// once every x seconds
// fetch reader cache
// clear writer cache
} }
pub fn lookup(&mut self, addr: Address) -> Option<SocketAddr> { pub fn lookup(&mut self, addr: Address) -> Option<SocketAddr> {
// TODO: use local reader cache
// if not found, use shared table and put into cache
self.table.lock().lookup(addr) self.table.lock().lookup(addr)
} }
pub fn set_claims(&mut self, peer: SocketAddr, claims: RangeList) { pub fn set_claims(&mut self, peer: SocketAddr, claims: RangeList) {
// clear writer cache
self.table.lock().set_claims(peer, claims) self.table.lock().set_claims(peer, claims)
} }
pub fn remove_claims(&mut self, peer: SocketAddr) { pub fn remove_claims(&mut self, peer: SocketAddr) {
// clear writer cache
self.table.lock().remove_claims(peer) self.table.lock().remove_claims(peer)
} }
pub fn cache(&mut self, addr: Address, peer: SocketAddr) { pub fn cache(&mut self, addr: Address, peer: SocketAddr) {
// check writer cache and only write real updates to shared table
self.table.lock().cache(addr, peer) self.table.lock().cache(addr, peer)
} }
@ -154,14 +168,17 @@ impl<TS: TimeSource> SharedTable<TS> {
} }
pub fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error> { pub fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error> {
//TODO: stats call
self.table.lock().write_out(out) self.table.lock().write_out(out)
} }
pub fn cache_len(&self) -> usize { pub fn cache_len(&self) -> usize {
//TODO: stats call
self.table.lock().cache_len() self.table.lock().cache_len()
} }
pub fn claim_len(&self) -> usize { pub fn claim_len(&self) -> usize {
//TODO: stats call
self.table.lock().claim_len() self.table.lock().claim_len()
} }
} }

View File

@ -386,6 +386,7 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
let src = mapped_addr(src); let src = mapped_addr(src);
debug!("Received {} bytes from {}", self.buffer.len(), src); debug!("Received {} bytes from {}", self.buffer.len(), src);
let buffer = &mut self.buffer; let buffer = &mut self.buffer;
self.traffic.count_in_traffic(src, buffer.len());
if let Some(result) = self.peers.get_mut(&src).map(|peer| peer.crypto.handle_message(buffer)) { if let Some(result) = self.peers.get_mut(&src).map(|peer| peer.crypto.handle_message(buffer)) {
return self.process_message(src, result?).await; return self.process_message(src, result?).await;
} }

View File

@ -2,9 +2,12 @@
// Copyright (C) 2015-2021 Dennis Schwerdel // Copyright (C) 2015-2021 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
#[macro_use] extern crate log; #[macro_use]
#[macro_use] extern crate serde; extern crate log;
#[macro_use] extern crate tokio; #[macro_use]
extern crate serde;
#[macro_use]
extern crate tokio;
#[cfg(test)] #[cfg(test)]
extern crate tempfile; extern crate tempfile;
@ -15,10 +18,10 @@ pub mod util;
#[macro_use] #[macro_use]
mod tests; mod tests;
pub mod beacon; pub mod beacon;
pub mod engine;
pub mod config; pub mod config;
pub mod crypto; pub mod crypto;
pub mod device; pub mod device;
pub mod engine;
pub mod error; pub mod error;
#[cfg(feature = "installer")] #[cfg(feature = "installer")]
pub mod installer; pub mod installer;
@ -36,12 +39,14 @@ pub mod wizard;
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
pub mod wsproxy; pub mod wsproxy;
use net::SocketBuilder;
use structopt::StructOpt; use structopt::StructOpt;
use tokio::runtime::Runtime;
use std::{ use std::{
fs::{self, File, Permissions}, fs::{self, File, Permissions},
io::{self, Write}, io::{self, Write},
net::{Ipv4Addr}, net::Ipv4Addr,
os::unix::fs::PermissionsExt, os::unix::fs::PermissionsExt,
path::Path, path::Path,
process, process,
@ -51,11 +56,11 @@ use std::{
}; };
use crate::{ use crate::{
engine::common::GenericCloud,
config::{Args, Command, Config, DEFAULT_PORT}, config::{Args, Command, Config, DEFAULT_PORT},
crypto::Crypto, crypto::Crypto,
device::{Device, TunTapDevice, Type}, device::{AsyncTunTapDevice, TunTapDevice, Type},
net::{Socket, NetSocket}, engine::common::GenericCloud,
net::NetSocket,
oldconfig::OldConfigFile, oldconfig::OldConfigFile,
payload::Protocol, payload::Protocol,
util::SystemTimeSource, util::SystemTimeSource,
@ -139,16 +144,16 @@ fn parse_ip_netmask(addr: &str) -> Result<(Ipv4Addr, Ipv4Addr), String> {
Ok((ip, netmask)) Ok((ip, netmask))
} }
async fn setup_device(config: &Config) -> TunTapDevice { fn setup_device(config: &Config) -> TunTapDevice {
let device = try_fail!( let device = try_fail!(
TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)).await, TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)),
"Failed to open virtual {} interface {}: {}", "Failed to open virtual {} interface {}: {}",
config.device_type, config.device_type,
config.device_name config.device_name
); );
info!("Opened device {}", device.ifname()); info!("Opened device {}", device.ifname());
config.call_hook("device_setup", vec![("IFNAME", device.ifname())], true); config.call_hook("device_setup", vec![("IFNAME", device.ifname())], true);
if let Err(err) = device.set_mtu(config.device_mtu).await { if let Err(err) = device.set_mtu(config.device_mtu) {
error!("Error setting MTU on {}: {}", device.ifname(), err); error!("Error setting MTU on {}: {}", device.ifname(), err);
} }
if let Some(ip) = &config.ip { if let Some(ip) = &config.ip {
@ -160,9 +165,9 @@ async fn setup_device(config: &Config) -> TunTapDevice {
run_script(script, device.ifname()); run_script(script, device.ifname());
} }
if config.fix_rp_filter { if config.fix_rp_filter {
try_fail!(device.fix_rp_filter().await, "Failed to change rp_filter settings: {}"); try_fail!(device.fix_rp_filter(), "Failed to change rp_filter settings: {}");
} }
if let Ok(val) = device.get_rp_filter().await { if let Ok(val) = device.get_rp_filter() {
if val != 1 { if val != 1 {
warn!("Your networking configuration might be affected by a vulnerability (https://vpncloud.ddswd.de/docs/security/cve-2019-14899/), please change your rp_filter setting to 1 (currently {}).", val); warn!("Your networking configuration might be affected by a vulnerability (https://vpncloud.ddswd.de/docs/security/cve-2019-14899/), please change your rp_filter setting to 1 (currently {}).", val);
} }
@ -172,9 +177,9 @@ async fn setup_device(config: &Config) -> TunTapDevice {
} }
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
async fn run<P: Protocol, S: Socket>(config: Config, socket: S) { fn run<P: Protocol, S: SocketBuilder>(config: Config, socket: S) {
let device = setup_device(&config).await; let device = setup_device(&config);
let port_forwarding = if config.port_forwarding { socket.create_port_forwarding().await } else { None }; let port_forwarding = if config.port_forwarding { socket.create_port_forwarding() } else { None };
let stats_file = match config.stats_file { let stats_file = match config.stats_file {
None => None, None => None,
Some(ref name) => { Some(ref name) => {
@ -191,25 +196,16 @@ async fn run<P: Protocol, S: Socket>(config: Config, socket: S) {
} }
}; };
let ifname = device.ifname().to_string(); let ifname = device.ifname().to_string();
let mut cloud =
try_fail!(GenericCloud::<TunTapDevice, P, S, SystemTimeSource>::new(&config, socket, device, port_forwarding, stats_file).await, "Failed to create engine: {}");
for mut addr in config.peers {
if addr.find(':').unwrap_or(0) <= addr.find(']').unwrap_or(0) {
// : not present or only in IPv6 address
addr = format!("{}:{}", addr, DEFAULT_PORT)
}
try_fail!(cloud.add_peer(addr.clone()), "Failed to send message to {}: {}", &addr);
}
if config.daemonize { if config.daemonize {
info!("Running process as daemon"); info!("Running process as daemon");
let mut daemonize = daemonize::Daemonize::new(); let mut daemonize = daemonize::Daemonize::new();
if let Some(user) = config.user { if let Some(user) = &config.user {
daemonize = daemonize.user(&user as &str); daemonize = daemonize.user(user as &str);
} }
if let Some(group) = config.group { if let Some(group) = &config.group {
daemonize = daemonize.group(&group as &str); daemonize = daemonize.group(group as &str);
} }
if let Some(pid_file) = config.pid_file { if let Some(pid_file) = &config.pid_file {
daemonize = daemonize.pid_file(pid_file).chown_pid_file(true); daemonize = daemonize.pid_file(pid_file).chown_pid_file(true);
// Give child process some time to write PID file // Give child process some time to write PID file
daemonize = daemonize.exit_action(|| thread::sleep(std::time::Duration::from_millis(10))); daemonize = daemonize.exit_action(|| thread::sleep(std::time::Duration::from_millis(10)));
@ -218,22 +214,47 @@ async fn run<P: Protocol, S: Socket>(config: Config, socket: S) {
} else if config.user.is_some() || config.group.is_some() { } else if config.user.is_some() || config.group.is_some() {
info!("Dropping privileges"); info!("Dropping privileges");
let mut pd = privdrop::PrivDrop::default(); let mut pd = privdrop::PrivDrop::default();
if let Some(user) = config.user { if let Some(user) = &config.user {
pd = pd.user(user); pd = pd.user(user);
} }
if let Some(group) = config.group { if let Some(group) = &config.group {
pd = pd.group(group); pd = pd.group(group);
} }
try_fail!(pd.apply(), "Failed to drop privileges: {}"); try_fail!(pd.apply(), "Failed to drop privileges: {}");
} }
cloud.run().await; let rt = Runtime::new().unwrap();
if let Some(script) = config.ifdown { let ifdown = config.ifdown.clone();
rt.block_on(async move {
// Warning: no async code outside this block, or it will break on daemonize
let device = AsyncTunTapDevice::from_sync(device);
let socket = try_fail!(socket.build(), "Failed to create async socket: {}");
let mut cloud = try_fail!(
GenericCloud::<AsyncTunTapDevice, P, S::SocketType, SystemTimeSource>::new(
&config,
socket,
device,
port_forwarding,
stats_file
)
.await,
"Failed to create engine: {}"
);
for mut addr in config.peers {
if addr.find(':').unwrap_or(0) <= addr.find(']').unwrap_or(0) {
// : not present or only in IPv6 address
addr = format!("{}:{}", addr, DEFAULT_PORT)
}
try_fail!(cloud.add_peer(addr.clone()), "Failed to send message to {}: {}", &addr);
}
cloud.run().await
});
if let Some(script) = ifdown {
run_script(&script, &ifname); run_script(&script, &ifname);
} }
std::process::exit(0)
} }
#[tokio::main] fn main() {
async fn main() {
let args: Args = Args::from_args(); let args: Args = Args::from_args();
if args.version { if args.version {
println!("VpnCloud v{}", env!("CARGO_PKG_VERSION")); println!("VpnCloud v{}", env!("CARGO_PKG_VERSION"));
@ -327,17 +348,16 @@ async fn main() {
} }
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
if config.listen.starts_with("ws://") { if config.listen.starts_with("ws://") {
let socket = try_fail!(ProxyConnection::listen(&config.listen).await, "Failed to open socket {}: {}", config.listen); let socket = try_fail!(ProxyConnection::listen(&config.listen), "Failed to open socket {}: {}", config.listen);
match config.device_type { match config.device_type {
Type::Tap => run::<payload::Frame, _>(config, socket).await, Type::Tap => run::<payload::Frame, _>(config, socket),
Type::Tun => run::<payload::Packet, _>(config, socket).await Type::Tun => run::<payload::Packet, _>(config, socket),
} }
return; return;
} }
let socket = try_fail!(NetSocket::listen(&config.listen).await, "Failed to open socket {}: {}", config.listen); let socket = try_fail!(NetSocket::listen(&config.listen), "Failed to open socket {}: {}", config.listen);
match config.device_type { match config.device_type {
Type::Tap => run::<payload::Frame, _>(config, socket).await, Type::Tap => run::<payload::Frame, _>(config, socket),
Type::Tun => run::<payload::Packet, _>(config, socket).await Type::Tun => run::<payload::Packet, _>(config, socket),
} }
std::process::exit(0)
} }

View File

@ -7,11 +7,11 @@ use crate::port_forwarding::PortForwarding;
use crate::util::{MockTimeSource, MsgBuffer, Time, TimeSource}; use crate::util::{MockTimeSource, MsgBuffer, Time, TimeSource};
use async_trait::async_trait; use async_trait::async_trait;
use parking_lot::Mutex; use parking_lot::Mutex;
use tokio::net::UdpSocket; use tokio::net::UdpSocket as AsyncUdpSocket;
use std::{ use std::{
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
io::{self, ErrorKind}, io::{self, ErrorKind},
net::{IpAddr, Ipv6Addr, SocketAddr}, net::{UdpSocket, IpAddr, Ipv6Addr, SocketAddr},
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -32,13 +32,17 @@ pub fn get_ip() -> IpAddr {
s.local_addr().unwrap().ip() s.local_addr().unwrap().ip()
} }
pub trait SocketBuilder {
type SocketType: Socket;
fn build(self) -> Result<Self::SocketType, io::Error>;
fn create_port_forwarding(&self) -> Option<PortForwarding>;
}
#[async_trait] #[async_trait]
pub trait Socket: Sized + Clone + Send + Sync + 'static { pub trait Socket: Sized + Clone + Send + Sync + 'static {
async fn listen(addr: &str) -> Result<Self, io::Error>;
async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error>; async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error>;
async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>; async fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>;
async fn address(&self) -> Result<SocketAddr, io::Error>; async fn address(&self) -> Result<SocketAddr, io::Error>;
async fn create_port_forwarding(&self) -> Option<PortForwarding>;
} }
pub fn parse_listen(addr: &str, default_port: u16) -> SocketAddr { pub fn parse_listen(addr: &str, default_port: u16) -> SocketAddr {
@ -55,21 +59,37 @@ pub fn parse_listen(addr: &str, default_port: u16) -> SocketAddr {
} }
} }
pub struct NetSocket(Arc<UdpSocket>); pub struct NetSocket(UdpSocket);
impl Clone for NetSocket { impl NetSocket {
pub fn listen(addr: &str) -> Result<Self, io::Error> {
let addr = parse_listen(addr, DEFAULT_PORT);
Ok(Self(UdpSocket::bind(addr)?))
}
}
impl SocketBuilder for NetSocket {
type SocketType = AsyncNetSocket;
fn create_port_forwarding(&self) -> Option<PortForwarding> {
PortForwarding::new(self.0.local_addr().unwrap().port())
}
fn build(self) -> Result<Self::SocketType, io::Error> {
Ok(AsyncNetSocket(Arc::new(AsyncUdpSocket::from_std(self.0)?)))
}
}
pub struct AsyncNetSocket(Arc<AsyncUdpSocket>);
impl Clone for AsyncNetSocket {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self(self.0.clone()) Self(self.0.clone())
} }
} }
#[async_trait] #[async_trait]
impl Socket for NetSocket { impl Socket for AsyncNetSocket {
async fn listen(addr: &str) -> Result<Self, io::Error> {
let addr = parse_listen(addr, DEFAULT_PORT);
Ok(NetSocket(Arc::new(UdpSocket::bind(addr).await?)))
}
async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
buffer.clear(); buffer.clear();
let (size, addr) = self.0.recv_from(buffer.buffer()).await?; let (size, addr) = self.0.recv_from(buffer.buffer()).await?;
@ -86,10 +106,6 @@ impl Socket for NetSocket {
addr.set_ip(get_ip()); addr.set_ip(get_ip());
Ok(addr) Ok(addr)
} }
async fn create_port_forwarding(&self) -> Option<PortForwarding> {
PortForwarding::new(self.address().await.unwrap().port())
}
} }
thread_local! { thread_local! {
@ -146,10 +162,6 @@ impl MockSocket {
#[async_trait] #[async_trait]
impl Socket for MockSocket { impl Socket for MockSocket {
async fn listen(addr: &str) -> Result<Self, io::Error> {
Ok(Self::new(parse_listen(addr, DEFAULT_PORT)))
}
async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
if let Some((addr, data)) = self.inbound.lock().pop_front() { if let Some((addr, data)) = self.inbound.lock().pop_front() {
buffer.clear(); buffer.clear();
@ -172,10 +184,6 @@ impl Socket for MockSocket {
async fn address(&self) -> Result<SocketAddr, io::Error> { async fn address(&self) -> Result<SocketAddr, io::Error> {
Ok(self.address) Ok(self.address)
} }
async fn create_port_forwarding(&self) -> Option<PortForwarding> {
None
}
} }
#[cfg(feature = "bench")] #[cfg(feature = "bench")]

View File

@ -3,7 +3,7 @@
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
use super::{ use super::{
net::{get_ip, mapped_addr, parse_listen, Socket}, net::{get_ip, mapped_addr, parse_listen, Socket, SocketBuilder},
poll::{WaitImpl, WaitResult}, poll::{WaitImpl, WaitResult},
port_forwarding::PortForwarding, port_forwarding::PortForwarding,
util::MsgBuffer, util::MsgBuffer,
@ -115,6 +115,17 @@ pub struct ProxyConnection {
} }
impl ProxyConnection { impl ProxyConnection {
pub fn listen(url: &str) -> Result<Self, io::Error> {
let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?;
let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?;
socket.get_mut().set_nodelay(true)?;
let addr = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
let mut con = ProxyConnection { addr, socket: Arc::new(socket) };
let addr_data = con.read_message()?;
con.addr = read_addr(Cursor::new(&addr_data))?;
Ok(con)
}
fn read_message(&mut self) -> Result<Vec<u8>, io::Error> { fn read_message(&mut self) -> Result<Vec<u8>, io::Error> {
loop { loop {
unimplemented!(); unimplemented!();
@ -127,19 +138,20 @@ impl ProxyConnection {
} }
} }
#[async_trait] impl SocketBuilder for ProxyConnection {
impl Socket for ProxyConnection { type SocketType = ProxyConnection;
async fn listen(url: &str) -> Result<Self, io::Error> {
let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; fn build(self) -> Result<Self::SocketType, io::Error> {
let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; Ok(self)
socket.get_mut().set_nodelay(true)?;
let addr = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
let mut con = ProxyConnection { addr, socket: Arc::new(socket) };
let addr_data = con.read_message()?;
con.addr = read_addr(Cursor::new(&addr_data))?;
Ok(con)
} }
fn create_port_forwarding(&self) -> Option<PortForwarding> {
None
}
}
#[async_trait]
impl Socket for ProxyConnection {
async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
buffer.clear(); buffer.clear();
let data = self.read_message()?; let data = self.read_message()?;
@ -162,8 +174,4 @@ impl Socket for ProxyConnection {
async fn address(&self) -> Result<SocketAddr, io::Error> { async fn address(&self) -> Result<SocketAddr, io::Error> {
Ok(self.addr) Ok(self.addr)
} }
async fn create_port_forwarding(&self) -> Option<PortForwarding> {
None
}
} }