Major refactoring, version 2

This commit is contained in:
Dennis Schwerdel 2020-09-24 19:48:13 +02:00
parent 176e1956e6
commit 923269d057
35 changed files with 4650 additions and 3226 deletions

688
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -24,10 +24,12 @@ rand = "0.7"
fnv = "1" fnv = "1"
yaml-rust = "0.4" yaml-rust = "0.4"
igd = { version = "0.11", optional = true } igd = { version = "0.11", optional = true }
siphasher = "0.3"
daemonize = "0.4" daemonize = "0.4"
ring = "0.16" ring = "0.16"
privdrop = "0.3" privdrop = "0.3"
byteorder = "1.3"
thiserror = "1.0"
smallvec = "1.4"
[dev-dependencies] [dev-dependencies]
tempfile = "3" tempfile = "3"

141
Strong Crypto.md Normal file
View File

@ -0,0 +1,141 @@
The following cryptographic communication protocol is inspired by the noise protocol but contains adaptations
* for accepting a set of public keys
* and to tolerate message reordering and loss during
* and for automatic rekeying
## Key pairs and trust
Each node has a key pair consisting of a public key P and a private key S.
Also each node has a set of trusted public keys of other nodes.
There are two ways configure the key pair and trusted keys:
shared secret mode) The key pair is derived from a secret string and the only trusted public key is the public key of the key pair itself.
In this mode, all nodes must be configured with the same shared secret and then derive the identical key pair and trust each others via the common public key.
explicit trust mode) The key pair is generated randomly on each node and stored in the config file of that node. The public keys of the other nodes are exchanged beforehand (not part of VpnCloud) and listed in the config file.
The key pair is only used during connection setup, all other messages are encrypted using short-lived keys derived from an ephemeral master key.
## Initial handshake
Upon establishing a connection, there will be a 3-way handshake during which a X25519 ECDH exchange happens. Here is some pseudo code to illustrate this:
fn connect(peer) {
ecdh_keypair = create_new_ecdh_keypair()
peer.ecdh_private_key = ecdh_keypair.private_key
msg = Message (
stage = 1,
node_id = self.node_id,
public_key = self.key_pair.public_key,
ecdh_public_key = ecdh_keypair.public_key,
)
msg_signed = self.key_pair.sign(msg)
peer.next_stage = 2
send_to(peer, msg_signed)
}
fn handle_init_message(peer, msg) {
if msg.node_id == self.node_id {
return
}
if !is_trusted(msg.public_key) {
return
}
if !signature_valid(msg, msg.signature, msg.public_key) {
return
}
if peer.next_stage == null {
peer.next_stage = 1
}
if msg.stage != peer.next_stage && peer.last_msg != null {
repeat_last_message_to(peer)
return
}
if msg.stage == 1 {
ecdh_keypair = create_new_ecdh_keypair()
peer.master_key = ecdh_keypair.private_key.agree(msg.ecdh_public_key)
peer.current_subkey = peer.master_key.derive_subkey(0)
payload = { non-crypto init message }
encrypted_payload = peer.current_subkey.encrypt(payload)
msg = Message (
stage = 2,
node_id = self.node_id,
public_key = self.key_pair.public_key,
ecdh_public_key = ecdh_keypair.public_key,
payload = encrypted_payload
)
msg_signed = self.key_pair.sign(msg)
peer.next_stage = 3
send_to(peer, msg_signed)
}
if msg.stage == 2 {
peer.master_key = peer.ecdh_private_key.agree(msg.ecdh_public_key)
peer.current_subkey = peer.master_key.derive_subkey(0)
peer_payload = peer.current_subkey.decrypt(msg.payload)
{ use non-crypto payload }
payload = { non-crypto init message }
encrypted_payload = peer.current_subkey.encrypt(payload)
msg = Message (
stage = 3,
node_id = self.node_id,
public_key = self.key_pair.public_key,
payload = encrypted_payload
)
msg_signed = self.key_pair.sign(msg)
peer.next_stage = 1
send_to(peer, msg_signed)
{ add peer to list, replace if needed }
}
if msg.stage == 3 {
peer_payload = peer.current_subkey.decrypt(msg.payload)
peer.next_stage = 1
{ use non-crypto payload }
{ add peer to list, replace if needed }
}
}
In order to prevent denial-of-service attacks, established connections will only be affected when a message in stage 2 or 3 is received. This stage can not be reached by replaying messages.
If an initialization message at stage 1 is received without there being an active initialization state, a new state is created to handle the request. Messages with other stages are ignored.
If a message with an unexpected stage is received, the last sent initialization message is repeated.
After a timeout, e.g. 5 seconds, the last message is repeated if the connection is still in the initialization phase.
## Subkeys and key rotation
All messages except for the initialization handshake are encrypted using short-lived keys derived from the master key that is created during initialization.
The master key is never used for encryption. Instead it is used to derive subkeys that in turn are used for encryption. The subkeys are only temporary and are rotated periodically.
The challenge is to swap subkeys and still accept the old subkey for a grace period (to allow packet losses and reordering but reject replay attacks).
The main idea is that both peers maintain a subkey id that is incremented on every subkey rotation and is used to derive the subkey from the master key. This way, no subkeys will ever be exchanged between the peers and both peers can calculate all possible subkeys if they know the key id.
The current subkey id is attached to every encrypted message so that the receiver knows which subkey to use. There is a simple logic based on a current subkey id:
- If the key id is more than N (e.g. 5) key ids behind the current subkey id, drop the message (this is interpreted as a replay attack)
- If the key id equals the current subkey id, decrypt and process it
- If the key id is higher than the current subkey id, try to decrypt the message. If it succeeds, set the current subkey id to the received key id and process the message, else drop the message.
The current subkey id is used to limit the accepted subkey ids and reject replay attacks. Also this current subkey id is used to encrypt messages when sending them to the peer.
In a message, only the least significant 2 bytes of the subkey id are embedded in the message and the rest of the id is taken from the current subkey id (making sure to select add or subtract 1 from the upper bytes if the resulting subkey id is closer to the current one).
The key rotation happens on the following conditions:
* The last key rotation was over X (5) minutes ago
* A huge amount of messages have been sent using that subkey (see below)
On the key rotation, the sender increments its current subkey id and uses it to send future messages.
Both the sender and the receiver will issue subkey rotations, so every X minutes 2 rotations happen.
So the accepted subkey ids cover a time of X * N / 2 minutes. After that time a subkey can become known without harming the confidentiality.
Subkeys are derived from the master key by encryption a fixed dummy message with the master key and the subkey id as nonce.
## Nonces
Since both peers use the same subkey for encryption, they have to make sure to not use the same nonces. In order to guarantee that, the nonce space of 96 bits is split in two halves. The lower half belongs to the node with the smaller node_id and the upper half belongs to the node with the greater node_id (if the node_ids are identical, no connection is established).
Both peers start with a randomized nonce. The most significant 8 bits are fixed (0x00 or 0x80 depending on node_id) and the lower 11 bytes are randomized.
Whenever the upper byte reaches 0x40 or 0xc0, an extra subkey rotation is performed.
In order to reject replay attacks before the subkeys are invalidated by rotation, the nonces are invalidated too.
Both nodes track the highest seen nonce sent by the remote end. Once every T seconds that value is stored into a buffer and the last value from the buffer is taken as a minimum nonce value that will be accepted. This will be done for all subkey ids that are still valid.
The result of this is that messages have at least T seconds to reach their destination, but after that time the nonces will not be accepted anymore to prevent replay attacks.

View File

@ -26,7 +26,7 @@
# Switch table entry timeout in seconds. This parameter is only used in switch # Switch table entry timeout in seconds. This parameter is only used in switch
# mode. Addresses that have not been seen for the given period of time will # mode. Addresses that have not been seen for the given period of time will
# be forgot. # be forgot.
#dst_timeout: 300 #switch_timeout: 300
# An optional token that identifies the network and helps to distinguish it # An optional token that identifies the network and helps to distinguish it
# from other networks. # from other networks.

View File

@ -241,7 +241,7 @@ in 3 different modes:
In this mode, the VPN will dynamically learn addresses In this mode, the VPN will dynamically learn addresses
as they are used as source addresses and use them to forward data to its as they are used as source addresses and use them to forward data to its
destination. Addresses that have not been seen for some time destination. Addresses that have not been seen for some time
(option \fBdst_timeout\fP) will be forgotten. Data for unknown addresses will be (option \fBswitch_timeout\fP) will be forgotten. Data for unknown addresses will be
broadcast to all peers. This mode is the default mode for TAP devices that broadcast to all peers. This mode is the default mode for TAP devices that
process Ethernet frames but it can also be used with TUN devices and IP process Ethernet frames but it can also be used with TUN devices and IP
packets. packets.
@ -492,7 +492,7 @@ Interval for loading and storing beacons in seconds. Same as \fB\-\-beacon\-inte
The mode of the VPN. Same as \fB\-\-mode\fP The mode of the VPN. Same as \fB\-\-mode\fP
.RE .RE
.sp .sp
\fBdst_timeout\fP \fBswitch_timeout\fP
.RS 4 .RS 4
Switch table entry timeout in seconds. Same as \fB\-\-dst\-timeout\fP Switch table entry timeout in seconds. Same as \fB\-\-dst\-timeout\fP
.RE .RE

View File

@ -3,9 +3,23 @@
Due to semantic versioning, any breaking change after 1.0 requires a new major version number. Due to semantic versioning, any breaking change after 1.0 requires a new major version number.
This is a list of breaking changes to do in such a case: This is a list of breaking changes to do in such a case:
- Add strong crypto, change network protocol [x] Add strong crypto, change network protocol
- Negotiate crypto method per peer [x] Negotiate crypto method per peer
- Make encryption the default, --unencrypted for no encryption [x] Make encryption the default, no encryption must be stated explicitly
- Remove network-id parameter [x] Changed default device type to TUN
- Remove port config option [x] Remove network-id parameter
- Rename peers to connect [x] Remove port config option
[x] Rename subnet to claim
[x] Automatically claim addresses based on interface addresses (disable with --no-auto-claim)
[x] Set peer exchange interval to 5min
[x] Periodically send claims with peer list
[x] Allow to give --ip instead of ifup cmd
[x] Automatically set optimal MTU on interface
[x] Warning for disabled or loose rp_filter https://seclists.org/oss-sec/2019/q4/122
## TODOs
- Send keepalive messages on NAT every 10secs
- https://github.com/tokio-rs/mio
- https://docs.rs/tun/0.5.0/tun/

View File

@ -8,6 +8,7 @@ import re
import json import json
import base64 import base64
import sys import sys
import os
from datetime import date from datetime import date
MAX_WAIT = 300 MAX_WAIT = 300
@ -26,6 +27,10 @@ def run_cmd(connection, cmd):
else: else:
return out, err return out, err
def upload(connection, local, remote):
ftp_client=connection.open_sftp()
ftp_client.put(local, remote)
ftp_client.close()
class SpotInstanceRequest: class SpotInstanceRequest:
def __init__(self, id): def __init__(self, id):
@ -44,24 +49,23 @@ class Node:
def run_cmd(self, cmd): def run_cmd(self, cmd):
return run_cmd(self.connection, cmd) return run_cmd(self.connection, cmd)
def start_vpncloud(self, ifup=None, mtu=1400, ip=None, crypto=None, shared_key="test", device_type="tap", listen="3210", mode="normal", peers=[], subnets=[]): def start_vpncloud(self, ip=None, crypto=None, password="test", device_type="tun", listen="3210", mode="normal", peers=[], claims=[]):
args = [ args = [
"--daemon", "--daemon",
"--no-port-forwarding", "--no-port-forwarding",
"-t {}".format(device_type), "-t {}".format(device_type),
"-m {}".format(mode), "-m {}".format(mode),
"-l {}".format(listen) "-l {}".format(listen),
"--password '{}'".format(password)
] ]
if ifup: if ip:
args.append("--ifup {}".format(ifup)) args.append("--ip {}".format(ip))
else:
args.append("--ifup 'ifconfig $IFNAME {ip} mtu {mtu} up'".format(mtu=mtu, ip=ip))
if crypto: if crypto:
args.append("--shared-key '{}' --crypto {}".format(shared_key, crypto)) args.append("--algo {}".format(crypto))
for p in peers: for p in peers:
args.append("-c {}".format(p)) args.append("-c {}".format(p))
for s in subnets: for c in claims:
args.append("-s {}".format(s)) args.append("--claim {}".format(c))
args = " ".join(args) args = " ".join(args)
self.run_cmd("sudo vpncloud {}".format(args)) self.run_cmd("sudo vpncloud {}".format(args))
@ -116,7 +120,7 @@ def find_ami(region, owner, name_pattern, arch='x86_64'):
class EC2Environment: class EC2Environment:
def __init__(self, vpncloud_version, region, node_count, instance_type, use_spot=True, max_price=0.1, ami=('amazon', 'amzn2-ami-hvm-*'), username="ec2-user", subnet=CREATE, keyname=CREATE, privatekey=CREATE, tag="vpncloud", cluster_nodes=False): def __init__(self, vpncloud_version, region, node_count, instance_type, vpncloud_file=None, use_spot=True, max_price=0.1, ami=('amazon', 'amzn2-ami-hvm-*'), username="ec2-user", subnet=CREATE, keyname=CREATE, privatekey=CREATE, tag="vpncloud", cluster_nodes=False):
self.region = region self.region = region
self.node_count = node_count self.node_count = node_count
self.instance_type = instance_type self.instance_type = instance_type
@ -130,6 +134,7 @@ class EC2Environment:
self.ami = ami self.ami = ami
self.username = username self.username = username
self.vpncloud_version = vpncloud_version self.vpncloud_version = vpncloud_version
self.vpncloud_file = vpncloud_file
self.cluster_nodes = cluster_nodes self.cluster_nodes = cluster_nodes
self.resources = [] self.resources = []
self.instances = [] self.instances = []
@ -220,6 +225,9 @@ class EC2Environment:
packages: packages:
- iperf3 - iperf3
- socat - socat
"""
if not self.vpncloud_file:
userdata += """
runcmd: runcmd:
- wget https://github.com/dswd/vpncloud/releases/download/v{version}/vpncloud_{version}.x86_64.rpm -O /tmp/vpncloud.rpm - wget https://github.com/dswd/vpncloud/releases/download/v{version}/vpncloud_{version}.x86_64.rpm -O /tmp/vpncloud.rpm
- yum install -y /tmp/vpncloud.rpm - yum install -y /tmp/vpncloud.rpm
@ -334,6 +342,13 @@ runcmd:
waited += 1 waited += 1
if waited >= MAX_WAIT: if waited >= MAX_WAIT:
raise Exception("Waited too long") raise Exception("Waited too long")
if self.vpncloud_file:
eprint("Uploading vpncloud binary")
for con in self.connections:
upload(con, self.vpncloud_file, 'vpncloud')
run_cmd(con, 'chmod +x vpncloud')
run_cmd(con, 'sudo mv vpncloud /usr/bin/vpncloud')
def terminate(self): def terminate(self):
if not self.resources: if not self.resources:

View File

@ -7,7 +7,8 @@ from datetime import date
# Note: this script will run for ~8 minutes and incur costs of about $ 0.02 # Note: this script will run for ~8 minutes and incur costs of about $ 0.02
VERSION = "1.4.0" FILE = "../target/release/vpncloud"
VERSION = "2.0-pre"
REGION = "eu-central-1" REGION = "eu-central-1"
env = EC2Environment( env = EC2Environment(
@ -16,14 +17,15 @@ env = EC2Environment(
instance_type = "m5.large", instance_type = "m5.large",
use_spot = True, use_spot = True,
max_price = "0.08", # USD per hour per VM max_price = "0.08", # USD per hour per VM
vpncloud_version = VERSION, vpncloud_version = VERSION,
vpncloud_file = FILE,
cluster_nodes = True, cluster_nodes = True,
subnet = CREATE, subnet = CREATE,
keyname = CREATE keyname = CREATE
) )
CRYPTO = ["aes256", "aes128", "chacha20"] if VERSION >= "1.5.0" else ["aes256", "chacha20"] CRYPTO = ["plain", "aes256", "aes128", "chacha20"]
class PerfTest: class PerfTest:
@ -40,7 +42,7 @@ class PerfTest:
"region": env.region, "region": env.region,
"instance_type": env.instance_type, "instance_type": env.instance_type,
"ami": env.ami, "ami": env.ami,
"version": env.version "version": env.vpncloud_version
} }
return cls(env.nodes[0], env.nodes[1], meta) return cls(env.nodes[0], env.nodes[1], meta)
@ -64,11 +66,11 @@ class PerfTest:
"ping_1000": self.run_ping(dst, 1000), "ping_1000": self.run_ping(dst, 1000),
} }
def start_vpncloud(self, mtu=8800, crypto=None): def start_vpncloud(self, crypto=None):
eprint("\tSetting up vpncloud on receiver") eprint("\tSetting up vpncloud on receiver")
self.receiver.start_vpncloud(crypto=crypto, ip="{}/24".format(self.receiver_ip_vpncloud), mtu=mtu) self.receiver.start_vpncloud(crypto=crypto, ip="{}/24".format(self.receiver_ip_vpncloud))
eprint("\tSetting up vpncloud on sender") eprint("\tSetting up vpncloud on sender")
self.sender.start_vpncloud(crypto=crypto, peers=["{}:3210".format(self.receiver.private_ip)], ip="{}/24".format(self.sender_ip_vpncloud), mtu=mtu) self.sender.start_vpncloud(crypto=crypto, peers=["{}:3210".format(self.receiver.private_ip)], ip="{}/24".format(self.sender_ip_vpncloud))
time.sleep(1.0) time.sleep(1.0)
def stop_vpncloud(self): def stop_vpncloud(self):
@ -81,20 +83,20 @@ class PerfTest:
"meta": self.meta, "meta": self.meta,
"native": self.run_suite(self.receiver.private_ip) "native": self.run_suite(self.receiver.private_ip)
} }
for crypto in [None] + CRYPTO: for crypto in CRYPTO:
eprint("Running with crypto {}".format(crypto or "plain")) eprint("Running with crypto {}".format(crypto))
self.start_vpncloud(crypto=crypto) self.start_vpncloud(crypto=crypto)
res = self.run_suite(self.receiver_ip_vpncloud) res = self.run_suite(self.receiver_ip_vpncloud)
self.stop_vpncloud() self.stop_vpncloud()
results[str(crypto or "plain")] = res results[str(crypto)] = res
results['results'] = { results['results'] = {
"throughput_mbits": dict([ "throughput_mbits": dict([
(k, results[k]["iperf"]["throughput"] / 1000000.0) for k in ["native", "plain"] + CRYPTO (k, results[k]["iperf"]["throughput"] / 1000000.0) for k in ["native"] + CRYPTO
]), ]),
"latency_us": dict([ "latency_us": dict([
(k, dict([ (k, dict([
(str(s), (results[k]["ping_%s" % s]["rtt_avg"] - results["native"]["ping_%s" % s]["rtt_avg"])*1000.0/2.0) for s in [100, 500, 1000] (str(s), (results[k]["ping_%s" % s]["rtt_avg"] - results["native"]["ping_%s" % s]["rtt_avg"])*1000.0/2.0) for s in [100, 500, 1000]
])) for k in ["plain"] + CRYPTO ])) for k in CRYPTO
]) ])
} }
return results return results

View File

@ -44,16 +44,14 @@ struct FutureResult<T> {
#[derive(Clone)] #[derive(Clone)]
pub struct BeaconSerializer<TS> { pub struct BeaconSerializer<TS> {
magic: Vec<u8>,
shared_key: Vec<u8>, shared_key: Vec<u8>,
future_peers: Arc<FutureResult<Vec<SocketAddr>>>, future_peers: Arc<FutureResult<Vec<SocketAddr>>>,
_dummy_ts: PhantomData<TS> _dummy_ts: PhantomData<TS>
} }
impl<TS: TimeSource> BeaconSerializer<TS> { impl<TS: TimeSource> BeaconSerializer<TS> {
pub fn new(magic: &[u8], shared_key: &[u8]) -> Self { pub fn new(shared_key: &[u8]) -> Self {
Self { Self {
magic: magic.to_owned(),
shared_key: shared_key.to_owned(), shared_key: shared_key.to_owned(),
future_peers: Arc::new(FutureResult { has_result: AtomicBool::new(false), result: Mutex::new(Vec::new()) }), future_peers: Arc::new(FutureResult { has_result: AtomicBool::new(false), result: Mutex::new(Vec::new()) }),
_dummy_ts: PhantomData _dummy_ts: PhantomData
@ -67,7 +65,6 @@ impl<TS: TimeSource> BeaconSerializer<TS> {
fn get_keystream(&self, type_: u8, seed: u8, iter: u8) -> Vec<u8> { fn get_keystream(&self, type_: u8, seed: u8, iter: u8) -> Vec<u8> {
let mut data = Vec::new(); let mut data = Vec::new();
data.extend_from_slice(&[type_, seed, iter]); data.extend_from_slice(&[type_, seed, iter]);
data.extend_from_slice(&self.magic);
data.extend_from_slice(&self.shared_key); data.extend_from_slice(&self.shared_key);
sha512(&data) sha512(&data)
} }
@ -325,122 +322,122 @@ impl<TS: TimeSource> BeaconSerializer<TS> {
#[test] #[test]
fn encode() { fn encode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
assert_eq!("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", ser.encode(&peers)); assert_eq!("WsHI31EWDMBYxvITiILIrm2k9gEik22E", ser.encode(&peers));
peers.push(SocketAddr::from_str("[::1]:5678").unwrap()); peers.push(SocketAddr::from_str("[::1]:5678").unwrap());
assert_eq!("3hRD8BKvg7jotek0FGLeYtIc1zj7jzPRyQscQAe9tCqnFJ0vyVfIxYMB", ser.encode(&peers)); assert_eq!("WsHI3GXKaXCveo6uejmZizZ72kR6Y0L9T7h49TXONp1ugfKvvvEik22E", ser.encode(&peers));
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:54").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:54").unwrap());
assert_eq!("3hRD86NwMC5dPp8bh5idzhMal4AIxYMB", ser.encode(&peers)); assert_eq!("WsHI32gm9eMSHP3Lm1GXcdP7rD3ik22E", ser.encode(&peers));
} }
#[test] #[test]
fn decode() { fn decode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
assert_eq!(format!("{:?}", peers), format!("{:?}", ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", None))); assert_eq!(format!("{:?}", peers), format!("{:?}", ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", None)));
peers.push(SocketAddr::from_str("[::1]:5678").unwrap()); peers.push(SocketAddr::from_str("[::1]:5678").unwrap());
assert_eq!( assert_eq!(
format!("{:?}", peers), format!("{:?}", peers),
format!("{:?}", ser.decode("3hRD8BKvg7jotek0FGLeYtIc1zj7jzPRyQscQAe9tCqnFJ0vyVfIxYMB", None)) format!("{:?}", ser.decode("WsHI3GXKaXCveo6uejmZizZ72kR6Y0L9T7h49TXONp1ugfKvvvEik22E", None))
); );
} }
#[test] #[test]
fn decode_split() { fn decode_split() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
assert_eq!( assert_eq!(
format!("{:?}", peers), format!("{:?}", peers),
format!("{:?}", ser.decode("3hRD8-5V.3h:1P 0g\t5U\nn9(ZW)no[qR]Doü7ZäIxYMB", None)) format!("{:?}", ser.decode("WsHI3-1E.WD:MB Yx\tvI\nTi(IL)Ir[m2]k9ügEäik22E", None))
); );
assert_eq!( assert_eq!(
format!("{:?}", peers), format!("{:?}", peers),
format!("{:?}", ser.decode("3 -, \nhR--D85V3h1P0g5Un9ZWnoqRDo7ZI(x}YÖÄÜ\nMB", None)) format!("{:?}", ser.decode("W -, \nsH--I31EWDMBYxvITiILIrm2k9gEi(k)2ÖÄÜ\n2E", None))
); );
} }
#[test] #[test]
fn decode_offset() { fn decode_offset() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
assert_eq!( assert_eq!(
format!("{:?}", peers), format!("{:?}", peers),
format!("{:?}", ser.decode("Hello World: 3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB! End of the World", None)) format!("{:?}", ser.decode("Hello World: WsHI31EWDMBYxvITiILIrm2k9gEik22E! End of the World", None))
); );
} }
#[test] #[test]
fn decode_multiple() { fn decode_multiple() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
assert_eq!( assert_eq!(
format!("{:?}", peers), format!("{:?}", peers),
format!("{:?}", ser.decode("3hRD850fTOmqFffvcJEIxYMB 3hRD823uwTS47pupeONIxYMB", None)) format!("{:?}", ser.decode("WsHI31HVpqxFNMNSPrvik22E WsHI34yOBcZIulKdtn2ik22E", None))
); );
} }
#[test] #[test]
fn decode_ttl() { fn decode_ttl() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", None).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", None).len());
MockTimeSource::set_time(2100 * 3600); MockTimeSource::set_time(2100 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", None).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", None).len());
MockTimeSource::set_time(2005 * 3600); MockTimeSource::set_time(2005 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", None).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", None).len());
MockTimeSource::set_time(1995 * 3600); MockTimeSource::set_time(1995 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", None).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", None).len());
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", Some(24)).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", Some(24)).len());
MockTimeSource::set_time(1995 * 3600); MockTimeSource::set_time(1995 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", Some(24)).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", Some(24)).len());
MockTimeSource::set_time(2005 * 3600); MockTimeSource::set_time(2005 * 3600);
assert_eq!(2, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", Some(24)).len()); assert_eq!(2, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", Some(24)).len());
MockTimeSource::set_time(2100 * 3600); MockTimeSource::set_time(2100 * 3600);
assert_eq!(0, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", Some(24)).len()); assert_eq!(0, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", Some(24)).len());
MockTimeSource::set_time(1900 * 3600); MockTimeSource::set_time(1900 * 3600);
assert_eq!(0, ser.decode("3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB", Some(24)).len()); assert_eq!(0, ser.decode("WsHI31EWDMBYxvITiILIrm2k9gEik22E", Some(24)).len());
} }
#[test] #[test]
fn decode_invalid() { fn decode_invalid() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
assert_eq!(0, ser.decode("", None).len()); assert_eq!(0, ser.decode("", None).len());
assert_eq!(0, ser.decode("3hRD8IxYMB", None).len()); assert_eq!(0, ser.decode("WsHI3ik22E", None).len());
assert_eq!(0, ser.decode("3hRD8--", None).len()); assert_eq!(0, ser.decode("WsHI3--", None).len());
assert_eq!(0, ser.decode("--IxYMB", None).len()); assert_eq!(0, ser.decode("--ik22E", None).len());
assert_eq!(0, ser.decode("3hRD85V3h1P0g5Un8ZWnoqRDo7ZIxYMB", None).len()); assert_eq!(0, ser.decode("WsHI32EWDMBYxvITiILIrm2k9gEik22E", None).len());
assert_eq!(2, ser.decode("IxYMB3hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMB3hRD8", None).len()); assert_eq!(2, ser.decode("ik22EWsHI31EWDMBYxvITiILIrm2k9gEik22EWsHI3", None).len());
assert_eq!(2, ser.decode("3hRD83hRD85V3h1P0g5Un9ZWnoqRDo7ZIxYMBIxYMB", None).len()); assert_eq!(2, ser.decode("WsHI3WsHI31EWDMBYxvITiILIrm2k9gEik22Eik22E", None).len());
} }
#[test] #[test]
fn encode_decode() { fn encode_decode() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
@ -452,7 +449,7 @@ fn encode_decode() {
#[test] #[test]
fn encode_decode_file() { fn encode_decode_file() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());
@ -466,7 +463,7 @@ fn encode_decode_file() {
#[test] #[test]
fn encode_decode_cmd() { fn encode_decode_cmd() {
MockTimeSource::set_time(2000 * 3600); MockTimeSource::set_time(2000 * 3600);
let ser = BeaconSerializer::<MockTimeSource>::new(b"vpnc", b"mysecretkey"); let ser = BeaconSerializer::<MockTimeSource>::new(b"mysecretkey");
let mut peers = Vec::new(); let mut peers = Vec::new();
peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap()); peers.push(SocketAddr::from_str("1.2.3.4:5678").unwrap());
peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap()); peers.push(SocketAddr::from_str("6.6.6.6:53").unwrap());

View File

@ -11,22 +11,21 @@ use std::{
use super::{ use super::{
cloud::GenericCloud, cloud::GenericCloud,
config::Config, config::{Config, CryptoConfig},
crypto::{Crypto, CryptoMethod},
device::{TunTapDevice, Type}, device::{TunTapDevice, Type},
ethernet::{self, SwitchTable}, ethernet::{self, SwitchTable},
ip::Packet, ip::Packet,
net::MockSocket, net::MockSocket,
old_crypto::{CryptoMethod, OldCrypto},
poll::WaitImpl, poll::WaitImpl,
types::{Address, Protocol, Table}, types::{Address, Protocol, Table},
udpmessage::{decode, encode, Message}, udpmessage::{decode, encode, Message},
util::{MockTimeSource, SystemTimeSource, TimeSource}, util::{MockTimeSource, SystemTimeSource, TimeSource}
MAGIC
}; };
#[bench] #[bench]
fn crypto_chacha20(b: &mut Bencher) { fn crypto_chacha20(b: &mut Bencher) {
let mut crypto = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test"); let mut crypto = OldCrypto::from_shared_key(CryptoMethod::ChaCha20, "test");
let mut payload = [0; 1500]; let mut payload = [0; 1500];
let header = [0; 8]; let header = [0; 8];
let mut nonce_bytes = [0; 12]; let mut nonce_bytes = [0; 12];
@ -39,7 +38,7 @@ fn crypto_chacha20(b: &mut Bencher) {
#[bench] #[bench]
fn crypto_aes256(b: &mut Bencher) { fn crypto_aes256(b: &mut Bencher) {
let mut crypto = Crypto::from_shared_key(CryptoMethod::AES256, "test"); let mut crypto = OldCrypto::from_shared_key(CryptoMethod::AES256, "test");
let mut payload = [0; 1500]; let mut payload = [0; 1500];
let header = [0; 8]; let header = [0; 8];
let mut nonce_bytes = [0; 12]; let mut nonce_bytes = [0; 12];
@ -52,7 +51,7 @@ fn crypto_aes256(b: &mut Bencher) {
#[bench] #[bench]
fn crypto_aes128(b: &mut Bencher) { fn crypto_aes128(b: &mut Bencher) {
let mut crypto = Crypto::from_shared_key(CryptoMethod::AES128, "test"); let mut crypto = OldCrypto::from_shared_key(CryptoMethod::AES128, "test");
let mut payload = [0; 1500]; let mut payload = [0; 1500];
let header = [0; 8]; let header = [0; 8];
let mut nonce_bytes = [0; 12]; let mut nonce_bytes = [0; 12];
@ -65,25 +64,25 @@ fn crypto_aes128(b: &mut Bencher) {
#[bench] #[bench]
fn message_encode(b: &mut Bencher) { fn message_encode(b: &mut Bencher) {
let mut crypto = Crypto::None; let mut crypto = OldCrypto::None;
let mut payload = [0; 1600]; let mut payload = [0; 1600];
let mut msg = Message::Data(&mut payload, 64, 1464); let mut msg = Message::Data(&mut payload, 64, 1464);
let mut buf = [0; 1600]; let mut buf = [0; 1600];
b.iter(|| { b.iter(|| {
encode(&mut msg, &mut buf[..], MAGIC, &mut crypto); encode(&mut msg, &mut buf[..], &mut crypto);
}); });
b.bytes = 1400; b.bytes = 1400;
} }
#[bench] #[bench]
fn message_decode(b: &mut Bencher) { fn message_decode(b: &mut Bencher) {
let mut crypto = Crypto::None; let mut crypto = OldCrypto::None;
let mut payload = [0; 1600]; let mut payload = [0; 1600];
let mut msg = Message::Data(&mut payload, 64, 1464); let mut msg = Message::Data(&mut payload, 64, 1464);
let mut buf = [0; 1600]; let mut buf = [0; 1600];
let mut res = encode(&mut msg, &mut buf[..], MAGIC, &mut crypto); let mut res = encode(&mut msg, &mut buf[..], &mut crypto);
b.iter(|| { b.iter(|| {
decode(&mut res, MAGIC, &mut crypto).unwrap(); decode(&mut res, &mut crypto).unwrap();
}); });
b.bytes = 1400; b.bytes = 1400;
} }
@ -154,13 +153,16 @@ type TestNode = GenericCloud<TunTapDevice, ethernet::Frame, SwitchTable<MockTime
fn create_test_node() -> TestNode { fn create_test_node() -> TestNode {
TestNode::new( TestNode::new(
&Config::default(), &Config {
crypto: CryptoConfig { password: Some("test".to_string()), ..CryptoConfig::default() },
..Config::default()
},
TunTapDevice::dummy("dummy", "/dev/null", Type::Tap).unwrap(), TunTapDevice::dummy("dummy", "/dev/null", Type::Tap).unwrap(),
SwitchTable::new(1800, 10), SwitchTable::new(1800, 10),
true, true,
true, true,
vec![], vec![],
Crypto::None, OldCrypto::None,
None, None,
None None
) )

File diff suppressed because it is too large Load Diff

View File

@ -2,25 +2,17 @@
// Copyright (C) 2015-2020 Dennis Schwerdel // Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
use super::{Args, MAGIC}; use super::{device::Type, types::Mode, util::Duration};
pub use crate::crypto::Config as CryptoConfig;
use super::{
crypto::CryptoMethod,
device::Type,
types::{HeaderMagic, Mode},
util::{Duration, Encoder}
};
use siphasher::sip::SipHasher24;
use std::{ use std::{
cmp::max, cmp::max,
hash::{Hash, Hasher},
net::{IpAddr, Ipv6Addr, SocketAddr} net::{IpAddr, Ipv6Addr, SocketAddr}
}; };
use structopt::StructOpt;
const HASH_PREFIX: &str = "hash:"; pub const DEFAULT_PEER_TIMEOUT: u16 = 300;
pub const DEFAULT_PEER_TIMEOUT: u16 = 600;
pub const DEFAULT_PORT: u16 = 3210; pub const DEFAULT_PORT: u16 = 3210;
@ -41,11 +33,14 @@ pub struct Config {
pub device_type: Type, pub device_type: Type,
pub device_name: String, pub device_name: String,
pub device_path: Option<String>, pub device_path: Option<String>,
pub fix_rp_filter: bool,
pub ip: Option<String>,
pub ifup: Option<String>, pub ifup: Option<String>,
pub ifdown: Option<String>, pub ifdown: Option<String>,
pub crypto: CryptoMethod,
pub shared_key: Option<String>, pub crypto: CryptoConfig,
pub magic: Option<String>,
pub listen: SocketAddr, pub listen: SocketAddr,
pub peers: Vec<String>, pub peers: Vec<String>,
pub peer_timeout: Duration, pub peer_timeout: Duration,
@ -53,9 +48,11 @@ pub struct Config {
pub beacon_store: Option<String>, pub beacon_store: Option<String>,
pub beacon_load: Option<String>, pub beacon_load: Option<String>,
pub beacon_interval: Duration, pub beacon_interval: Duration,
pub beacon_password: Option<String>,
pub mode: Mode, pub mode: Mode,
pub dst_timeout: Duration, pub switch_timeout: Duration,
pub subnets: Vec<String>, pub claims: Vec<String>,
pub auto_claim: bool,
pub port_forwarding: bool, pub port_forwarding: bool,
pub daemonize: bool, pub daemonize: bool,
pub pid_file: Option<String>, pub pid_file: Option<String>,
@ -69,14 +66,14 @@ pub struct Config {
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
Config { Config {
device_type: Type::Tap, device_type: Type::Tun,
device_name: "vpncloud%d".to_string(), device_name: "vpncloud%d".to_string(),
device_path: None, device_path: None,
fix_rp_filter: false,
ip: None,
ifup: None, ifup: None,
ifdown: None, ifdown: None,
crypto: CryptoMethod::ChaCha20, crypto: CryptoConfig::default(),
shared_key: None,
magic: None,
listen: "[::]:3210".parse::<SocketAddr>().unwrap(), listen: "[::]:3210".parse::<SocketAddr>().unwrap(),
peers: vec![], peers: vec![],
peer_timeout: DEFAULT_PEER_TIMEOUT as Duration, peer_timeout: DEFAULT_PEER_TIMEOUT as Duration,
@ -84,9 +81,11 @@ impl Default for Config {
beacon_store: None, beacon_store: None,
beacon_load: None, beacon_load: None,
beacon_interval: 3600, beacon_interval: 3600,
beacon_password: None,
mode: Mode::Normal, mode: Mode::Normal,
dst_timeout: 300, switch_timeout: 300,
subnets: vec![], claims: vec![],
auto_claim: true,
port_forwarding: true, port_forwarding: true,
daemonize: false, daemonize: false,
pid_file: None, pid_file: None,
@ -101,7 +100,7 @@ impl Default for Config {
impl Config { impl Config {
#[allow(clippy::cognitive_complexity)] #[allow(clippy::cognitive_complexity)]
pub fn merge_file(&mut self, file: ConfigFile) { pub fn merge_file(&mut self, mut file: ConfigFile) {
if let Some(val) = file.device_type { if let Some(val) = file.device_type {
self.device_type = val; self.device_type = val;
} }
@ -111,25 +110,18 @@ impl Config {
if let Some(val) = file.device_path { if let Some(val) = file.device_path {
self.device_path = Some(val); self.device_path = Some(val);
} }
if let Some(val) = file.fix_rp_filter {
self.fix_rp_filter = val;
}
if let Some(val) = file.ip {
self.ip = Some(val);
}
if let Some(val) = file.ifup { if let Some(val) = file.ifup {
self.ifup = Some(val); self.ifup = Some(val);
} }
if let Some(val) = file.ifdown { if let Some(val) = file.ifdown {
self.ifdown = Some(val); self.ifdown = Some(val);
} }
if let Some(val) = file.crypto {
self.crypto = val;
}
if let Some(val) = file.shared_key {
self.shared_key = Some(val);
}
if let Some(val) = file.magic {
self.magic = Some(val);
}
if let Some(val) = file.port {
self.listen = parse_listen(&format!("{}", &val));
warn!("The config option 'port' is deprecated, use 'listen' instead.");
}
if let Some(val) = file.listen { if let Some(val) = file.listen {
self.listen = parse_listen(&val); self.listen = parse_listen(&val);
} }
@ -151,14 +143,20 @@ impl Config {
if let Some(val) = file.beacon_interval { if let Some(val) = file.beacon_interval {
self.beacon_interval = val; self.beacon_interval = val;
} }
if let Some(val) = file.beacon_password {
self.beacon_password = Some(val);
}
if let Some(val) = file.mode { if let Some(val) = file.mode {
self.mode = val; self.mode = val;
} }
if let Some(val) = file.dst_timeout { if let Some(val) = file.switch_timeout {
self.dst_timeout = val; self.switch_timeout = val;
} }
if let Some(mut val) = file.subnets { if let Some(mut val) = file.claims {
self.subnets.append(&mut val); self.claims.append(&mut val);
}
if let Some(val) = file.auto_claim {
self.auto_claim = val;
} }
if let Some(val) = file.port_forwarding { if let Some(val) = file.port_forwarding {
self.port_forwarding = val; self.port_forwarding = val;
@ -181,6 +179,19 @@ impl Config {
if let Some(val) = file.group { if let Some(val) = file.group {
self.group = Some(val); self.group = Some(val);
} }
if let Some(val) = file.crypto.password {
self.crypto.password = Some(val)
}
if let Some(val) = file.crypto.public_key {
self.crypto.public_key = Some(val)
}
if let Some(val) = file.crypto.private_key {
self.crypto.private_key = Some(val)
}
self.crypto.trusted_keys.append(&mut file.crypto.trusted_keys);
if !file.crypto.algorithms.is_empty() {
self.crypto.algorithms = file.crypto.algorithms.clone();
}
} }
pub fn merge_args(&mut self, mut args: Args) { pub fn merge_args(&mut self, mut args: Args) {
@ -193,29 +204,22 @@ impl Config {
if let Some(val) = args.device_path { if let Some(val) = args.device_path {
self.device_path = Some(val); self.device_path = Some(val);
} }
if args.fix_rp_filter {
self.fix_rp_filter = true;
}
if let Some(val) = args.ip {
self.ip = Some(val);
}
if let Some(val) = args.ifup { if let Some(val) = args.ifup {
self.ifup = Some(val); self.ifup = Some(val);
} }
if let Some(val) = args.ifdown { if let Some(val) = args.ifdown {
self.ifdown = Some(val); self.ifdown = Some(val);
} }
if let Some(val) = args.crypto {
self.crypto = val;
}
if let Some(val) = args.key {
self.shared_key = Some(val);
}
if let Some(val) = args.network_id {
warn!("The --network-id argument is deprecated, please use --magic instead.");
self.magic = Some(val);
}
if let Some(val) = args.magic {
self.magic = Some(val);
}
if let Some(val) = args.listen { if let Some(val) = args.listen {
self.listen = parse_listen(&val); self.listen = parse_listen(&val);
} }
self.peers.append(&mut args.connect); self.peers.append(&mut args.peers);
if let Some(val) = args.peer_timeout { if let Some(val) = args.peer_timeout {
self.peer_timeout = val; self.peer_timeout = val;
} }
@ -231,13 +235,19 @@ impl Config {
if let Some(val) = args.beacon_interval { if let Some(val) = args.beacon_interval {
self.beacon_interval = val; self.beacon_interval = val;
} }
if let Some(val) = args.beacon_password {
self.beacon_password = Some(val);
}
if let Some(val) = args.mode { if let Some(val) = args.mode {
self.mode = val; self.mode = val;
} }
if let Some(val) = args.dst_timeout { if let Some(val) = args.switch_timeout {
self.dst_timeout = val; self.switch_timeout = val;
}
self.claims.append(&mut args.claims);
if args.no_auto_claim {
self.auto_claim = false;
} }
self.subnets.append(&mut args.subnets);
if args.no_port_forwarding { if args.no_port_forwarding {
self.port_forwarding = false; self.port_forwarding = false;
} }
@ -262,24 +272,18 @@ impl Config {
if let Some(val) = args.group { if let Some(val) = args.group {
self.group = Some(val); self.group = Some(val);
} }
} if let Some(val) = args.password {
self.crypto.password = Some(val)
pub fn get_magic(&self) -> HeaderMagic { }
if let Some(ref name) = self.magic { if let Some(val) = args.public_key {
if name.starts_with(HASH_PREFIX) { self.crypto.public_key = Some(val)
let mut s = SipHasher24::new(); }
name[HASH_PREFIX.len()..].hash(&mut s); if let Some(val) = args.private_key {
let mut data = [0; 4]; self.crypto.private_key = Some(val)
Encoder::write_u32((s.finish() & 0xffff_ffff) as u32, &mut data); }
data self.crypto.trusted_keys.append(&mut args.trusted_keys);
} else { if !args.algorithms.is_empty() {
let num = try_fail!(u32::from_str_radix(name, 16), "Failed to parse header magic: {}"); self.crypto.algorithms = args.algorithms.clone();
let mut data = [0; 4];
Encoder::write_u32(num, &mut data);
data
}
} else {
MAGIC
} }
} }
@ -292,45 +296,190 @@ impl Config {
} }
#[derive(Serialize, Deserialize, Debug, PartialEq, Default)] #[derive(StructOpt, Debug, Default)]
pub struct ConfigFile { pub struct Args {
#[serde(alias = "device-type")] /// Read configuration options from the specified file.
pub device_type: Option<Type>, #[structopt(long)]
#[serde(alias = "device-name")] pub config: Option<String>,
pub device_name: Option<String>,
#[serde(alias = "device-path")] /// Set the type of network
#[structopt(name = "type", short, long, possible_values=&["tun", "tap"])]
pub type_: Option<Type>,
/// Set the path of the base device
#[structopt(long)]
pub device_path: Option<String>, pub device_path: Option<String>,
/// Fix the rp_filter settings on the host
#[structopt(long)]
pub fix_rp_filter: bool,
/// The mode of the VPN
#[structopt(short, long, possible_values=&["normal", "router", "switch", "hub"])]
pub mode: Option<Mode>,
/// The shared password to encrypt all traffic
#[structopt(short, long, required_unless = "private-key", env)]
pub password: Option<String>,
/// The private key to use
#[structopt(long, alias = "key", conflicts_with = "password", env)]
pub private_key: Option<String>,
/// The public key to use
#[structopt(long)]
pub public_key: Option<String>,
/// Other public keys to trust
#[structopt(long, name = "trusted-key", alias = "trust", use_delimiter = true)]
pub trusted_keys: Vec<String>,
/// Algorithms to allow
#[structopt(long, name = "algorithm", alias = "algo", use_delimiter=true, case_insensitive = true, possible_values=&["plain", "aes128", "aes256", "chacha20"])]
pub algorithms: Vec<String>,
/// The local subnets to claim (IP or IP/prefix)
#[structopt(long, name = "claim", use_delimiter = true)]
pub claims: Vec<String>,
/// Do not automatically claim the device ip
#[structopt(long)]
pub no_auto_claim: bool,
/// Name of the virtual device
#[structopt(short, long)]
pub device: Option<String>,
/// The port number (or ip:port) on which to listen for data
#[structopt(short, long)]
pub listen: Option<String>,
/// Address of a peer to connect to
#[structopt(short = "c", name = "peer", long, alias = "connect")]
pub peers: Vec<String>,
/// Peer timeout in seconds
#[structopt(long)]
pub peer_timeout: Option<Duration>,
/// Periodically send message to keep connections alive
#[structopt(long)]
pub keepalive: Option<Duration>,
/// Switch table entry timeout in seconds
#[structopt(long)]
pub switch_timeout: Option<Duration>,
/// The file path or |command to store the beacon
#[structopt(long)]
pub beacon_store: Option<String>,
/// The file path or |command to load the beacon
#[structopt(long)]
pub beacon_load: Option<String>,
/// Beacon store/load interval in seconds
#[structopt(long)]
pub beacon_interval: Option<Duration>,
/// Password to encrypt the beacon with
#[structopt(long)]
pub beacon_password: Option<String>,
/// Print debug information
#[structopt(short, long, conflicts_with = "quiet")]
pub verbose: bool,
/// Only print errors and warnings
#[structopt(short, long)]
pub quiet: bool,
/// An IP address (plus optional prefix length) for the interface
#[structopt(long)]
pub ip: Option<String>,
/// A command to setup the network interface
#[structopt(long)]
pub ifup: Option<String>,
/// A command to bring down the network interface
#[structopt(long)]
pub ifdown: Option<String>,
/// Print the version and exit
#[structopt(long)]
pub version: bool,
/// Generate and print a key-pair and exit
#[structopt(long, conflicts_with_all=&["password", "private_key"])]
pub genkey: bool,
/// Disable automatic port forwarding
#[structopt(long)]
pub no_port_forwarding: bool,
/// Run the process in the background
#[structopt(long)]
pub daemon: bool,
/// Store the process id in this file when daemonizing
#[structopt(long)]
pub pid_file: Option<String>,
/// Print statistics to this file
#[structopt(long)]
pub stats_file: Option<String>,
/// Send statistics to this statsd server
#[structopt(long)]
pub statsd_server: Option<String>,
/// Use the given prefix for statsd records
#[structopt(long, requires = "statsd-server")]
pub statsd_prefix: Option<String>,
/// Run as other user
#[structopt(long)]
pub user: Option<String>,
/// Run as other group
#[structopt(long)]
pub group: Option<String>,
/// Print logs also to this file
#[structopt(long)]
pub log_file: Option<String>
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Default)]
#[serde(rename_all = "kebab-case", deny_unknown_fields, default)]
pub struct ConfigFile {
pub device_type: Option<Type>,
pub device_name: Option<String>,
pub device_path: Option<String>,
pub fix_rp_filter: Option<bool>,
pub ip: Option<String>,
pub ifup: Option<String>, pub ifup: Option<String>,
pub ifdown: Option<String>, pub ifdown: Option<String>,
pub crypto: Option<CryptoMethod>,
#[serde(alias = "shared-key")] pub crypto: CryptoConfig,
pub shared_key: Option<String>,
pub magic: Option<String>,
pub port: Option<u16>,
pub listen: Option<String>, pub listen: Option<String>,
pub peers: Option<Vec<String>>, pub peers: Option<Vec<String>>,
#[serde(alias = "peer-timeout")]
pub peer_timeout: Option<Duration>, pub peer_timeout: Option<Duration>,
pub keepalive: Option<Duration>, pub keepalive: Option<Duration>,
#[serde(alias = "beacon-store")]
pub beacon_store: Option<String>, pub beacon_store: Option<String>,
#[serde(alias = "beacon-load")]
pub beacon_load: Option<String>, pub beacon_load: Option<String>,
#[serde(alias = "beacon-interval")]
pub beacon_interval: Option<Duration>, pub beacon_interval: Option<Duration>,
pub beacon_password: Option<String>,
pub mode: Option<Mode>, pub mode: Option<Mode>,
#[serde(alias = "dst-timeout")] pub switch_timeout: Option<Duration>,
pub dst_timeout: Option<Duration>, pub claims: Option<Vec<String>>,
pub subnets: Option<Vec<String>>, pub auto_claim: Option<bool>,
#[serde(alias = "port-forwarding")]
pub port_forwarding: Option<bool>, pub port_forwarding: Option<bool>,
#[serde(alias = "pid-file")]
pub pid_file: Option<String>, pub pid_file: Option<String>,
#[serde(alias = "stats-file")]
pub stats_file: Option<String>, pub stats_file: Option<String>,
#[serde(alias = "statsd-server")]
pub statsd_server: Option<String>, pub statsd_server: Option<String>,
#[serde(alias = "statsd-prefix")]
pub statsd_prefix: Option<String>, pub statsd_prefix: Option<String>,
pub user: Option<String>, pub user: Option<String>,
pub group: Option<String> pub group: Option<String>
@ -340,45 +489,42 @@ pub struct ConfigFile {
#[test] #[test]
fn config_file() { fn config_file() {
let config_file = " let config_file = "
device_type: tun device-type: tun
device_name: vpncloud%d device-name: vpncloud%d
device_path: /dev/net/tun device-path: /dev/net/tun
magic: 0123ABCD ip: 10.0.1.1/16
ifup: ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up ifup: ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up
ifdown: 'true' ifdown: 'true'
crypto: aes256
shared_key: mysecret
port: 3210
peers: peers:
- remote.machine.foo:3210 - remote.machine.foo:3210
- remote.machine.bar:3210 - remote.machine.bar:3210
peer_timeout: 600 peer-timeout: 600
keepalive: 840 keepalive: 840
dst_timeout: 300 switch-timeout: 300
beacon_store: /run/vpncloud.beacon.out beacon-store: /run/vpncloud.beacon.out
beacon_load: /run/vpncloud.beacon.in beacon-load: /run/vpncloud.beacon.in
beacon_interval: 3600 beacon-interval: 3600
beacon-password: test123
mode: normal mode: normal
subnets: claims:
- 10.0.1.0/24 - 10.0.1.0/24
port_forwarding: true port-forwarding: true
user: nobody user: nobody
group: nogroup group: nogroup
pid_file: /run/vpncloud.run pid-file: /run/vpncloud.run
stats_file: /var/log/vpncloud.stats stats-file: /var/log/vpncloud.stats
statsd_server: example.com:1234 statsd-server: example.com:1234
statsd_prefix: prefix statsd-prefix: prefix
"; ";
assert_eq!(serde_yaml::from_str::<ConfigFile>(config_file).unwrap(), ConfigFile { assert_eq!(serde_yaml::from_str::<ConfigFile>(config_file).unwrap(), ConfigFile {
device_type: Some(Type::Tun), device_type: Some(Type::Tun),
device_name: Some("vpncloud%d".to_string()), device_name: Some("vpncloud%d".to_string()),
device_path: Some("/dev/net/tun".to_string()), device_path: Some("/dev/net/tun".to_string()),
fix_rp_filter: None,
ip: Some("10.0.1.1/16".to_string()),
ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()), ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()),
ifdown: Some("true".to_string()), ifdown: Some("true".to_string()),
crypto: Some(CryptoMethod::AES256), crypto: CryptoConfig::default(),
shared_key: Some("mysecret".to_string()),
magic: Some("0123ABCD".to_string()),
port: Some(3210),
listen: None, listen: None,
peers: Some(vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()]), peers: Some(vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()]),
peer_timeout: Some(600), peer_timeout: Some(600),
@ -386,9 +532,11 @@ statsd_prefix: prefix
beacon_store: Some("/run/vpncloud.beacon.out".to_string()), beacon_store: Some("/run/vpncloud.beacon.out".to_string()),
beacon_load: Some("/run/vpncloud.beacon.in".to_string()), beacon_load: Some("/run/vpncloud.beacon.in".to_string()),
beacon_interval: Some(3600), beacon_interval: Some(3600),
beacon_password: Some("test123".to_string()),
mode: Some(Mode::Normal), mode: Some(Mode::Normal),
dst_timeout: Some(300), switch_timeout: Some(300),
subnets: Some(vec!["10.0.1.0/24".to_string()]), claims: Some(vec!["10.0.1.0/24".to_string()]),
auto_claim: None,
port_forwarding: Some(true), port_forwarding: Some(true),
user: Some("nobody".to_string()), user: Some("nobody".to_string()),
group: Some("nogroup".to_string()), group: Some("nogroup".to_string()),
@ -406,12 +554,11 @@ fn config_merge() {
device_type: Some(Type::Tun), device_type: Some(Type::Tun),
device_name: Some("vpncloud%d".to_string()), device_name: Some("vpncloud%d".to_string()),
device_path: None, device_path: None,
fix_rp_filter: None,
ip: None,
ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()), ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()),
ifdown: Some("true".to_string()), ifdown: Some("true".to_string()),
crypto: Some(CryptoMethod::AES256), crypto: CryptoConfig::default(),
shared_key: Some("mysecret".to_string()),
magic: Some("0123ABCD".to_string()),
port: Some(3210),
listen: None, listen: None,
peers: Some(vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()]), peers: Some(vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()]),
peer_timeout: Some(600), peer_timeout: Some(600),
@ -419,9 +566,11 @@ fn config_merge() {
beacon_store: Some("/run/vpncloud.beacon.out".to_string()), beacon_store: Some("/run/vpncloud.beacon.out".to_string()),
beacon_load: Some("/run/vpncloud.beacon.in".to_string()), beacon_load: Some("/run/vpncloud.beacon.in".to_string()),
beacon_interval: Some(7200), beacon_interval: Some(7200),
beacon_password: Some("test123".to_string()),
mode: Some(Mode::Normal), mode: Some(Mode::Normal),
dst_timeout: Some(300), switch_timeout: Some(300),
subnets: Some(vec!["10.0.1.0/24".to_string()]), claims: Some(vec!["10.0.1.0/24".to_string()]),
auto_claim: Some(true),
port_forwarding: Some(true), port_forwarding: Some(true),
user: Some("nobody".to_string()), user: Some("nobody".to_string()),
group: Some("nogroup".to_string()), group: Some("nogroup".to_string()),
@ -434,22 +583,21 @@ fn config_merge() {
device_type: Type::Tun, device_type: Type::Tun,
device_name: "vpncloud%d".to_string(), device_name: "vpncloud%d".to_string(),
device_path: None, device_path: None,
ip: None,
ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()), ifup: Some("ifconfig $IFNAME 10.0.1.1/16 mtu 1400 up".to_string()),
ifdown: Some("true".to_string()), ifdown: Some("true".to_string()),
magic: Some("0123ABCD".to_string()),
crypto: CryptoMethod::AES256,
shared_key: Some("mysecret".to_string()),
listen: "[::]:3210".parse::<SocketAddr>().unwrap(), listen: "[::]:3210".parse::<SocketAddr>().unwrap(),
peers: vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()], peers: vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string()],
peer_timeout: 600, peer_timeout: 600,
keepalive: Some(840), keepalive: Some(840),
dst_timeout: 300, switch_timeout: 300,
beacon_store: Some("/run/vpncloud.beacon.out".to_string()), beacon_store: Some("/run/vpncloud.beacon.out".to_string()),
beacon_load: Some("/run/vpncloud.beacon.in".to_string()), beacon_load: Some("/run/vpncloud.beacon.in".to_string()),
beacon_interval: 7200, beacon_interval: 7200,
beacon_password: Some("test123".to_string()),
mode: Mode::Normal, mode: Mode::Normal,
port_forwarding: true, port_forwarding: true,
subnets: vec!["10.0.1.0/24".to_string()], claims: vec!["10.0.1.0/24".to_string()],
user: Some("nobody".to_string()), user: Some("nobody".to_string()),
group: Some("nogroup".to_string()), group: Some("nogroup".to_string()),
pid_file: Some("/run/vpncloud.run".to_string()), pid_file: Some("/run/vpncloud.run".to_string()),
@ -464,19 +612,18 @@ fn config_merge() {
device_path: Some("/dev/null".to_string()), device_path: Some("/dev/null".to_string()),
ifup: Some("ifconfig $IFNAME 10.0.1.2/16 mtu 1400 up".to_string()), ifup: Some("ifconfig $IFNAME 10.0.1.2/16 mtu 1400 up".to_string()),
ifdown: Some("ifconfig $IFNAME down".to_string()), ifdown: Some("ifconfig $IFNAME down".to_string()),
crypto: Some(CryptoMethod::ChaCha20), password: Some("anothersecret".to_string()),
key: Some("anothersecret".to_string()),
magic: Some("hash:mynet".to_string()),
listen: Some("3211".to_string()), listen: Some("3211".to_string()),
peer_timeout: Some(1801), peer_timeout: Some(1801),
keepalive: Some(850), keepalive: Some(850),
dst_timeout: Some(301), switch_timeout: Some(301),
beacon_store: Some("/run/vpncloud.beacon.out2".to_string()), beacon_store: Some("/run/vpncloud.beacon.out2".to_string()),
beacon_load: Some("/run/vpncloud.beacon.in2".to_string()), beacon_load: Some("/run/vpncloud.beacon.in2".to_string()),
beacon_interval: Some(3600), beacon_interval: Some(3600),
beacon_password: Some("test1234".to_string()),
mode: Some(Mode::Switch), mode: Some(Mode::Switch),
subnets: vec![], claims: vec![],
connect: vec!["another:3210".to_string()], peers: vec!["another:3210".to_string()],
no_port_forwarding: true, no_port_forwarding: true,
daemon: true, daemon: true,
pid_file: Some("/run/vpncloud-mynet.run".to_string()), pid_file: Some("/run/vpncloud-mynet.run".to_string()),
@ -491,11 +638,11 @@ fn config_merge() {
device_type: Type::Tap, device_type: Type::Tap,
device_name: "vpncloud0".to_string(), device_name: "vpncloud0".to_string(),
device_path: Some("/dev/null".to_string()), device_path: Some("/dev/null".to_string()),
fix_rp_filter: false,
ip: None,
ifup: Some("ifconfig $IFNAME 10.0.1.2/16 mtu 1400 up".to_string()), ifup: Some("ifconfig $IFNAME 10.0.1.2/16 mtu 1400 up".to_string()),
ifdown: Some("ifconfig $IFNAME down".to_string()), ifdown: Some("ifconfig $IFNAME down".to_string()),
magic: Some("hash:mynet".to_string()), crypto: CryptoConfig { password: Some("anothersecret".to_string()), ..CryptoConfig::default() },
crypto: CryptoMethod::ChaCha20,
shared_key: Some("anothersecret".to_string()),
listen: "[::]:3211".parse::<SocketAddr>().unwrap(), listen: "[::]:3211".parse::<SocketAddr>().unwrap(),
peers: vec![ peers: vec![
"remote.machine.foo:3210".to_string(), "remote.machine.foo:3210".to_string(),
@ -504,13 +651,15 @@ fn config_merge() {
], ],
peer_timeout: 1801, peer_timeout: 1801,
keepalive: Some(850), keepalive: Some(850),
dst_timeout: 301, switch_timeout: 301,
beacon_store: Some("/run/vpncloud.beacon.out2".to_string()), beacon_store: Some("/run/vpncloud.beacon.out2".to_string()),
beacon_load: Some("/run/vpncloud.beacon.in2".to_string()), beacon_load: Some("/run/vpncloud.beacon.in2".to_string()),
beacon_interval: 3600, beacon_interval: 3600,
beacon_password: Some("test1234".to_string()),
mode: Mode::Switch, mode: Mode::Switch,
port_forwarding: false, port_forwarding: false,
subnets: vec!["10.0.1.0/24".to_string()], claims: vec!["10.0.1.0/24".to_string()],
auto_claim: true,
user: Some("root".to_string()), user: Some("root".to_string()),
group: Some("root".to_string()), group: Some("root".to_string()),
pid_file: Some("/run/vpncloud-mynet.run".to_string()), pid_file: Some("/run/vpncloud-mynet.run".to_string()),

View File

@ -1,254 +0,0 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use std::{num::NonZeroU32, str::FromStr};
use ring::{aead::*, pbkdf2, rand::*};
use super::types::Error;
const SALT: &[u8; 32] = b"vpncloudVPNCLOUDvpncl0udVpnCloud";
const HEX_PREFIX: &str = "hex:";
const HASH_PREFIX: &str = "hash:";
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Copy)]
pub enum CryptoMethod {
#[serde(rename = "chacha20")]
ChaCha20,
#[serde(rename = "aes256")]
AES256,
#[serde(rename = "aes128")]
AES128
}
impl FromStr for CryptoMethod {
type Err = &'static str;
fn from_str(text: &str) -> Result<Self, Self::Err> {
Ok(match &text.to_lowercase() as &str {
"chacha20" | "chacha" => Self::ChaCha20,
"aes256" => Self::AES256,
"aes128" | "aes" => Self::AES128,
_ => return Err("Unknown method")
})
}
}
pub struct CryptoData {
crypto_key: LessSafeKey,
nonce: Vec<u8>,
key: Vec<u8>
}
#[allow(unknown_lints, clippy::large_enum_variant)]
pub enum Crypto {
None,
ChaCha20Poly1305(CryptoData),
AES256GCM(CryptoData),
AES128GCM(CryptoData)
}
fn inc_nonce(nonce: &mut [u8]) {
let l = nonce.len();
for i in (0..l).rev() {
let mut num = nonce[i];
num = num.wrapping_add(1);
nonce[i] = num;
if num > 0 {
return
}
}
warn!("Nonce overflowed");
}
impl Crypto {
#[inline]
pub fn method(&self) -> u8 {
match *self {
Crypto::None => 0,
Crypto::ChaCha20Poly1305 { .. } => 1,
Crypto::AES256GCM { .. } => 2,
Crypto::AES128GCM { .. } => 3
}
}
#[inline]
pub fn nonce_bytes(&self) -> usize {
match *self {
Crypto::None => 0,
Crypto::ChaCha20Poly1305(ref data) | Crypto::AES256GCM(ref data) | Crypto::AES128GCM(ref data) => {
data.crypto_key.algorithm().nonce_len()
}
}
}
#[inline]
pub fn get_key(&self) -> &[u8] {
match *self {
Crypto::None => &[],
Crypto::ChaCha20Poly1305(ref data) | Crypto::AES256GCM(ref data) | Crypto::AES128GCM(ref data) => &data.key
}
}
#[inline]
#[allow(unknown_lints, clippy::match_same_arms)]
pub fn additional_bytes(&self) -> usize {
match *self {
Crypto::None => 0,
Crypto::ChaCha20Poly1305(ref data) | Crypto::AES256GCM(ref data) | Crypto::AES128GCM(ref data) => {
data.crypto_key.algorithm().tag_len()
}
}
}
pub fn from_shared_key(method: CryptoMethod, password: &str) -> Self {
let algo = match method {
CryptoMethod::ChaCha20 => &CHACHA20_POLY1305,
CryptoMethod::AES256 => &AES_256_GCM,
CryptoMethod::AES128 => &AES_128_GCM
};
let mut key: Vec<u8> = Vec::with_capacity(algo.key_len());
for _ in 0..algo.key_len() {
key.push(0);
}
if password.starts_with(HEX_PREFIX) {
let password = &password[HEX_PREFIX.len()..];
if password.len() != 2 * algo.key_len() {
fail!("Raw secret key must be exactly {} bytes long", algo.key_len());
}
for i in 0..algo.key_len() {
key[i] = try_fail!(
u8::from_str_radix(&password[2 * i..=2 * i + 1], 16),
"Failed to parse raw secret key: {}"
);
}
} else {
let password = if password.starts_with(HASH_PREFIX) { &password[HASH_PREFIX.len()..] } else { password };
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
NonZeroU32::new(4096).unwrap(),
SALT,
password.as_bytes(),
&mut key
);
}
let crypto_key = LessSafeKey::new(UnboundKey::new(algo, &key[..algo.key_len()]).expect("Failed to create key"));
let mut nonce: Vec<u8> = Vec::with_capacity(algo.nonce_len());
for _ in 0..algo.nonce_len() {
nonce.push(0);
}
if SystemRandom::new().fill(&mut nonce).is_err() {
fail!("Randomizing nonce failed");
}
// make sure the nonce will not overflow
if nonce[0] == 0xff {
nonce[0] = 0
}
let data = CryptoData { crypto_key, nonce, key };
match method {
CryptoMethod::ChaCha20 => Crypto::ChaCha20Poly1305(data),
CryptoMethod::AES256 => Crypto::AES256GCM(data),
CryptoMethod::AES128 => Crypto::AES128GCM(data)
}
}
pub fn decrypt(&self, buf: &mut [u8], nonce: &[u8], header: &[u8]) -> Result<usize, Error> {
match *self {
Crypto::None => Ok(buf.len()),
Crypto::ChaCha20Poly1305(ref data) | Crypto::AES256GCM(ref data) | Crypto::AES128GCM(ref data) => {
let nonce = Nonce::try_assume_unique_for_key(nonce).unwrap();
match data.crypto_key.open_in_place(nonce, Aad::from(header), buf) {
Ok(plaintext) => Ok(plaintext.len()),
Err(_) => Err(Error::Crypto("Failed to decrypt"))
}
}
}
}
pub fn encrypt(&mut self, buf: &mut [u8], mlen: usize, nonce_bytes: &mut [u8], header: &[u8]) -> usize {
let tag_len = self.additional_bytes();
match *self {
Crypto::None => mlen,
Crypto::ChaCha20Poly1305(ref mut data)
| Crypto::AES256GCM(ref mut data)
| Crypto::AES128GCM(ref mut data) => {
inc_nonce(&mut data.nonce);
assert!(buf.len() - mlen >= tag_len);
let nonce = Nonce::try_assume_unique_for_key(&data.nonce).unwrap();
let tag = data
.crypto_key
.seal_in_place_separate_tag(nonce, Aad::from(header), &mut buf[..mlen])
.expect("Failed to encrypt");
buf[mlen..mlen + tag_len].copy_from_slice(tag.as_ref());
nonce_bytes.clone_from_slice(&data.nonce);
mlen + tag_len
}
}
}
}
#[test]
fn encrypt_decrypt_chacha20poly1305() {
let mut sender = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test");
let receiver = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test");
let msg = "HelloWorld0123456789";
let msg_bytes = msg.as_bytes();
let mut buffer = [0u8; 1024];
let header = [0u8; 8];
buffer[..msg_bytes.len()].clone_from_slice(&msg_bytes);
let mut nonce1 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce1, &header);
assert_eq!(size, msg_bytes.len() + sender.additional_bytes());
assert!(msg_bytes != &buffer[..msg_bytes.len()] as &[u8]);
receiver.decrypt(&mut buffer[..size], &nonce1, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
let mut nonce2 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce2, &header);
assert!(nonce1 != nonce2);
receiver.decrypt(&mut buffer[..size], &nonce2, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
}
#[test]
fn encrypt_decrypt_aes256() {
let mut sender = Crypto::from_shared_key(CryptoMethod::AES256, "test");
let receiver = Crypto::from_shared_key(CryptoMethod::AES256, "test");
let msg = "HelloWorld0123456789";
let msg_bytes = msg.as_bytes();
let mut buffer = [0u8; 1024];
let header = [0u8; 8];
buffer[..msg_bytes.len()].clone_from_slice(&msg_bytes);
let mut nonce1 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce1, &header);
assert_eq!(size, msg_bytes.len() + sender.additional_bytes());
assert!(msg_bytes != &buffer[..msg_bytes.len()] as &[u8]);
receiver.decrypt(&mut buffer[..size], &nonce1, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
let mut nonce2 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce2, &header);
assert!(nonce1 != nonce2);
receiver.decrypt(&mut buffer[..size], &nonce2, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
}
#[test]
fn encrypt_decrypt_aes128() {
let mut sender = Crypto::from_shared_key(CryptoMethod::AES128, "test");
let receiver = Crypto::from_shared_key(CryptoMethod::AES128, "test");
let msg = "HelloWorld0123456789";
let msg_bytes = msg.as_bytes();
let mut buffer = [0u8; 1024];
let header = [0u8; 8];
buffer[..msg_bytes.len()].clone_from_slice(&msg_bytes);
let mut nonce1 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce1, &header);
assert_eq!(size, msg_bytes.len() + sender.additional_bytes());
assert!(msg_bytes != &buffer[..msg_bytes.len()] as &[u8]);
receiver.decrypt(&mut buffer[..size], &nonce1, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
let mut nonce2 = [0u8; 12];
let size = sender.encrypt(&mut buffer, msg_bytes.len(), &mut nonce2, &header);
assert!(nonce1 != nonce2);
receiver.decrypt(&mut buffer[..size], &nonce2, &header).unwrap();
assert_eq!(msg_bytes, &buffer[..msg_bytes.len()] as &[u8]);
}

494
src/crypto/core.rs Normal file
View File

@ -0,0 +1,494 @@
//! This module implements a crypto core for encrypting and decrypting message streams
//!
//! The crypto core only encrypts and decrypts messages, using given keys. Negotiating and rotating the keys is out of
//! scope of the crypto core. The crypto core assumes that the remote node will always have the necessary key to decrypt
//! the message.
//!
//! The crypto core encrypts messages in place, writes some extra data (key id and nonce) into a given space and
//! includes the given header data in the authentication tag. When decrypting messages, the crypto core reads the extra
//! data, uses the key id to find the right key to decrypting the message and then decrypts the message, using the given
//! nonce and including the given header data in the verification of the authentication tag.
//!
//! While the core only uses a single key at a time for encrypting messages, it is ready to decrypt messages based on
//! one of 4 stored keys (the encryption key being one of them). An external key rotation is responsible for adding the
//! key to the remote peer before switching to the key on the local peer for encryption.
//!
//! As mentioned, the encryption and decryption works in place. Therefore the parameter payload_and_tag contains (when
//! decrypting) or provides space for (when encrypting) the payload and the authentication tag. When encrypting, that
//! means, that the last TAG_LEN bytes of payload_and_tag must be reserved for the tag and must not contain payload
//! bytes.
//!
//! The nonce is a value of 12 bytes (192 bits). Since both nodes can use the same key for encryption, the most
//! significant byte (msb) of the nonce is initialized differently on both peers: one peer uses the value 0x00 and the
//! other one 0x80. That means that the nonce space is essentially divided in two halves, one for each node.
//!
//! To save space and keep the encrypted data aligned to 64 bits, not all bytes of the nonce are transferred. Instead,
//! only 7 bytes are included in messages (another byte is used for the key id, hence 64 bit alignment). The rest of the
//! nonce is deduced by the nodes: All other bytes are assumed to be 0x00, except for the most significant byte, which
//! is assumed to be the opposite ones own msb. This has two nice effects:
//! 1) Long before the nonce could theoretically repeat, the messages can no longer be decrypted by the peer as the
//! higher bytes are no longer zero as assumed.
//! 2) By deducing the msb to be the opposite of ones own msb, it is no longer possible for an attacker to redirect a
//! message back to the sender because then the assumed nonce will be wrong and the message fails to decrypt. Otherwise,
//! this could lead to problems as nodes would be able to accidentally decrypt their own messages.
//!
//! In order to be resistent against replay attacks but allow for reordering of messages, the crypto core uses nonce
//! pinning. For every active key, the biggest nonce seen so far is being tracked. Every second, the biggest nonce seen
//! one second ago plus 1 becomes the minimum nonce that is accepted for that key. That means, that reordering can
//! happen within one second but after a second, old messages will not be accepted anymore.
use byteorder::{ReadBytesExt, WriteBytesExt};
use ring::{
aead::{self, LessSafeKey, UnboundKey},
rand::{SecureRandom, SystemRandom}
};
use std::{
io::{Cursor, Read, Write},
mem,
time::{Duration, Instant}
};
use super::{Error, MsgBuffer};
const NONCE_LEN: usize = 12;
pub const TAG_LEN: usize = 16;
pub const EXTRA_LEN: usize = 8;
fn random_data(size: usize) -> Vec<u8> {
let rand = SystemRandom::new();
let mut data = vec![0; size];
rand.fill(&mut data).expect("Failed to obtain random bytes");
data
}
#[derive(PartialOrd, Ord, PartialEq, Debug, Eq, Clone)]
struct Nonce([u8; NONCE_LEN]);
impl Nonce {
fn zero() -> Self {
Nonce([0; NONCE_LEN])
}
fn random(rand: &SystemRandom) -> Self {
let mut nonce = Nonce::zero();
rand.fill(&mut nonce.0[6..]).expect("Failed to obtain random bytes");
nonce
}
fn set_msb(&mut self, val: u8) {
self.0[0] = val
}
fn as_bytes(&self) -> &[u8; NONCE_LEN] {
&self.0
}
fn increment(&mut self) {
for i in (0..NONCE_LEN).rev() {
let mut num = self.0[i];
num = num.wrapping_add(1);
self.0[i] = num;
if num > 0 {
return
}
}
}
}
struct CryptoKey {
key: LessSafeKey,
send_nonce: Nonce,
min_nonce: Nonce,
next_min_nonce: Nonce,
seen_nonce: Nonce
}
impl CryptoKey {
fn new(rand: &SystemRandom, key: LessSafeKey, nonce_half: bool) -> Self {
let mut send_nonce = Nonce::random(&rand);
send_nonce.set_msb(if nonce_half { 0x80 } else { 0x00 });
CryptoKey {
key,
send_nonce,
min_nonce: Nonce::zero(),
next_min_nonce: Nonce::zero(),
seen_nonce: Nonce::zero()
}
}
fn update_min_nonce(&mut self) {
mem::swap(&mut self.min_nonce, &mut self.next_min_nonce);
self.next_min_nonce = self.seen_nonce.clone();
self.next_min_nonce.increment();
}
}
pub struct CryptoCore {
rand: SystemRandom,
keys: [CryptoKey; 4],
current_key: usize,
nonce_half: bool
}
impl CryptoCore {
pub fn new(key: LessSafeKey, nonce_half: bool) -> Self {
let rand = SystemRandom::new();
let dummy_key_data = random_data(key.algorithm().key_len());
let dummy_key1 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap());
let dummy_key2 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap());
let dummy_key3 = LessSafeKey::new(UnboundKey::new(key.algorithm(), &dummy_key_data).unwrap());
Self {
keys: [
CryptoKey::new(&rand, key, nonce_half),
CryptoKey::new(&rand, dummy_key1, nonce_half),
CryptoKey::new(&rand, dummy_key2, nonce_half),
CryptoKey::new(&rand, dummy_key3, nonce_half)
],
current_key: 0,
nonce_half,
rand
}
}
pub fn encrypt(&mut self, buffer: &mut MsgBuffer) {
let data_start = buffer.get_start();
let data_length = buffer.len();
assert!(buffer.get_start() >= EXTRA_LEN);
buffer.set_start(data_start - EXTRA_LEN);
buffer.set_length(data_length + EXTRA_LEN + TAG_LEN);
let (extra, data_and_tag) = buffer.message_mut().split_at_mut(EXTRA_LEN);
let (data, tag_space) = data_and_tag.split_at_mut(data_length);
let key = &mut self.keys[self.current_key];
key.send_nonce.increment();
{
let mut extra = Cursor::new(extra);
extra.write_u8(self.current_key as u8).unwrap();
extra.write_all(&key.send_nonce.as_bytes()[5..]).unwrap();
}
let nonce = aead::Nonce::assume_unique_for_key(*key.send_nonce.as_bytes());
let tag = key.key.seal_in_place_separate_tag(nonce, aead::Aad::empty(), data).expect("Failed to encrypt");
tag_space.clone_from_slice(tag.as_ref());
}
fn decrypt_with_key<'a>(key: &mut CryptoKey, nonce: Nonce, data_and_tag: &'a mut [u8]) -> Result<(), Error> {
if nonce < key.min_nonce {
return Err(Error::Unauthorized("Old nonce rejected"))
}
// decrypt
let crypto_nonce = aead::Nonce::assume_unique_for_key(*nonce.as_bytes());
key.key
.open_in_place(crypto_nonce, aead::Aad::empty(), data_and_tag)
.map_err(|_| Error::Unauthorized("Failed to decrypt data"))?;
// last seen nonce
if key.seen_nonce < nonce {
key.seen_nonce = nonce;
}
Ok(())
}
pub fn decrypt(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
assert!(buffer.len() >= EXTRA_LEN + TAG_LEN);
let (extra, data_and_tag) = buffer.message_mut().split_at_mut(EXTRA_LEN);
let key_id;
let mut nonce;
{
let mut extra = Cursor::new(extra);
key_id = extra.read_u8().map_err(|_| Error::Crypto("Input data too short"))? % 4;
nonce = Nonce::zero();
extra.read_exact(&mut nonce.0[5..]).map_err(|_| Error::Crypto("Input data too short"))?;
nonce.set_msb(if self.nonce_half { 0x00 } else { 0x80 });
}
let key = &mut self.keys[key_id as usize];
let result = Self::decrypt_with_key(key, nonce, data_and_tag);
buffer.set_start(buffer.get_start() + EXTRA_LEN);
buffer.set_length(buffer.len() - TAG_LEN);
result
}
pub fn rotate_key(&mut self, key: LessSafeKey, id: u64, use_for_sending: bool) {
let id = (id % 4) as usize;
self.keys[id] = CryptoKey::new(&self.rand, key, self.nonce_half);
if use_for_sending {
self.current_key = id
}
}
pub fn algorithm(&self) -> &'static aead::Algorithm {
self.keys[self.current_key].key.algorithm()
}
pub fn every_second(&mut self) {
// Set min nonce on all keys
for k in &mut self.keys {
k.update_min_nonce();
}
}
}
pub fn test_speed(algo: &'static aead::Algorithm, max_time: &Duration) -> f64 {
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.set_length(1000);
let key_data = random_data(algo.key_len());
let mut sender = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key_data).unwrap()), true);
let mut receiver = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key_data).unwrap()), false);
let mut iterations = 0;
let start = Instant::now();
while (Instant::now() - start).as_nanos() < max_time.as_nanos() {
for _ in 0..1000 {
sender.encrypt(&mut buffer);
receiver.decrypt(&mut buffer).unwrap();
}
iterations += 1000;
}
let duration = (Instant::now() - start).as_secs_f64();
let data = iterations * 1000 * 2;
data as f64 / duration / 1_000_000.0
}
#[cfg(test)]
mod tests {
use super::*;
use ring::aead::{self, LessSafeKey, UnboundKey};
fn setup_pair(algo: &'static aead::Algorithm) -> (CryptoCore, CryptoCore) {
let key = random_data(algo.key_len());
let crypto1 = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key).unwrap()), false);
let crypto2 = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key).unwrap()), true);
(crypto1, crypto2)
}
#[test]
fn test_nonce() {
let mut nonce = Nonce::zero();
assert_eq!(nonce.as_bytes(), &[0; 12]);
nonce.increment();
assert_eq!(nonce.as_bytes(), &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
nonce.increment();
assert_eq!(nonce.as_bytes(), &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]);
}
fn test_encrypt_decrypt(algo: &'static aead::Algorithm) {
let (mut sender, mut receiver) = setup_pair(algo);
let plain = random_data(1000);
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.clone_from(&plain);
assert_eq!(&plain[..], buffer.message());
sender.encrypt(&mut buffer);
assert_ne!(&plain[..], buffer.message());
receiver.decrypt(&mut buffer).unwrap();
assert_eq!(&plain[..], buffer.message());
}
#[test]
fn test_encrypt_decrypt_aes128() {
test_encrypt_decrypt(&aead::AES_128_GCM)
}
#[test]
fn test_encrypt_decrypt_aes256() {
test_encrypt_decrypt(&aead::AES_256_GCM)
}
#[test]
fn test_encrypt_decrypt_chacha() {
test_encrypt_decrypt(&aead::CHACHA20_POLY1305)
}
fn test_tampering(algo: &'static aead::Algorithm) {
let (mut sender, mut receiver) = setup_pair(algo);
let plain = random_data(1000);
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.clone_from(&plain);
sender.encrypt(&mut buffer);
let mut d = buffer.clone();
assert!(receiver.decrypt(&mut d,).is_ok());
// Tamper with extra data byte 1 (subkey id)
d = buffer.clone();
d.message_mut()[0] ^= 1;
assert!(receiver.decrypt(&mut d).is_err());
// Tamper with extra data byte 2 (nonce)
d = buffer.clone();
d.message_mut()[1] ^= 1;
assert!(receiver.decrypt(&mut d).is_err());
// Tamper with data itself
d = buffer.clone();
d.message_mut()[EXTRA_LEN] ^= 1;
assert!(receiver.decrypt(&mut d).is_err());
// Check everything still works
d = buffer;
assert!(receiver.decrypt(&mut d).is_ok());
}
#[test]
fn test_tampering_aes128() {
test_tampering(&aead::AES_128_GCM)
}
#[test]
fn test_tampering_aes256() {
test_tampering(&aead::AES_256_GCM)
}
#[test]
fn test_tampering_chacha() {
test_tampering(&aead::CHACHA20_POLY1305)
}
fn test_nonce_pinning(algo: &'static aead::Algorithm) {
let (mut sender, mut receiver) = setup_pair(algo);
let plain = random_data(1000);
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.clone_from(&plain);
sender.encrypt(&mut buffer);
{
let mut d = buffer.clone();
assert!(receiver.decrypt(&mut d).is_ok());
}
receiver.every_second();
{
let mut d = buffer.clone();
assert!(receiver.decrypt(&mut d).is_ok());
}
receiver.every_second();
{
let mut d = buffer;
assert!(receiver.decrypt(&mut d).is_err());
}
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.clone_from(&plain);
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_ok());
}
#[test]
fn test_nonce_pinning_aes128() {
test_nonce_pinning(&aead::AES_128_GCM)
}
#[test]
fn test_nonce_pinning_aes256() {
test_nonce_pinning(&aead::AES_256_GCM)
}
#[test]
fn test_nonce_pinning_chacha() {
test_nonce_pinning(&aead::CHACHA20_POLY1305)
}
fn test_key_rotation(algo: &'static aead::Algorithm) {
let (mut sender, mut receiver) = setup_pair(algo);
let plain = random_data(1000);
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.clone_from(&plain);
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_ok());
let new_key = random_data(algo.key_len());
receiver.rotate_key(LessSafeKey::new(UnboundKey::new(algo, &new_key).unwrap()), 1, false);
receiver.encrypt(&mut buffer);
assert!(sender.decrypt(&mut buffer).is_ok());
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_ok());
sender.rotate_key(LessSafeKey::new(UnboundKey::new(algo, &new_key).unwrap()), 1, true);
receiver.encrypt(&mut buffer);
assert!(sender.decrypt(&mut buffer).is_ok());
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_ok());
let new_key = random_data(algo.key_len());
sender.rotate_key(LessSafeKey::new(UnboundKey::new(algo, &new_key).unwrap()), 2, true);
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_err());
receiver.encrypt(&mut buffer);
assert!(sender.decrypt(&mut buffer).is_ok());
receiver.rotate_key(LessSafeKey::new(UnboundKey::new(algo, &new_key).unwrap()), 2, false);
receiver.encrypt(&mut buffer);
assert!(sender.decrypt(&mut buffer).is_ok());
sender.encrypt(&mut buffer);
assert!(receiver.decrypt(&mut buffer).is_ok());
}
#[test]
fn test_key_rotation_aes128() {
test_key_rotation(&aead::AES_128_GCM);
}
#[test]
fn test_key_rotation_aes256() {
test_key_rotation(&aead::AES_256_GCM);
}
#[test]
fn test_key_rotation_chacha() {
test_key_rotation(&aead::CHACHA20_POLY1305);
}
#[test]
fn test_core_size() {
assert_eq!(2384, mem::size_of::<CryptoCore>());
}
#[test]
fn test_speed_aes128() {
let speed = test_speed(&aead::AES_128_GCM, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0);
}
#[test]
fn test_speed_aes256() {
let speed = test_speed(&aead::AES_256_GCM, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0);
}
#[test]
fn test_speed_chacha() {
let speed = test_speed(&aead::CHACHA20_POLY1305, &Duration::from_secs_f32(0.2));
assert!(speed > 10.0);
}
}
#[cfg(feature = "bench")]
mod benches {
use super::*;
use test::Bencher;
fn crypto_bench(b: &mut Bencher, algo: &'static aead::Algorithm) {
let mut buffer = MsgBuffer::new(EXTRA_LEN);
buffer.set_length(1400);
let key_data = random_data(algo.key_len());
let mut sender = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key_data).unwrap()), true);
let mut receiver = CryptoCore::new(LessSafeKey::new(UnboundKey::new(algo, &key_data).unwrap()), false);
b.iter(|| {
sender.encrypt(&mut buffer);
receiver.decrypt(&mut buffer).unwrap();
});
b.bytes = 1400;
}
#[bench]
fn crypto_chacha20(b: &mut Bencher) {
crypto_bench(b, &aead::CHACHA20_POLY1305)
}
#[bench]
fn crypto_aes128(b: &mut Bencher) {
crypto_bench(b, &aead::AES_128_GCM)
}
#[bench]
fn crypto_aes256(b: &mut Bencher) {
crypto_bench(b, &aead::AES_256_GCM)
}
}

690
src/crypto/init.rs Normal file
View File

@ -0,0 +1,690 @@
//! This module implements a 3-way handshake to initialize an authenticated and encrypted connection.
//!
//! The handshake assumes that each node has a asymmetric Curve 25519 key pair as well as a list of trusted public keys
//! and a set of supported crypto algorithms as well as the expected speed when using them. If successful, the handshake
//! will negotiate a crypto algorithm to use and a common ephemeral symmetric key and exchange a given payload between
//! the nodes.
//!
//! The handshake consists of 3 stages, "ping", "pong" and "peng". In the following description, the node that initiates
//! the connection is named "A" and the other node is named "B". Since a lot of things are going on in parallel in the
//! handshake, those aspects are described separately in the following paragraphs.
//!
//! Every message contains the node id of the sender. If a node receives a message with its own node id, it just ignores
//! it and closes the connection. This is the way nodes avoid to connect to themselves as it is not trivial for a node
//! to know its own addresses (especially in the case of NAT).
//!
//! All initialization messages are signed by the asymmetric key of the sender. Also the messages indicate the public
//! key being used, so the receiver can use the correct public key to verify the signature. The public key itself is not
//! attached to the message for privacy reasons (the public key is stable over multiple restarts while the node id is
//! only valid for a single run). Instead, a 2 byte salt value as well as the last 2 bytes of the salted sha 2 hash of
//! the public key are used to identify the public key. This way, a receiver that trusts this public key can identify
//! it but a random observer can't. If the public key is unknown or the signature can't be verified, the message is
//! ignored.
//!
//! Every message contains a byte that specifies the stage (ping = 1, pong = 2, peng = 3). If a message with an
//! unexpected stage is received, it is ignored and the last message that has been sent is repeated. There is only one
//! exception to this rule: if a "pong" message is expected, but a "ping" message is received instead AND the node id of
//! the sender is greater than the node id of the receiver, the receiving node will reset its state and assume the role
//! of a receiver of the initialization (i.e. "B"). This is used to "negotiate" the roles A and B when both nodes
//! initiate the connection in parallel and think they are A.
//!
//! Upon connection creation, both nodes create a random ephemeral ECDH key pair and exchange the public keys in the
//! ping and pong messages. A sends the ping message to B containing A's public key and B replies with a pong message
//! containing B's public key. That means, that after receiving the ping message B can calculate the shared key material
//! and after receiving the pong message A can calculate the shared key material.
//!
//! The ping message and the pong message contain a set of supported crypto algorithms together with the estimated
//! speeds of the algorithms. When B receives a ping message, or A receives a pong message, it can combine this
//! information with its own algorithm list and select the algorithm with the best expected speed for the crypto core.
//!
//! The pong and peng message contain the payload that the nodes want to exchange in the initialization phase apart from
//! the cryptographic initialization. This payload is encoded according to the application and encrypted using the key
//! material and the crypto algorithm that have been negotiated via the ping and pong messages. The pong message,
//! therefore contains information to set up symmetric encryption as well as a part that is already encrypted.
//!
//! The handshake ends for A after sending the peng message and for B after receiving this message. At this time both
//! nodes initialize the connection using the payload and enter normal operation. The negotiated crypto core is used for
//! future communication and the key rotation is started. Since the peng message can be lost, A needs to keep the
//! initialization state in order to repeat a lost peng message. After one second, A removes that state.
//!
//! Once every second, both nodes check whether they have already finished the initialization. If not, they repeat their
//! last message. After 5 seconds, the initialization is aborted as failed.
use super::{
core::{CryptoCore, EXTRA_LEN},
Algorithms, EcdhPrivateKey, EcdhPublicKey, Ed25519PublicKey, Error, MsgBuffer, Payload
};
use crate::types::{NodeId, NODE_ID_BYTES};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use ring::{
aead::{Algorithm, LessSafeKey, UnboundKey, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305},
agreement::{agree_ephemeral, X25519},
digest,
rand::{SecureRandom, SystemRandom},
signature::{self, Ed25519KeyPair, KeyPair, ED25519, ED25519_PUBLIC_KEY_LEN}
};
use smallvec::{smallvec, SmallVec};
use std::{
cmp, f32,
fmt::Debug,
io::{self, Cursor, Read, Write},
sync::Arc
};
pub const STAGE_PING: u8 = 1;
pub const STAGE_PONG: u8 = 2;
pub const STAGE_PENG: u8 = 3;
pub const WAITING_TO_CLOSE: u8 = 4;
pub const CLOSING: u8 = 5;
#[allow(clippy::large_enum_variant)]
pub enum InitMsg {
Ping { node_id: NodeId, ecdh_public_key: EcdhPublicKey, algorithms: Algorithms },
Pong { node_id: NodeId, ecdh_public_key: EcdhPublicKey, algorithms: Algorithms, encrypted_payload: MsgBuffer },
Peng { node_id: NodeId, encrypted_payload: MsgBuffer }
}
impl InitMsg {
const PART_ALGORITHMS: u8 = 4;
const PART_ECDH_PUBLIC_KEY: u8 = 3;
const PART_END: u8 = 0;
const PART_NODE_ID: u8 = 2;
const PART_PAYLOAD: u8 = 5;
const PART_STAGE: u8 = 1;
fn stage(&self) -> u8 {
match self {
InitMsg::Ping { .. } => STAGE_PING,
InitMsg::Pong { .. } => STAGE_PONG,
InitMsg::Peng { .. } => STAGE_PENG
}
}
fn node_id(&self) -> NodeId {
match self {
InitMsg::Ping { node_id, .. } | InitMsg::Pong { node_id, .. } | InitMsg::Peng { node_id, .. } => *node_id
}
}
fn calculate_hash(key: &Ed25519PublicKey, salt: &[u8; 4]) -> [u8; 4] {
let mut data = [0; ED25519_PUBLIC_KEY_LEN + 4];
data[..ED25519_PUBLIC_KEY_LEN].clone_from_slice(key);
data[ED25519_PUBLIC_KEY_LEN..].clone_from_slice(salt);
let hash = digest::digest(&digest::SHA256, &data);
let mut short_hash = [0; 4];
short_hash.clone_from_slice(&hash.as_ref()[..4]);
short_hash
}
fn read_from(buffer: &[u8], trusted_keys: &[Ed25519PublicKey]) -> Result<(Self, Ed25519PublicKey), Error> {
let mut r = Cursor::new(buffer);
let mut public_key_salt = [0; 4];
r.read_exact(&mut public_key_salt).map_err(|_| Error::Parse("Init message too short"))?;
let mut public_key_hash = [0; 4];
r.read_exact(&mut public_key_hash).map_err(|_| Error::Parse("Init message too short"))?;
let mut public_key_data = [0; ED25519_PUBLIC_KEY_LEN];
let mut found_key = false;
for tk in trusted_keys {
if Self::calculate_hash(tk, &public_key_salt) == public_key_hash {
public_key_data.clone_from_slice(tk);
found_key = true;
break
}
}
if !found_key {
return Err(Error::Unauthorized("untrusted peer"))
}
let mut stage = None;
let mut node_id = None;
let mut ecdh_public_key = None;
let mut encrypted_payload = None;
let mut algorithms = None;
loop {
let field = r.read_u8().map_err(|_| Error::Parse("Init message too short"))?;
if field == Self::PART_END {
break
}
let field_len = r.read_u16::<NetworkEndian>().map_err(|_| Error::Parse("Init message too short"))? as usize;
match field {
Self::PART_STAGE => {
if field_len != 1 {
return Err(Error::CryptoInit("Invalid size for stage field"))
}
stage = Some(r.read_u8().map_err(|_| Error::Parse("Init message too short"))?)
}
Self::PART_NODE_ID => {
if field_len != NODE_ID_BYTES {
return Err(Error::CryptoInit("Invalid size for node id field"))
}
let mut id = [0; NODE_ID_BYTES];
r.read_exact(&mut id).map_err(|_| Error::Parse("Init message too short"))?;
node_id = Some(id)
}
Self::PART_ECDH_PUBLIC_KEY => {
let mut pub_key_data = smallvec![0; field_len];
r.read_exact(&mut pub_key_data).map_err(|_| Error::Parse("Init message too short"))?;
ecdh_public_key = Some(EcdhPublicKey::new(&X25519, pub_key_data));
}
Self::PART_PAYLOAD => {
let mut payload = MsgBuffer::new(0);
payload.set_length(field_len);
r.read_exact(payload.message_mut()).map_err(|_| Error::Parse("Init message too short"))?;
encrypted_payload = Some(payload);
}
Self::PART_ALGORITHMS => {
let count = field_len / 5;
let mut algos = SmallVec::with_capacity(count);
let mut allow_unencrypted = false;
for _ in 0..count {
let algo = match r.read_u8().map_err(|_| Error::Parse("Init message too short"))? {
0 => {
allow_unencrypted = true;
None
}
1 => Some(&AES_128_GCM),
2 => Some(&AES_256_GCM),
3 => Some(&CHACHA20_POLY1305),
_ => None
};
let speed =
r.read_f32::<NetworkEndian>().map_err(|_| Error::Parse("Init message too short"))?;
if let Some(algo) = algo {
algos.push((algo, speed));
}
}
algorithms = Some(Algorithms { algorithm_speeds: algos, allow_unencrypted });
}
_ => {
let mut data = vec![0; field_len];
r.read_exact(&mut data).map_err(|_| Error::Parse("Init message too short"))?;
}
}
}
let pos = r.position() as usize;
let signature_len = r.read_u8().map_err(|_| Error::Parse("Init message too short"))? as usize;
let mut signature = vec![0; signature_len];
r.read_exact(&mut signature).map_err(|_| Error::Parse("Init message too short"))?;
let signed_data = &r.into_inner()[0..pos];
let public_key = signature::UnparsedPublicKey::new(&ED25519, &public_key_data);
if public_key.verify(&signed_data, &signature).is_err() {
return Err(Error::Unauthorized("invalid signature"))
}
let stage = match stage {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without stage"))
};
let node_id = match node_id {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without node id"))
};
let msg = match stage {
STAGE_PING => {
let ecdh_public_key = match ecdh_public_key {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without ecdh public key"))
};
let algorithms = match algorithms {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without algorithms"))
};
Self::Ping { node_id, ecdh_public_key, algorithms }
}
STAGE_PONG => {
let ecdh_public_key = match ecdh_public_key {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without ecdh public key"))
};
let algorithms = match algorithms {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without algorithms"))
};
let encrypted_payload = match encrypted_payload {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without payload"))
};
Self::Pong { node_id, ecdh_public_key, algorithms, encrypted_payload }
}
STAGE_PENG => {
let encrypted_payload = match encrypted_payload {
Some(val) => val,
None => return Err(Error::CryptoInit("Init message without payload"))
};
Self::Peng { node_id, encrypted_payload }
}
_ => return Err(Error::CryptoInit("Invalid stage"))
};
Ok((msg, public_key_data))
}
fn write_to(&self, buffer: &mut [u8], key: &Ed25519KeyPair) -> Result<usize, io::Error> {
let mut w = Cursor::new(buffer);
let rand = SystemRandom::new();
let mut salt = [0; 4];
rand.fill(&mut salt).unwrap();
let mut public_key = [0; ED25519_PUBLIC_KEY_LEN];
public_key.clone_from_slice(key.public_key().as_ref());
let hash = Self::calculate_hash(&public_key, &salt);
w.write_all(&salt)?;
w.write_all(&hash)?;
w.write_u8(Self::PART_STAGE)?;
w.write_u16::<NetworkEndian>(1)?;
w.write_u8(self.stage())?;
match &self {
Self::Ping { node_id, .. } | Self::Pong { node_id, .. } | Self::Peng { node_id, .. } => {
w.write_u8(Self::PART_NODE_ID)?;
w.write_u16::<NetworkEndian>(NODE_ID_BYTES as u16)?;
w.write_all(node_id)?;
}
}
match &self {
Self::Ping { ecdh_public_key, .. } | Self::Pong { ecdh_public_key, .. } => {
w.write_u8(Self::PART_ECDH_PUBLIC_KEY)?;
let key_bytes = ecdh_public_key.bytes();
w.write_u16::<NetworkEndian>(key_bytes.len() as u16)?;
w.write_all(&key_bytes)?;
}
_ => ()
}
match &self {
Self::Ping { algorithms, .. } | Self::Pong { algorithms, .. } => {
w.write_u8(Self::PART_ALGORITHMS)?;
let mut len = algorithms.algorithm_speeds.len() * 5;
if algorithms.allow_unencrypted {
len += 5;
}
w.write_u16::<NetworkEndian>(len as u16)?;
if algorithms.allow_unencrypted {
w.write_u8(0)?;
w.write_f32::<NetworkEndian>(f32::INFINITY)?;
}
for (algo, speed) in &algorithms.algorithm_speeds {
if *algo == &AES_128_GCM {
w.write_u8(1)?;
} else if *algo == &AES_256_GCM {
w.write_u8(2)?;
} else if *algo == &CHACHA20_POLY1305 {
w.write_u8(3)?;
} else {
unreachable!();
}
w.write_f32::<NetworkEndian>(*speed)?;
}
}
_ => ()
}
match &self {
Self::Pong { encrypted_payload, .. } | Self::Peng { encrypted_payload, .. } => {
w.write_u8(Self::PART_PAYLOAD)?;
w.write_u16::<NetworkEndian>(encrypted_payload.len() as u16)?;
w.write_all(encrypted_payload.message())?;
}
_ => ()
}
w.write_u8(Self::PART_END)?;
let pos = w.position() as usize;
let signature = key.sign(&w.get_ref()[0..pos]);
w.write_u8(signature.as_ref().len() as u8)?;
w.write_all(signature.as_ref())?;
Ok(w.position() as usize)
}
}
#[derive(PartialEq, Debug)]
pub enum InitResult<P: Payload> {
Continue,
Success { peer_payload: P, node_id: NodeId, is_initiator: bool }
}
pub struct InitState<P: Payload> {
node_id: NodeId,
payload: P,
key_pair: Arc<Ed25519KeyPair>,
trusted_keys: Arc<Vec<Ed25519PublicKey>>,
ecdh_private_key: Option<EcdhPrivateKey>,
next_stage: u8,
close_time: usize,
last_message: Option<Vec<u8>>,
crypto: Option<CryptoCore>,
algorithms: Algorithms,
failed_retries: usize
}
impl<P: Payload> InitState<P> {
pub fn new(
node_id: NodeId, payload: P, key_pair: Arc<Ed25519KeyPair>, trusted_keys: Arc<Vec<Ed25519PublicKey>>,
algorithms: Algorithms
) -> Self
{
Self {
node_id,
payload,
key_pair,
trusted_keys,
next_stage: STAGE_PING,
last_message: None,
crypto: None,
ecdh_private_key: None,
algorithms,
failed_retries: 0,
close_time: 60
}
}
pub fn send_ping(&mut self, out: &mut MsgBuffer) -> Result<(), Error> {
// create ecdh ephemeral key
let (ecdh_private_key, ecdh_public_key) = self.create_ecdh_keypair();
self.ecdh_private_key = Some(ecdh_private_key);
// create stage 1 msg
self.send_message(STAGE_PING, Some(ecdh_public_key), out)?;
self.next_stage = STAGE_PONG;
Ok(())
}
pub fn stage(&self) -> u8 {
self.next_stage
}
pub fn every_second(&mut self, out: &mut MsgBuffer) -> Result<(), Error> {
if self.next_stage == WAITING_TO_CLOSE {
if self.close_time == 0 {
self.next_stage = CLOSING;
} else {
self.close_time -= 1;
}
Ok(())
} else if self.next_stage == CLOSING {
Ok(())
} else if self.failed_retries < 5 {
self.failed_retries += 1;
self.repeat_last_message(out)?;
Ok(())
} else {
Err(Error::CryptoInit("Initialization timeout"))
}
}
fn derive_master_key(&self, algo: &'static Algorithm, privk: EcdhPrivateKey, pubk: &EcdhPublicKey) -> LessSafeKey {
agree_ephemeral(privk, pubk, (), |k| {
UnboundKey::new(algo, &k[..algo.key_len()]).map(LessSafeKey::new).map_err(|_| ())
})
.unwrap()
}
fn create_ecdh_keypair(&self) -> (EcdhPrivateKey, EcdhPublicKey) {
let rand = SystemRandom::new();
let ecdh_private_key = EcdhPrivateKey::generate(&X25519, &rand).unwrap();
let public_key = ecdh_private_key.compute_public_key().unwrap();
let mut vec = SmallVec::<[u8; 96]>::new();
vec.extend_from_slice(public_key.as_ref());
let ecdh_public_key = EcdhPublicKey::new(&X25519, vec);
(ecdh_private_key, ecdh_public_key)
}
fn encrypt_payload(&mut self) -> Result<MsgBuffer, Error> {
let mut buffer = MsgBuffer::new(EXTRA_LEN);
self.payload.write_to(&mut buffer);
if let Some(crypto) = &mut self.crypto {
crypto.encrypt(&mut buffer);
}
Ok(buffer)
}
fn decrypt(&mut self, data: &mut MsgBuffer) -> Result<P, Error> {
if let Some(crypto) = &mut self.crypto {
crypto.decrypt(data)?;
}
Ok(P::read_from(Cursor::new(data.message()))?)
}
fn send_message(
&mut self, stage: u8, ecdh_public_key: Option<EcdhPublicKey>, out: &mut MsgBuffer
) -> Result<(), Error> {
debug!("Sending init with stage={}", stage);
assert!(out.is_empty());
let mut public_key = [0; ED25519_PUBLIC_KEY_LEN];
public_key.clone_from_slice(self.key_pair.as_ref().public_key().as_ref());
let msg = match stage {
STAGE_PING => {
InitMsg::Ping {
node_id: self.node_id,
ecdh_public_key: ecdh_public_key.unwrap(),
algorithms: self.algorithms.clone()
}
}
STAGE_PONG => {
InitMsg::Pong {
node_id: self.node_id,
ecdh_public_key: ecdh_public_key.unwrap(),
algorithms: self.algorithms.clone(),
encrypted_payload: self.encrypt_payload()?
}
}
STAGE_PENG => InitMsg::Peng { node_id: self.node_id, encrypted_payload: self.encrypt_payload()? },
_ => unreachable!()
};
let mut bytes = out.buffer();
let len = msg.write_to(&mut bytes, &self.key_pair).expect("Buffer too small");
self.last_message = Some(bytes[0..len].to_vec());
out.set_length(len);
Ok(())
}
fn repeat_last_message(&self, out: &mut MsgBuffer) -> Result<(), Error> {
if let Some(ref bytes) = self.last_message {
debug!("Repeating last init message");
let buffer = out.buffer();
buffer[0..bytes.len()].copy_from_slice(bytes);
out.set_length(bytes.len());
}
Ok(())
}
fn select_algorithm(&self, peer_algos: &Algorithms) -> Result<Option<(&'static Algorithm, f32)>, Error> {
if self.algorithms.allow_unencrypted && peer_algos.allow_unencrypted {
return Ok(None)
}
// For each supported algorithm, find the algorithm in the list of the peer (ignore algorithm if not found).
// Take the minimal speed reported by either us or the peer.
// Select the algorithm with the greatest minimal speed.
let algo = self
.algorithms
.algorithm_speeds
.iter()
.filter_map(|(a1, s1)| {
peer_algos
.algorithm_speeds
.iter()
.find(|(a2, _)| a1 == a2)
.map(|(_, s2)| (*a1, if s1 < s2 { *s1 } else { *s2 }))
})
.max_by(|(_, s1), (_, s2)| if s1 < s2 { cmp::Ordering::Less } else { cmp::Ordering::Greater });
if let Some(algo) = algo {
Ok(Some(algo))
} else {
Err(Error::CryptoInit("No common algorithms"))
}
}
pub fn handle_init(&mut self, out: &mut MsgBuffer) -> Result<InitResult<P>, Error> {
let (msg, _peer_key) = InitMsg::read_from(out.buffer(), &self.trusted_keys)?;
out.clear();
let stage = msg.stage();
let node_id = msg.node_id();
debug!("Received init with stage={}, expected stage={}", stage, self.next_stage);
if self.node_id == node_id {
return Err(Error::CryptoInit("Connected to self"))
}
if stage != self.next_stage {
if self.next_stage == STAGE_PONG && stage == STAGE_PING {
// special case for concurrent init messages in both directions
// the node with the higher node_id "wins" and gets to initialize the connection
if node_id > self.node_id {
// reset to initial state
self.next_stage = STAGE_PING;
self.last_message = None;
self.ecdh_private_key = None;
} else {
return Ok(InitResult::Continue)
}
} else if self.next_stage == CLOSING {
return Ok(InitResult::Continue)
} else if self.last_message.is_some() {
self.repeat_last_message(out)?;
return Ok(InitResult::Continue)
} else {
return Err(Error::CryptoInit("Received invalid stage as first message"))
}
}
self.failed_retries = 0;
match msg {
InitMsg::Ping { ecdh_public_key, algorithms, .. } => {
// create ecdh ephemeral key
let (my_ecdh_private_key, my_ecdh_public_key) = self.create_ecdh_keypair();
// do ecdh agreement and derive master key
let algorithm = self.select_algorithm(&algorithms)?;
if let Some((algorithm, _speed)) = algorithm {
let master_key = self.derive_master_key(algorithm, my_ecdh_private_key, &ecdh_public_key);
self.crypto = Some(CryptoCore::new(master_key, self.node_id > node_id));
}
// create and send stage 2 reply
self.send_message(STAGE_PONG, Some(my_ecdh_public_key), out)?;
self.next_stage = STAGE_PENG;
Ok(InitResult::Continue)
}
InitMsg::Pong { ecdh_public_key, algorithms, mut encrypted_payload, .. } => {
// do ecdh agreement and derive master key
let ecdh_private_key = self.ecdh_private_key.take().unwrap();
let algorithm = self.select_algorithm(&algorithms)?;
if let Some((algorithm, _speed)) = algorithm {
let master_key = self.derive_master_key(algorithm, ecdh_private_key, &ecdh_public_key);
self.crypto = Some(CryptoCore::new(master_key, self.node_id > node_id));
}
// decrypt the payload
let peer_payload =
self.decrypt(&mut encrypted_payload).map_err(|_| Error::CryptoInit("Failed to decrypt payload"))?;
// create and send stage 3 reply
self.send_message(STAGE_PENG, None, out)?;
self.next_stage = WAITING_TO_CLOSE;
self.close_time = 60;
Ok(InitResult::Success { peer_payload, node_id, is_initiator: true })
}
InitMsg::Peng { mut encrypted_payload, .. } => {
// decrypt the payload
let peer_payload =
self.decrypt(&mut encrypted_payload).map_err(|_| Error::CryptoInit("Failed to decrypt payload"))?;
self.next_stage = CLOSING; // force resend when receiving any message
Ok(InitResult::Success { peer_payload, node_id, is_initiator: false })
}
}
}
pub fn take_core(&mut self) -> Option<CryptoCore> {
self.crypto.take()
}
}
#[cfg(test)]
mod tests {
use super::*;
impl Payload for Vec<u8> {
fn write_to(&self, buffer: &mut MsgBuffer) {
buffer.buffer().write_all(&self).expect("Buffer too small");
buffer.set_length(self.len())
}
fn read_from<R: Read>(mut r: R) -> Result<Self, Error> {
let mut data = Vec::new();
r.read_to_end(&mut data).map_err(|_| Error::Parse("Buffer too small"))?;
Ok(data)
}
}
fn create_pair() -> (InitState<Vec<u8>>, InitState<Vec<u8>>) {
let rng = SystemRandom::new();
let pkcs8_bytes = Ed25519KeyPair::generate_pkcs8(&rng).unwrap();
let key_pair = Arc::new(Ed25519KeyPair::from_pkcs8(pkcs8_bytes.as_ref()).unwrap());
let mut public_key = [0; ED25519_PUBLIC_KEY_LEN];
public_key.clone_from_slice(key_pair.public_key().as_ref());
let trusted_nodes = Arc::new(vec![public_key]);
let mut node1 = [0; NODE_ID_BYTES];
rng.fill(&mut node1).unwrap();
let mut node2 = [0; NODE_ID_BYTES];
rng.fill(&mut node2).unwrap();
let algorithms = Algorithms {
algorithm_speeds: smallvec![(&AES_128_GCM, 600.0), (&AES_256_GCM, 500.0), (&CHACHA20_POLY1305, 400.0)],
allow_unencrypted: false
};
let sender = InitState::new(node1, vec![1], key_pair.clone(), trusted_nodes.clone(), algorithms.clone());
let receiver = InitState::new(node2, vec![2], key_pair, trusted_nodes, algorithms);
(sender, receiver)
}
#[test]
fn normal_init() {
let (mut sender, mut receiver) = create_pair();
let mut out = MsgBuffer::new(8);
sender.send_ping(&mut out).unwrap();
assert_eq!(sender.stage(), STAGE_PONG);
let result = receiver.handle_init(&mut out).unwrap();
assert_eq!(receiver.stage(), STAGE_PENG);
assert_eq!(result, InitResult::Continue);
let result = sender.handle_init(&mut out).unwrap();
assert_eq!(sender.stage(), WAITING_TO_CLOSE);
let result = match result {
InitResult::Success { .. } => receiver.handle_init(&mut out).unwrap(),
InitResult::Continue => unreachable!()
};
assert_eq!(receiver.stage(), CLOSING);
match result {
InitResult::Success { .. } => assert!(out.is_empty()),
InitResult::Continue => unreachable!()
}
}
// TODO Test: last message repeated when message is lost
// TODO Test: timeout after 5 retries
// TODO Test: duplicated message or replay attacks
// TODO Test: untrusted peers
// TODO Test: manipulated message
// TODO Test: algorithm negotiation
}

458
src/crypto/mod.rs Normal file
View File

@ -0,0 +1,458 @@
mod core;
mod init;
mod rotate;
pub use self::core::{EXTRA_LEN, TAG_LEN};
use self::{
core::{test_speed, CryptoCore},
init::{InitResult, InitState, CLOSING},
rotate::RotationState
};
use crate::{
error::Error,
types::NodeId,
util::{from_base62, to_base62, MsgBuffer}
};
use ring::{
aead::{self, Algorithm, LessSafeKey, UnboundKey},
agreement::{EphemeralPrivateKey, UnparsedPublicKey},
pbkdf2,
rand::{SecureRandom, SystemRandom},
signature::{Ed25519KeyPair, KeyPair, ED25519_PUBLIC_KEY_LEN}
};
use smallvec::{smallvec, SmallVec};
use std::{fmt::Debug, io::Read, num::NonZeroU32, sync::Arc, time::Duration};
use thiserror::Error;
const SALT: &[u8; 32] = b"vpncloudVPNCLOUDvpncl0udVpnCloud";
const INIT_MESSAGE_FIRST_BYTE: u8 = 0xff;
const MESSAGE_TYPE_ROTATION: u8 = 0x10;
pub type Ed25519PublicKey = [u8; ED25519_PUBLIC_KEY_LEN];
pub type EcdhPublicKey = UnparsedPublicKey<SmallVec<[u8; 96]>>;
pub type EcdhPrivateKey = EphemeralPrivateKey;
pub type Key = SmallVec<[u8; 32]>;
const DEFAULT_ALGORITHMS: [&str; 3] = ["AES128", "AES256", "CHACHA20"];
#[cfg(test)]
const SPEED_TEST_TIME: f32 = 0.02;
#[cfg(not(test))]
const SPEED_TEST_TIME: f32 = 0.1;
const ROTATE_INTERVAL: usize = 120;
pub trait Payload: Debug + PartialEq + Sized {
fn write_to(&self, buffer: &mut MsgBuffer);
fn read_from<R: Read>(r: R) -> Result<Self, Error>;
}
#[derive(Clone)]
pub struct Algorithms {
pub algorithm_speeds: SmallVec<[(&'static Algorithm, f32); 3]>,
pub allow_unencrypted: bool
}
#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
#[serde(rename_all = "kebab-case", deny_unknown_fields, default)]
pub struct Config {
pub password: Option<String>,
pub private_key: Option<String>,
pub public_key: Option<String>,
pub trusted_keys: Vec<String>,
pub algorithms: Vec<String>
}
pub struct Crypto {
node_id: NodeId,
key_pair: Arc<Ed25519KeyPair>,
trusted_keys: Arc<Vec<Ed25519PublicKey>>,
algorithms: Algorithms
}
impl Crypto {
pub fn new(node_id: NodeId, config: &Config) -> Result<Self, Error> {
let key_pair = if let Some(priv_key) = &config.private_key {
if let Some(pub_key) = &config.public_key {
Self::parse_keypair(priv_key, pub_key)?
} else {
Self::parse_private_key(priv_key)?
}
} else if let Some(password) = &config.password {
Self::keypair_from_password(password)
} else {
return Err(Error::InvalidConfig("Either private_key or password must be set"))
};
let mut trusted_keys = vec![];
for tn in &config.trusted_keys {
trusted_keys.push(Self::parse_public_key(tn)?);
}
if trusted_keys.is_empty() {
info!("Trusted keys not set, trusting only own public key");
let mut key = [0; ED25519_PUBLIC_KEY_LEN];
key.clone_from_slice(key_pair.public_key().as_ref());
trusted_keys.push(key);
}
let mut algos = Algorithms { algorithm_speeds: smallvec![], allow_unencrypted: false };
let algorithms = config.algorithms.iter().map(|a| a as &str).collect::<Vec<_>>();
let allowed = if algorithms.is_empty() { &DEFAULT_ALGORITHMS } else { &algorithms as &[&str] };
let duration = Duration::from_secs_f32(SPEED_TEST_TIME);
let mut speeds = Vec::new();
for name in allowed {
let algo = match &name.to_uppercase() as &str {
"UNENCRYPTED" | "NONE" | "PLAIN" => {
algos.allow_unencrypted = true;
warn!("Crypto settings allow unencrypted connections");
continue
}
"AES128" | "AES128_GCM" | "AES_128" | "AES_128_GCM" => &aead::AES_128_GCM,
"AES256" | "AES256_GCM" | "AES_256" | "AES_256_GCM" => &aead::AES_256_GCM,
"CHACHA" | "CHACHA20" | "CHACHA20_POLY1305" => &aead::CHACHA20_POLY1305,
_ => return Err(Error::InvalidConfig("Unknown crypto method"))
};
let speed = test_speed(algo, &duration);
algos.algorithm_speeds.push((algo, speed as f32));
speeds.push((name, speed as f32));
}
if !speeds.is_empty() {
info!(
"Crypto speeds: {}",
speeds.into_iter().map(|(a, s)| format!("{}: {:.1} MiB/s", a, s)).collect::<Vec<_>>().join(", ")
);
}
Ok(Self { node_id, key_pair: Arc::new(key_pair), trusted_keys: Arc::new(trusted_keys), algorithms: algos })
}
pub fn generate_keypair() -> (String, String) {
let rng = SystemRandom::new();
let mut bytes = [0; 32];
rng.fill(&mut bytes).unwrap();
let keypair = Ed25519KeyPair::from_seed_unchecked(&bytes).unwrap();
let privkey = to_base62(&bytes);
let pubkey = to_base62(keypair.public_key().as_ref());
(privkey, pubkey)
}
fn keypair_from_password(password: &str) -> Ed25519KeyPair {
let mut key = [0; 32];
pbkdf2::derive(pbkdf2::PBKDF2_HMAC_SHA256, NonZeroU32::new(4096).unwrap(), SALT, password.as_bytes(), &mut key);
Ed25519KeyPair::from_seed_unchecked(&key).unwrap()
}
fn parse_keypair(privkey: &str, pubkey: &str) -> Result<Ed25519KeyPair, Error> {
let privkey = from_base62(privkey).map_err(|_| Error::InvalidConfig("Failed to parse private key"))?;
let pubkey = from_base62(pubkey).map_err(|_| Error::InvalidConfig("Failed to parse public key"))?;
let keypair = Ed25519KeyPair::from_seed_and_public_key(&privkey, &pubkey)
.map_err(|_| Error::InvalidConfig("Keys rejected by crypto library"))?;
Ok(keypair)
}
fn parse_private_key(privkey: &str) -> Result<Ed25519KeyPair, Error> {
let privkey = from_base62(privkey).map_err(|_| Error::InvalidConfig("Failed to parse private key"))?;
let keypair = Ed25519KeyPair::from_seed_unchecked(&privkey)
.map_err(|_| Error::InvalidConfig("Key rejected by crypto library"))?;
Ok(keypair)
}
fn parse_public_key(pubkey: &str) -> Result<Ed25519PublicKey, Error> {
let pubkey = from_base62(pubkey).map_err(|_| Error::InvalidConfig("Failed to parse public key"))?;
if pubkey.len() != ED25519_PUBLIC_KEY_LEN {
return Err(Error::InvalidConfig("Failed to parse public key"))
}
let mut result = [0; ED25519_PUBLIC_KEY_LEN];
result.clone_from_slice(&pubkey);
Ok(result)
}
pub fn peer_instance<P: Payload>(&self, payload: P) -> PeerCrypto<P> {
PeerCrypto::new(
self.node_id,
payload,
self.key_pair.clone(),
self.trusted_keys.clone(),
self.algorithms.clone()
)
}
}
#[derive(Debug, PartialEq)]
pub enum MessageResult<P: Payload> {
Message(u8),
Initialized(NodeId, P),
InitializedWithReply(NodeId, P),
Reply,
None
}
pub struct PeerCrypto<P: Payload> {
#[allow(dead_code)]
node_id: NodeId,
init: Option<InitState<P>>,
rotation: Option<RotationState>,
unencrypted: bool,
core: Option<CryptoCore>,
rotate_counter: usize
}
impl<P: Payload> PeerCrypto<P> {
pub fn new(
node_id: NodeId, init_payload: P, key_pair: Arc<Ed25519KeyPair>, trusted_keys: Arc<Vec<Ed25519PublicKey>>,
algorithms: Algorithms
) -> Self
{
Self {
node_id,
init: Some(InitState::new(node_id, init_payload, key_pair, trusted_keys, algorithms)),
rotation: None,
unencrypted: false,
core: None,
rotate_counter: 0
}
}
fn get_init(&mut self) -> Result<&mut InitState<P>, Error> {
if let Some(init) = &mut self.init {
Ok(init)
} else {
Err(Error::InvalidCryptoState("Initialization already finished"))
}
}
fn get_core(&mut self) -> Result<&mut CryptoCore, Error> {
if let Some(core) = &mut self.core {
Ok(core)
} else {
Err(Error::InvalidCryptoState("Crypto core not ready yet"))
}
}
fn get_rotation(&mut self) -> Result<&mut RotationState, Error> {
if let Some(rotation) = &mut self.rotation {
Ok(rotation)
} else {
Err(Error::InvalidCryptoState("Key rotation not initialized"))
}
}
pub fn initialize(&mut self, out: &mut MsgBuffer) -> Result<(), Error> {
let init = self.get_init()?;
if init.stage() != init::STAGE_PING {
Err(Error::InvalidCryptoState("Initialization already ongoing"))
} else {
init.send_ping(out)?;
out.prepend_byte(INIT_MESSAGE_FIRST_BYTE);
Ok(())
}
}
pub fn is_ready(&self) -> bool {
self.core.is_some()
}
fn handle_init_message(&mut self, buffer: &mut MsgBuffer) -> Result<MessageResult<P>, Error> {
let result = self.get_init()?.handle_init(buffer)?;
if !buffer.is_empty() {
buffer.prepend_byte(INIT_MESSAGE_FIRST_BYTE);
}
match result {
InitResult::Continue => Ok(MessageResult::Reply),
InitResult::Success { peer_payload, node_id, is_initiator } => {
self.core = self.get_init()?.take_core();
if self.core.is_none() {
self.unencrypted = true;
}
if self.get_init()?.stage() == init::CLOSING {
self.init = None
}
if self.core.is_some() {
self.rotation = Some(RotationState::new(!is_initiator, buffer)?);
}
if !is_initiator {
if self.unencrypted {
return Ok(MessageResult::Initialized(node_id, peer_payload))
}
assert!(!buffer.is_empty());
buffer.prepend_byte(MESSAGE_TYPE_ROTATION);
self.encrypt_message(buffer)?;
}
Ok(MessageResult::InitializedWithReply(node_id, peer_payload))
}
}
}
fn handle_rotate_message(&mut self, data: &[u8]) -> Result<(), Error> {
if self.unencrypted {
return Ok(())
}
if let Some(rot) = self.get_rotation()?.handle_message(data)? {
let core = self.get_core()?;
let algo = core.algorithm();
let key = LessSafeKey::new(UnboundKey::new(algo, &rot.key[..algo.key_len()]).unwrap());
core.rotate_key(key, rot.id, rot.use_for_sending);
}
Ok(())
}
fn encrypt_message(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
if self.unencrypted {
return Ok(())
}
self.get_core()?.encrypt(buffer);
Ok(())
}
fn decrypt_message(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
if self.unencrypted {
return Ok(())
}
self.get_core()?.decrypt(buffer)
}
pub fn handle_message(&mut self, buffer: &mut MsgBuffer) -> Result<MessageResult<P>, Error> {
if buffer.is_empty() {
return Err(Error::InvalidCryptoState("No message in buffer"))
}
if is_init_message(buffer.buffer()) {
debug!("Received init message");
buffer.take_prefix();
self.handle_init_message(buffer)
} else {
debug!("Received encrypted message");
self.decrypt_message(buffer)?;
let msg_type = buffer.take_prefix();
if msg_type == MESSAGE_TYPE_ROTATION {
debug!("Received rotation message");
self.handle_rotate_message(buffer.buffer())?;
buffer.clear();
Ok(MessageResult::None)
} else {
Ok(MessageResult::Message(msg_type))
}
}
}
pub fn send_message(&mut self, type_: u8, buffer: &mut MsgBuffer) -> Result<(), Error> {
assert_ne!(type_, MESSAGE_TYPE_ROTATION);
buffer.prepend_byte(type_);
self.encrypt_message(buffer)
}
pub fn every_second(&mut self, out: &mut MsgBuffer) -> Result<MessageResult<P>, Error> {
out.clear();
if let Some(ref mut core) = self.core {
core.every_second()
}
if let Some(ref mut init) = self.init {
init.every_second(out)?;
}
if self.init.as_ref().map(|i| i.stage()).unwrap_or(CLOSING) == CLOSING {
self.init = None
}
if !out.is_empty() {
out.prepend_byte(INIT_MESSAGE_FIRST_BYTE);
return Ok(MessageResult::Reply)
}
if let Some(ref mut rotate) = self.rotation {
self.rotate_counter += 1;
if self.rotate_counter >= ROTATE_INTERVAL {
self.rotate_counter = 0;
if let Some(rot) = rotate.cycle(out)? {
let core = self.get_core()?;
let algo = core.algorithm();
let key = LessSafeKey::new(UnboundKey::new(algo, &rot.key[..algo.key_len()]).unwrap());
core.rotate_key(key, rot.id, rot.use_for_sending);
}
if !out.is_empty() {
out.prepend_byte(MESSAGE_TYPE_ROTATION);
self.encrypt_message(out)?;
return Ok(MessageResult::Reply)
}
}
}
Ok(MessageResult::None)
}
}
pub fn is_init_message(msg: &[u8]) -> bool {
!msg.is_empty() && msg[0] == INIT_MESSAGE_FIRST_BYTE
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::NODE_ID_BYTES;
fn create_node(config: &Config) -> PeerCrypto<Vec<u8>> {
let rng = SystemRandom::new();
let mut node_id = [0; NODE_ID_BYTES];
rng.fill(&mut node_id).unwrap();
let crypto = Crypto::new(node_id, config).unwrap();
crypto.peer_instance(vec![])
}
#[test]
fn normal() {
let config = Config { password: Some("test".to_string()), ..Default::default() };
let mut node1 = create_node(&config);
let mut node2 = create_node(&config);
let mut msg = MsgBuffer::new(16);
node1.initialize(&mut msg).unwrap();
assert!(!msg.is_empty());
debug!("Node1 -> Node2");
let res = node2.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::Reply);
assert!(!msg.is_empty());
debug!("Node1 <- Node2");
let res = node1.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::InitializedWithReply(node2.node_id, vec![]));
assert!(!msg.is_empty());
debug!("Node1 -> Node2");
let res = node2.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::InitializedWithReply(node1.node_id, vec![]));
assert!(!msg.is_empty());
debug!("Node1 <- Node2");
let res = node1.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::None);
assert!(msg.is_empty());
let mut buffer = MsgBuffer::new(16);
let rng = SystemRandom::new();
buffer.set_length(1000);
rng.fill(buffer.message_mut()).unwrap();
for _ in 0..1000 {
node1.send_message(1, &mut buffer).unwrap();
let res = node2.handle_message(&mut buffer).unwrap();
assert_eq!(res, MessageResult::Message(1));
match node1.every_second(&mut msg).unwrap() {
MessageResult::None => (),
MessageResult::Reply => {
let res = node2.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::None);
}
other => assert_eq!(other, MessageResult::None)
}
match node2.every_second(&mut msg).unwrap() {
MessageResult::None => (),
MessageResult::Reply => {
let res = node1.handle_message(&mut msg).unwrap();
assert_eq!(res, MessageResult::None);
}
other => assert_eq!(other, MessageResult::None)
}
}
}
}

326
src/crypto/rotate.rs Normal file
View File

@ -0,0 +1,326 @@
//! This module implements a turn based key rotation.
//!
//! The main idea is that both peers periodically create ecdh key pairs and exchange their public keys to create
//! common key material. There are always two separate ecdh handshakes going on: one initiated by each peer.
//! However, one handshake is always one step ahead of the other. That means that every message being sent contains a
//! public key from step 1 of the handshake "proposed key" and a public key from step 2 of the handshake "confirmed
//! key" (all messages except first message).
//!
//! When receiving a message from the peer, the node will create a new ecdh key pair and perform the key
//! calculation for the proposed key. The peer will store the public key for the confirmation as pending to be
//! confirmed in the next cycle. Also, if the message contains a confirmation (all but the very first message do),
//! the node will use the stored private key to perform the ecdh key calculation and emit that key to be used in
//! the crypto stream.
//!
//! Upon each cycle, a node first checks if it still has a proposed key that has not been confirmed by the remote
//! peer. If so, a message must have been lost and the whole last message including the proposed key as well as the
//! last confirmed key is being resent. If no proposed key is stored, the node will create a new ecdh key pair, and
//! store the private key as proposed key. It then sends out a message containing the public key as proposal, as
//! well as confirming the pending key. This key is also emitted to be added to the crypto stream but not to be
//! used for encrypting.
//!
//! Monotonically increasing message ids guard the communication from message duplication and also serve as
//! identifiers for the keys to be used in the crypto stream. Since the keys are rotating, the last 2 bits of the
//! id are enough to identify the key.
//!
//! The whole communication is sent via the crypto stream and is therefore encrypted and protected against tampering.
use super::{Error, Key, MsgBuffer};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use ring::{
agreement::{agree_ephemeral, EphemeralPrivateKey, UnparsedPublicKey, X25519},
rand::SystemRandom
};
use smallvec::{smallvec, SmallVec};
use std::io::{self, Cursor, Read, Write};
type EcdhPublicKey = UnparsedPublicKey<SmallVec<[u8; 96]>>;
type EcdhPrivateKey = EphemeralPrivateKey;
pub struct RotationMessage {
message_id: u64,
propose: EcdhPublicKey,
confirm: Option<EcdhPublicKey>
}
impl RotationMessage {
#[allow(dead_code)]
pub fn read_from<R: Read>(mut r: R) -> Result<Self, io::Error> {
let message_id = r.read_u64::<NetworkEndian>()?;
let key_len = r.read_u8()? as usize;
let mut key_data = smallvec![0; key_len];
r.read_exact(&mut key_data)?;
let propose = EcdhPublicKey::new(&X25519, key_data);
let key_len = r.read_u8()? as usize;
let confirm = if key_len > 0 {
let mut key_data = smallvec![0; key_len];
r.read_exact(&mut key_data)?;
Some(EcdhPublicKey::new(&X25519, key_data))
} else {
None
};
Ok(RotationMessage { message_id, propose, confirm })
}
#[allow(dead_code)]
pub fn write_to<W: Write>(&self, mut w: W) -> Result<(), io::Error> {
w.write_u64::<NetworkEndian>(self.message_id)?;
let key_bytes = self.propose.bytes();
w.write_u8(key_bytes.len() as u8)?;
w.write_all(key_bytes)?;
if let Some(ref key) = self.confirm {
let key_bytes = key.bytes();
w.write_u8(key_bytes.len() as u8)?;
w.write_all(key_bytes)?;
} else {
w.write_u8(0)?;
}
Ok(())
}
}
pub struct RotationState {
confirmed: Option<(EcdhPublicKey, u64)>, // sent by remote, already confirmed
pending: Option<(Key, EcdhPublicKey)>, // sent by remote, to be confirmed
proposed: Option<EcdhPrivateKey>, // my own, proposed but not confirmed
message_id: u64,
timeout: bool
}
pub struct RotatedKey {
pub key: Key,
pub id: u64,
pub use_for_sending: bool
}
impl RotationState {
#[allow(dead_code)]
pub fn new(initiator: bool, out: &mut MsgBuffer) -> Result<Self, Error> {
if initiator {
let (private_key, public_key) = Self::create_key();
Self::send(&RotationMessage { message_id: 1, confirm: None, propose: public_key }, out)?;
Ok(Self { confirmed: None, pending: None, proposed: Some(private_key), message_id: 1, timeout: false })
} else {
Ok(Self { confirmed: None, pending: None, proposed: None, message_id: 0, timeout: false })
}
}
fn send(msg: &RotationMessage, out: &mut MsgBuffer) -> Result<(), Error> {
assert!(out.is_empty());
let len;
{
let mut cursor = Cursor::new(out.buffer());
msg.write_to(&mut cursor).expect("Buffer too small");
len = cursor.position() as usize;
}
out.set_length(len);
Ok(())
}
fn create_key() -> (EcdhPrivateKey, EcdhPublicKey) {
let rand = SystemRandom::new();
let private_key = EcdhPrivateKey::generate(&X25519, &rand).unwrap();
let public_key = Self::compute_public_key(&private_key);
(private_key, public_key)
}
fn compute_public_key(private_key: &EcdhPrivateKey) -> EcdhPublicKey {
let public_key = private_key.compute_public_key().unwrap();
let mut vec = SmallVec::<[u8; 96]>::new();
vec.extend_from_slice(public_key.as_ref());
EcdhPublicKey::new(&X25519, vec)
}
fn derive_key(private_key: EcdhPrivateKey, public_key: EcdhPublicKey) -> Key {
agree_ephemeral(private_key, &public_key, (), |k| {
let mut vec = Key::new();
vec.extend_from_slice(k);
Ok(vec)
})
.unwrap()
}
pub fn handle_message(&mut self, msg: &[u8]) -> Result<Option<RotatedKey>, Error> {
let msg =
RotationMessage::read_from(Cursor::new(msg)).map_err(|_| Error::Crypto("Rotation message too short"))?;
Ok(self.process_message(msg))
}
pub fn process_message(&mut self, msg: RotationMessage) -> Option<RotatedKey> {
if msg.message_id <= self.message_id {
return None
}
self.timeout = false;
// Create key from proposal and store reply as pending
let (private_key, public_key) = Self::create_key();
let key = Self::derive_key(private_key, msg.propose);
self.pending = Some((key, public_key));
// If proposed key has been confirmed, derive and use key
if let Some(peer_key) = msg.confirm {
if let Some(private_key) = self.proposed.take() {
let key = Self::derive_key(private_key, peer_key);
return Some(RotatedKey { key, id: msg.message_id, use_for_sending: true })
}
}
None
}
#[allow(dead_code)]
pub fn cycle(&mut self, out: &mut MsgBuffer) -> Result<Option<RotatedKey>, Error> {
if let Some(ref private_key) = self.proposed {
// Still a proposed key that has not been confirmed, proposal must have been lost
if self.timeout {
let proposed_key = Self::compute_public_key(&private_key);
if let Some((ref confirmed_key, message_id)) = self.confirmed {
// Reconfirm last confirmed key
Self::send(
&RotationMessage { confirm: Some(confirmed_key.clone()), propose: proposed_key, message_id },
out
)?;
} else {
// First message has been lost
Self::send(&RotationMessage { confirm: None, propose: proposed_key, message_id: 1 }, out)?;
}
} else {
self.timeout = true;
}
} else {
// No proposed key, our turn to propose a new one
if let Some((key, confirm_key)) = self.pending.take() {
// Send out pending confirmation and register key for receiving
self.message_id += 2;
let message_id = self.message_id;
let (private_key, propose_key) = Self::create_key();
self.proposed = Some(private_key);
self.confirmed = Some((confirm_key.clone(), message_id));
Self::send(&RotationMessage { confirm: Some(confirm_key), propose: propose_key, message_id }, out)?;
return Ok(Some(RotatedKey { key, id: message_id, use_for_sending: false }))
} else {
// Nothing pending nor proposed, still waiting to receive message 1
// Do nothing, peer will retry
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
impl MsgBuffer {
fn msg(&mut self) -> Option<RotationMessage> {
if self.is_empty() {
return None
}
let msg = RotationMessage::read_from(Cursor::new(self.message())).unwrap();
self.set_length(0);
Some(msg)
}
}
#[test]
fn test_encode_decode_message() {
let mut data = Vec::with_capacity(100);
let (_, key) = RotationState::create_key();
let msg = RotationMessage { message_id: 1, propose: key, confirm: None };
msg.write_to(&mut data).unwrap();
let msg2 = RotationMessage::read_from(Cursor::new(&data)).unwrap();
assert_eq!(msg.message_id, msg2.message_id);
assert_eq!(msg.propose.bytes(), msg2.propose.bytes());
assert_eq!(msg.confirm.map(|v| v.bytes().to_vec()), msg2.confirm.map(|v| v.bytes().to_vec()));
let mut data = Vec::with_capacity(100);
let (_, key1) = RotationState::create_key();
let (_, key2) = RotationState::create_key();
let msg = RotationMessage { message_id: 2, propose: key1, confirm: Some(key2) };
msg.write_to(&mut data).unwrap();
let msg2 = RotationMessage::read_from(Cursor::new(&data)).unwrap();
assert_eq!(msg.message_id, msg2.message_id);
assert_eq!(msg.propose.bytes(), msg2.propose.bytes());
assert_eq!(msg.confirm.map(|v| v.bytes().to_vec()), msg2.confirm.map(|v| v.bytes().to_vec()));
}
#[test]
fn test_normal_rotation() {
let mut out1 = MsgBuffer::new(8);
let mut out2 = MsgBuffer::new(8);
// Initialization
let mut node1 = RotationState::new(true, &mut out1).unwrap();
let mut node2 = RotationState::new(false, &mut out2).unwrap();
assert!(!out1.is_empty());
let msg1 = out1.msg().unwrap();
assert_eq!(msg1.message_id, 1);
assert!(out2.is_empty());
// Message 1
let key = node2.process_message(msg1);
assert!(key.is_none());
// Cycle 1
let key1 = node1.cycle(&mut out1).unwrap();
let key2 = node2.cycle(&mut out2).unwrap();
assert!(key1.is_none());
assert!(out1.is_empty());
assert!(key2.is_some());
let key2 = key2.unwrap();
assert_eq!(key2.id, 2);
assert_eq!(key2.use_for_sending, false);
assert!(!out2.is_empty());
let msg2 = out2.msg().unwrap();
assert_eq!(msg2.message_id, 2);
assert!(msg2.confirm.is_some());
// Message 2
let key = node1.process_message(msg2);
assert!(key.is_some());
let key = key.unwrap();
assert_eq!(key.id, 2);
assert_eq!(key.use_for_sending, true);
// Cycle 2
let key1 = node1.cycle(&mut out1).unwrap();
let key2 = node2.cycle(&mut out2).unwrap();
assert!(key1.is_some());
let key1 = key1.unwrap();
assert_eq!(key1.id, 3);
assert_eq!(key1.use_for_sending, false);
assert!(!out1.is_empty());
let msg1 = out1.msg().unwrap();
assert_eq!(msg1.message_id, 3);
assert!(msg1.confirm.is_some());
assert!(key2.is_none());
assert!(out2.is_empty());
// Message 3
let key = node2.process_message(msg1);
assert!(key.is_some());
let key = key.unwrap();
assert_eq!(key.id, 3);
assert_eq!(key.use_for_sending, true);
// Cycle 3
let key1 = node1.cycle(&mut out1).unwrap();
let key2 = node2.cycle(&mut out2).unwrap();
assert!(key1.is_none());
assert!(out1.is_empty());
assert!(key2.is_some());
let key2 = key2.unwrap();
assert_eq!(key2.id, 4);
assert_eq!(key2.use_for_sending, false);
assert!(!out2.is_empty());
let msg2 = out2.msg().unwrap();
assert_eq!(msg2.message_id, 4);
assert!(msg2.confirm.is_some());
// Message 4
let key = node1.process_message(msg2);
assert!(key.is_some());
let key = key.unwrap();
assert_eq!(key.id, 4);
assert_eq!(key.use_for_sending, true);
}
// TODO: test duplication
// TODO: test lost message
// TODO: test potential attack: reflect message back to sender
}

View File

@ -2,39 +2,43 @@
// Copyright (C) 2015-2020 Dennis Schwerdel // Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
use libc::{c_short, c_ulong, ioctl, IFF_NO_PI, IFF_TAP, IFF_TUN, IF_NAMESIZE};
use std::{ use std::{
cmp,
collections::VecDeque, collections::VecDeque,
fmt, fs, fmt,
io::{self, Error as IoError, ErrorKind, Read, Write}, fs::{self, File},
io::{self, BufRead, BufReader, Cursor, Error as IoError, Read, Write},
net::{Ipv4Addr, UdpSocket},
os::unix::io::{AsRawFd, RawFd}, os::unix::io::{AsRawFd, RawFd},
str, str,
str::FromStr str::FromStr
}; };
use super::types::Error; use crate::{crypto, error::Error, util::MsgBuffer};
static TUNSETIFF: c_ulong = 1074025674; static TUNSETIFF: libc::c_ulong = 1074025674;
#[repr(C)] #[repr(C)]
union IfReqData { union IfReqData {
flags: c_short, flags: libc::c_short,
value: libc::c_int,
addr: (libc::c_short, Ipv4Addr),
_dummy: [u8; 24] _dummy: [u8; 24]
} }
#[repr(C)] #[repr(C)]
struct IfReq { struct IfReq {
ifr_name: [u8; IF_NAMESIZE], ifr_name: [u8; libc::IF_NAMESIZE],
data: IfReqData data: IfReqData
} }
impl IfReq { impl IfReq {
fn new(name: &str, flags: c_short) -> Self { fn new(name: &str) -> Self {
assert!(name.len() < IF_NAMESIZE); assert!(name.len() < libc::IF_NAMESIZE);
let mut ifr_name = [0 as u8; IF_NAMESIZE]; let mut ifr_name = [0 as u8; libc::IF_NAMESIZE];
ifr_name[..name.len()].clone_from_slice(name.as_bytes()); ifr_name[..name.len()].clone_from_slice(name.as_bytes());
Self { ifr_name, data: IfReqData { flags } } Self { ifr_name, data: IfReqData { _dummy: [0; 24] } }
} }
} }
@ -94,7 +98,7 @@ pub trait Device: AsRawFd {
/// ///
/// # Errors /// # Errors
/// This method will return an error if the underlying read call fails. /// This method will return an error if the underlying read call fails.
fn read(&mut self, buffer: &mut [u8]) -> Result<(usize, usize), Error>; fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>;
/// Writes a packet/frame to the device /// Writes a packet/frame to the device
/// ///
@ -105,13 +109,15 @@ pub trait Device: AsRawFd {
/// ///
/// # Errors /// # Errors
/// This method will return an error if the underlying read call fails. /// This method will return an error if the underlying read call fails.
fn write(&mut self, data: &mut [u8], start: usize) -> Result<(), Error>; fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error>;
fn get_ip(&self) -> Result<Ipv4Addr, Error>;
} }
/// Represents a tun/tap device /// Represents a tun/tap device
pub struct TunTapDevice { pub struct TunTapDevice {
fd: fs::File, fd: File,
ifname: String, ifname: String,
type_: Type type_: Type
} }
@ -142,16 +148,19 @@ impl TunTapDevice {
} }
let fd = fs::OpenOptions::new().read(true).write(true).open(path)?; let fd = fs::OpenOptions::new().read(true).write(true).open(path)?;
let flags = match type_ { let flags = match type_ {
Type::Tun => IFF_TUN | IFF_NO_PI, Type::Tun => libc::IFF_TUN | libc::IFF_NO_PI,
Type::Tap => IFF_TAP | IFF_NO_PI, Type::Tap => libc::IFF_TAP | libc::IFF_NO_PI,
Type::Dummy => unreachable!() Type::Dummy => unreachable!()
}; };
let mut ifreq = IfReq::new(ifname, flags as c_short); let mut ifreq = IfReq::new(ifname);
let res = unsafe { ioctl(fd.as_raw_fd(), TUNSETIFF, &mut ifreq) }; ifreq.data.flags = flags as libc::c_short;
let res = unsafe { libc::ioctl(fd.as_raw_fd(), TUNSETIFF, &mut ifreq) };
match res { match res {
0 => { 0 => {
let nul_range_end = ifreq.ifr_name.iter().position(|&c| c == b'\0').unwrap_or(ifreq.ifr_name.len()); let mut ifname = String::with_capacity(32);
let ifname = unsafe { str::from_utf8_unchecked(&ifreq.ifr_name[0..nul_range_end]) }.to_string(); let mut cursor = Cursor::new(ifreq.ifr_name);
cursor.read_to_string(&mut ifname)?;
ifname = ifname.trim_end_matches('\0').to_owned();
Ok(Self { fd, ifname, type_ }) Ok(Self { fd, ifname, type_ })
} }
_ => Err(IoError::last_os_error()) _ => Err(IoError::last_os_error())
@ -191,9 +200,7 @@ impl TunTapDevice {
#[cfg(any(target_os = "linux", target_os = "android"))] #[cfg(any(target_os = "linux", target_os = "android"))]
#[inline] #[inline]
fn correct_data_after_read(&mut self, _buffer: &mut [u8], start: usize, read: usize) -> (usize, usize) { fn correct_data_after_read(&mut self, _buffer: &mut MsgBuffer) {}
(start, read)
}
#[cfg(any( #[cfg(any(
target_os = "bitrig", target_os = "bitrig",
@ -205,21 +212,17 @@ impl TunTapDevice {
target_os = "openbsd" target_os = "openbsd"
))] ))]
#[inline] #[inline]
fn correct_data_after_read(&mut self, buffer: &mut [u8], start: usize, read: usize) -> (usize, usize) { fn correct_data_after_read(&mut self, buffer: &mut MsgBuffer) {
if self.type_ == Type::Tun { if self.type_ == Type::Tun {
// BSD-based systems add a 4-byte header containing the Ethertype for TUN // BSD-based systems add a 4-byte header containing the Ethertype for TUN
assert!(read >= 4); buffer.set_start(buffer.get_start() + 4);
(start + 4, read - 4)
} else { } else {
(start, read)
} }
} }
#[cfg(any(target_os = "linux", target_os = "android"))] #[cfg(any(target_os = "linux", target_os = "android"))]
#[inline] #[inline]
fn correct_data_before_write(&mut self, _buffer: &mut [u8], start: usize) -> usize { fn correct_data_before_write(&mut self, _buffer: &mut MsgBuffer) {}
start
}
#[cfg(any( #[cfg(any(
target_os = "bitrig", target_os = "bitrig",
@ -231,21 +234,63 @@ impl TunTapDevice {
target_os = "openbsd" target_os = "openbsd"
))] ))]
#[inline] #[inline]
fn correct_data_before_write(&mut self, buffer: &mut [u8], start: usize) -> usize { fn correct_data_before_write(&mut self, buffer: &mut MsgBuffer) {
if self.type_ == Type::Tun { if self.type_ == Type::Tun {
// BSD-based systems add a 4-byte header containing the Ethertype for TUN // BSD-based systems add a 4-byte header containing the Ethertype for TUN
assert!(start >= 4); buffer.set_start(buffer.get_start() - 4);
match buffer[start] >> 4 { match buffer.message()[4] >> 4 {
// IP version // IP version
4 => buffer[start - 4..start].copy_from_slice(&[0x00, 0x00, 0x08, 0x00]), 4 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x08, 0x00]),
6 => buffer[start - 4..start].copy_from_slice(&[0x00, 0x00, 0x86, 0xdd]), 6 => buffer.message_mut()[0..4].copy_from_slice(&[0x00, 0x00, 0x86, 0xdd]),
_ => unreachable!() _ => unreachable!()
} }
start - 4
} else {
start
} }
} }
pub fn get_overhead(&self) -> usize {
40 /* for outer IPv6 header, can't be sure to only have IPv4 peers */
+ 8 /* for outer UDP header */
+ crypto::EXTRA_LEN + crypto::TAG_LEN /* crypto overhead */
+ 1 /* message type header */
+ match self.type_ {
Type::Tap => 12, /* inner ethernet header */
Type::Tun | Type::Dummy => 0
}
}
pub fn set_mtu(&self, value: Option<usize>) -> io::Result<()> {
let value = match value {
Some(value) => value,
None => {
let default_device = get_default_device()?;
get_device_mtu(&default_device)? - self.get_overhead()
}
};
info!("Setting MTU {} on device {}", value, self.ifname);
set_device_mtu(&self.ifname, value)
}
pub fn configure(&self, addr: Ipv4Addr, netmask: Ipv4Addr) -> io::Result<()> {
set_device_addr(&self.ifname, addr)?;
set_device_netmask(&self.ifname, netmask)?;
set_device_enabled(&self.ifname, true)
}
pub fn get_rp_filter(&self) -> io::Result<u8> {
Ok(cmp::max(get_rp_filter("all")?, get_rp_filter(&self.ifname)?))
}
pub fn fix_rp_filter(&self) -> io::Result<()> {
if get_rp_filter("all")? > 1 {
info!("Setting net.ipv4.conf.all.rp_filter=1");
set_rp_filter("all", 1)?
}
if get_rp_filter(&self.ifname)? != 1 {
info!("Setting net.ipv4.conf.{}.rp_filter=1", self.ifname);
set_rp_filter(&self.ifname, 1)?
}
Ok(())
}
} }
impl Device for TunTapDevice { impl Device for TunTapDevice {
@ -257,19 +302,25 @@ impl Device for TunTapDevice {
&self.ifname &self.ifname
} }
fn read(&mut self, mut buffer: &mut [u8]) -> Result<(usize, usize), Error> { fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
let read = self.fd.read(&mut buffer).map_err(|e| Error::TunTapDev("Read error", e))?; buffer.clear();
let (start, read) = self.correct_data_after_read(&mut buffer, 0, read); let read = self.fd.read(buffer.buffer()).map_err(|e| Error::DeviceIo("Read error", e))?;
Ok((start, read)) buffer.set_length(read);
self.correct_data_after_read(buffer);
Ok(())
} }
fn write(&mut self, mut data: &mut [u8], start: usize) -> Result<(), Error> { fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
let start = self.correct_data_before_write(&mut data, start); self.correct_data_before_write(buffer);
match self.fd.write_all(&data[start..]) { match self.fd.write_all(buffer.message()) {
Ok(_) => self.fd.flush().map_err(|e| Error::TunTapDev("Flush error", e)), Ok(_) => self.fd.flush().map_err(|e| Error::DeviceIo("Flush error", e)),
Err(e) => Err(Error::TunTapDev("Write error", e)) Err(e) => Err(Error::DeviceIo("Write error", e))
} }
} }
fn get_ip(&self) -> Result<Ipv4Addr, Error> {
get_device_addr(&self.ifname).map_err(|e| Error::DeviceIo("Error getting IP address", e))
}
} }
impl AsRawFd for TunTapDevice { impl AsRawFd for TunTapDevice {
@ -312,19 +363,25 @@ impl Device for MockDevice {
unimplemented!() unimplemented!()
} }
fn read(&mut self, buffer: &mut [u8]) -> Result<(usize, usize), Error> { fn read(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
if let Some(data) = self.inbound.pop_front() { if let Some(data) = self.inbound.pop_front() {
buffer[0..data.len()].copy_from_slice(&data); buffer.clear();
Ok((0, data.len())) buffer.set_length(data.len());
buffer.message_mut().copy_from_slice(&data);
Ok(())
} else { } else {
Err(Error::TunTapDev("empty", io::Error::from(ErrorKind::UnexpectedEof))) Err(Error::Device("empty"))
} }
} }
fn write(&mut self, data: &mut [u8], start: usize) -> Result<(), Error> { fn write(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> {
self.outbound.push_back(data[start..].to_owned()); self.outbound.push_back(buffer.message().to_owned());
Ok(()) Ok(())
} }
fn get_ip(&self) -> Result<Ipv4Addr, Error> {
Err(Error::Device("Dummy devices have no IP address"))
}
} }
impl Default for MockDevice { impl Default for MockDevice {
@ -339,3 +396,134 @@ impl AsRawFd for MockDevice {
unimplemented!() unimplemented!()
} }
} }
fn set_device_mtu(ifname: &str, mtu: usize) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
ifreq.data.value = mtu as libc::c_int;
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU, &mut ifreq) };
match res {
0 => Ok(()),
_ => Err(IoError::last_os_error())
}
}
fn get_device_mtu(ifname: &str) -> io::Result<usize> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFMTU, &mut ifreq) };
match res {
0 => Ok(unsafe { ifreq.data.value as usize }),
_ => Err(IoError::last_os_error())
}
}
fn get_device_addr(ifname: &str) -> io::Result<Ipv4Addr> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFADDR, &mut ifreq) };
match res {
0 => {
let af = unsafe { ifreq.data.addr.0 };
if af as libc::c_int != libc::AF_INET {
return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned()))
}
let ip = unsafe { ifreq.data.addr.1 };
Ok(ip)
}
_ => Err(IoError::last_os_error())
}
}
fn set_device_addr(ifname: &str, addr: Ipv4Addr) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
ifreq.data.addr = (libc::AF_INET as libc::c_short, addr);
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFADDR, &mut ifreq) };
match res {
0 => Ok(()),
_ => Err(IoError::last_os_error())
}
}
#[allow(dead_code)]
fn get_device_netmask(ifname: &str) -> io::Result<Ipv4Addr> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFNETMASK, &mut ifreq) };
match res {
0 => {
let af = unsafe { ifreq.data.addr.0 };
if af as libc::c_int != libc::AF_INET {
return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid address family".to_owned()))
}
let ip = unsafe { ifreq.data.addr.1 };
Ok(ip)
}
_ => Err(IoError::last_os_error())
}
}
fn set_device_netmask(ifname: &str, addr: Ipv4Addr) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
ifreq.data.addr = (libc::AF_INET as libc::c_short, addr);
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFNETMASK, &mut ifreq) };
match res {
0 => Ok(()),
_ => Err(IoError::last_os_error())
}
}
fn set_device_enabled(ifname: &str, up: bool) -> io::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:0")?;
let mut ifreq = IfReq::new(ifname);
if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCGIFFLAGS, &mut ifreq) } != 0 {
return Err(IoError::last_os_error())
}
if up {
unsafe { ifreq.data.value |= libc::IFF_UP | libc::IFF_RUNNING }
} else {
unsafe { ifreq.data.value &= !libc::IFF_UP }
}
let res = unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFFLAGS, &mut ifreq) };
match res {
0 => Ok(()),
_ => Err(IoError::last_os_error())
}
}
fn get_default_device() -> io::Result<String> {
let fd = BufReader::new(File::open("/proc/net/route")?);
let mut best = None;
for line in fd.lines() {
let line = line?;
let parts = line.split('\t').collect::<Vec<_>>();
if parts[1] == "00000000" {
best = Some(parts[0].to_string());
break
}
if parts[2] != "00000000" {
best = Some(parts[0].to_string())
}
}
if let Some(ifname) = best {
Ok(ifname)
} else {
Err(io::Error::new(io::ErrorKind::NotFound, "No default interface found".to_string()))
}
}
fn get_rp_filter(device: &str) -> io::Result<u8> {
let mut fd = File::open(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?;
let mut contents = String::with_capacity(10);
fd.read_to_string(&mut contents)?;
u8::from_str(contents.trim()).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid rp_filter value"))
}
fn set_rp_filter(device: &str, val: u8) -> io::Result<()> {
let mut fd = File::create(format!("/proc/sys/net/ipv4/conf/{}/rp_filter", device))?;
writeln!(fd, "{}", val)
}

49
src/error.rs Normal file
View File

@ -0,0 +1,49 @@
use thiserror::Error;
use std::io;
#[derive(Error, Debug)]
pub enum Error {
#[error("Unauthorized message: {0}")]
Unauthorized(&'static str),
#[error("Crypto initialization error: {0}")]
CryptoInit(&'static str),
#[error("Crypto error: {0}")]
Crypto(&'static str),
#[error("Invalid crypto state: {0}")]
InvalidCryptoState(&'static str),
#[error("Invalid config: {0}")]
InvalidConfig(&'static str),
#[error("Socker error: {0}")]
Socket(&'static str),
#[error("Socker error: {0}")]
SocketIo(&'static str, #[source] io::Error),
#[error("Device error: {0}")]
Device(&'static str),
#[error("Device error: {0}")]
DeviceIo(&'static str, #[source] io::Error),
#[error("File error: {0}")]
FileIo(&'static str, #[source] io::Error),
#[error("Message error: {0}")]
Message(&'static str),
#[error("Beacon error: {0}")]
BeaconIo(&'static str, #[source] io::Error),
#[error("Parse error: {0}")]
Parse(&'static str),
#[error("Name can not be resolved: {0}")]
NameUnresolvable(String)
}

View File

@ -1,228 +0,0 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use std::{
collections::{hash_map::Entry, HashMap},
hash::BuildHasherDefault,
io::{self, Write},
marker::PhantomData,
net::SocketAddr
};
use fnv::FnvHasher;
use super::{
types::{Address, Error, NodeId, Protocol, Table},
util::{addr_nice, Duration, Time, TimeSource}
};
/// An ethernet frame dissector
///
/// This dissector is able to extract the source and destination addresses of ethernet frames.
///
/// If the ethernet frame contains a VLAN tag, both addresses will be prefixed with that tag,
/// resulting in 8-byte addresses. Additional nested tags will be ignored.
pub struct Frame;
impl Protocol for Frame {
/// Parses an ethernet frame and extracts the source and destination addresses
///
/// # Errors
/// This method will fail when the given data is not a valid ethernet frame.
fn parse(data: &[u8]) -> Result<(Address, Address), Error> {
if data.len() < 14 {
return Err(Error::Parse("Frame is too short"))
}
let mut pos = 0;
let dst_data = &data[pos..pos + 6];
pos += 6;
let src_data = &data[pos..pos + 6];
pos += 6;
if data[pos] == 0x81 && data[pos + 1] == 0x00 {
pos += 2;
if data.len() < pos + 2 {
return Err(Error::Parse("Vlan frame is too short"))
}
let mut src = [0; 16];
let mut dst = [0; 16];
src[0] = data[pos];
src[1] = data[pos + 1];
dst[0] = data[pos];
dst[1] = data[pos + 1];
src[2..8].copy_from_slice(src_data);
dst[2..8].copy_from_slice(dst_data);
Ok((Address { data: src, len: 8 }, Address { data: dst, len: 8 }))
} else {
let src = Address::read_from_fixed(src_data, 6)?;
let dst = Address::read_from_fixed(dst_data, 6)?;
Ok((src, dst))
}
}
}
struct SwitchTableValue {
address: SocketAddr,
timeout: Time
}
type Hash = BuildHasherDefault<FnvHasher>;
/// A table used to implement a learning switch
///
/// This table is a simple hash map between an address and the destination peer. It learns
/// addresses as they are seen and forgets them after some time.
pub struct SwitchTable<TS> {
/// The table storing the actual mapping
table: HashMap<Address, SwitchTableValue, Hash>,
/// Timeout period for forgetting learnt addresses
timeout: Duration,
// Timeout period for not overwriting learnt addresses
protection_period: Duration,
_dummy_ts: PhantomData<TS>
}
impl<TS: TimeSource> SwitchTable<TS> {
/// Creates a new switch table
pub fn new(timeout: Duration, protection_period: Duration) -> Self {
Self { table: HashMap::default(), timeout, protection_period, _dummy_ts: PhantomData }
}
}
impl<TS: TimeSource> Table for SwitchTable<TS> {
/// Forget addresses that have not been seen for the configured timeout
fn housekeep(&mut self) {
let now = TS::now();
let mut del: Vec<Address> = Vec::new();
for (key, val) in &self.table {
if val.timeout < now {
del.push(*key);
}
}
for key in del {
info!("Forgot address {}", key);
self.table.remove(&key);
}
}
/// Write out the table
fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error> {
let now = TS::now();
writeln!(out, "switch_table:")?;
for (addr, val) in &self.table {
writeln!(
out,
" - \"{}\": {{ peer: \"{}\", ttl_secs: {} }}",
addr,
addr_nice(val.address),
val.timeout - now
)?;
}
Ok(())
}
/// Learns the given address, inserting it in the hash map
#[inline]
fn learn(&mut self, key: Address, _prefix_len: Option<u8>, _: NodeId, addr: SocketAddr) {
let deadline = TS::now() + Time::from(self.timeout);
match self.table.entry(key) {
Entry::Vacant(entry) => {
entry.insert(SwitchTableValue { address: addr, timeout: deadline });
info!("Learned address {} => {}", key, addr_nice(addr));
}
Entry::Occupied(mut entry) => {
let mut entry = entry.get_mut();
if entry.timeout + Time::from(self.protection_period) > deadline {
// Do not override recently learnt entries
return
}
entry.timeout = deadline;
entry.address = addr;
}
}
}
/// Retrieves a peer for an address if it is inside the hash map
#[inline]
fn lookup(&mut self, key: &Address) -> Option<SocketAddr> {
match self.table.get(key) {
Some(value) => Some(value.address),
None => None
}
}
/// Removes an address from the map and returns whether something has been removed
#[inline]
fn remove(&mut self, key: &Address) -> bool {
self.table.remove(key).is_some()
}
/// Removed all addresses associated with a certain peer
fn remove_all(&mut self, addr: &SocketAddr) {
let mut remove = Vec::new();
for (key, val) in &self.table {
if &val.address == addr {
remove.push(*key);
}
}
for key in remove {
self.table.remove(&key);
}
}
fn len(&self) -> usize {
self.table.len()
}
}
#[cfg(test)] use super::util::MockTimeSource;
#[cfg(test)] use std::net::ToSocketAddrs;
#[cfg(test)] use std::str::FromStr;
#[test]
fn decode_frame_without_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 });
assert_eq!(dst, Address { data: [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 });
}
#[test]
fn decode_frame_with_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0, 4, 210, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [4, 210, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 });
assert_eq!(dst, Address { data: [4, 210, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 });
}
#[test]
fn decode_invalid_frame() {
assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8]).is_ok());
// truncated frame
assert!(Frame::parse(&[]).is_err());
// truncated vlan frame
assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0x00]).is_err());
}
#[test]
fn switch() {
MockTimeSource::set_time(1000);
let mut table = SwitchTable::<MockTimeSource>::new(10, 1);
let addr = Address::from_str("12:34:56:78:90:ab").unwrap();
let node_id = [0; 16];
let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap();
let peer2 = "1.2.3.5:7890".to_socket_addrs().unwrap().next().unwrap();
assert!(table.lookup(&addr).is_none());
MockTimeSource::set_time(1000);
table.learn(addr, None, node_id, peer);
assert_eq!(table.lookup(&addr), Some(peer));
MockTimeSource::set_time(1000);
table.learn(addr, None, node_id, peer2);
assert_eq!(table.lookup(&addr), Some(peer));
MockTimeSource::set_time(1010);
table.learn(addr, None, node_id, peer2);
assert_eq!(table.lookup(&addr), Some(peer2));
}

319
src/ip.rs
View File

@ -1,319 +0,0 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use std::{
collections::{hash_map, HashMap},
hash::BuildHasherDefault,
io::{self, Write},
net::SocketAddr
};
use fnv::FnvHasher;
use super::{
types::{Address, Error, NodeId, Protocol, Table},
util::addr_nice
};
/// An IP packet dissector
///
/// This dissector is able to extract the source and destination ip addresses of ipv4 packets and
/// ipv6 packets.
#[allow(dead_code)]
pub struct Packet;
impl Protocol for Packet {
/// Parses an ip packet and extracts the source and destination addresses
///
/// # Errors
/// This method will fail when the given data is not a valid ipv4 and ipv6 packet.
fn parse(data: &[u8]) -> Result<(Address, Address), Error> {
if data.is_empty() {
return Err(Error::Parse("Empty header"))
}
let version = data[0] >> 4;
match version {
4 => {
if data.len() < 20 {
return Err(Error::Parse("Truncated IPv4 header"))
}
let src = Address::read_from_fixed(&data[12..], 4)?;
let dst = Address::read_from_fixed(&data[16..], 4)?;
Ok((src, dst))
}
6 => {
if data.len() < 40 {
return Err(Error::Parse("Truncated IPv6 header"))
}
let src = Address::read_from_fixed(&data[8..], 16)?;
let dst = Address::read_from_fixed(&data[24..], 16)?;
Ok((src, dst))
}
_ => Err(Error::Parse("Invalid version"))
}
}
}
struct RoutingEntry {
address: SocketAddr,
node_id: NodeId,
bytes: Address,
prefix_len: u8
}
type Hash = BuildHasherDefault<FnvHasher>;
/// A prefix-based routing table
///
/// This table contains a mapping of prefixes associated with peer addresses.
/// To speed up lookup, prefixes are grouped into full bytes and map to a list of prefixes with
/// more fine grained prefixes.
#[derive(Default)]
pub struct RoutingTable(HashMap<[u8; 16], Vec<RoutingEntry>, Hash>);
impl RoutingTable {
/// Creates a new empty routing table
pub fn new() -> Self {
RoutingTable(HashMap::default())
}
}
impl Table for RoutingTable {
/// Learns the given address, inserting it in the hash map
fn learn(&mut self, addr: Address, prefix_len: Option<u8>, node_id: NodeId, address: SocketAddr) {
// If prefix length is not set, treat the whole address as significant
let prefix_len = match prefix_len {
Some(val) => val,
None => addr.len * 8
};
info!("New routing entry: {}/{} => {}", addr, prefix_len, addr_nice(address));
// Round the prefix length down to the next multiple of 8 and extract a prefix of that
// length.
let group_len = prefix_len as usize / 8;
assert!(group_len <= 16);
let mut group_bytes = [0; 16];
group_bytes[..group_len].copy_from_slice(&addr.data[..group_len]);
// Create an entry
let routing_entry = RoutingEntry { address, bytes: addr, node_id, prefix_len };
// Add the entry to the routing table, creating a new list if the prefix group is empty.
match self.0.entry(group_bytes) {
hash_map::Entry::Occupied(mut entry) => {
let list = entry.get_mut();
list.retain(|e| e.node_id != routing_entry.node_id || e.address == routing_entry.address);
list.push(routing_entry)
}
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![routing_entry]);
}
}
}
/// Retrieves a peer for an address if it is inside the routing table
#[allow(unknown_lints, clippy::needless_range_loop)]
fn lookup(&mut self, addr: &Address) -> Option<SocketAddr> {
let len = addr.len as usize;
let mut found = None;
let mut found_len: isize = -1;
// Iterate over the prefix length from longest prefix group to shortest (empty) prefix
// group
let mut group_bytes = addr.data;
for i in len..16 {
group_bytes[i] = 0;
}
for i in (0..=len).rev() {
if i < len {
group_bytes[i] = 0;
}
if let Some(group) = self.0.get(&group_bytes) {
// If the group is not empty, check every entry
for entry in group {
// Calculate the match length of the address and the prefix
let mut match_len = 0;
for j in 0..addr.len as usize {
let b = addr.data[j] ^ entry.bytes.data[j];
if b == 0 {
match_len += 8;
} else {
match_len += b.leading_zeros();
break
}
}
// If the full prefix matches and the match is longer than the longest prefix
// found so far, remember the peer
if match_len as u8 >= entry.prefix_len && entry.prefix_len as isize > found_len {
found = Some(entry.address);
found_len = entry.prefix_len as isize;
}
}
}
}
// Return the longest match found (if any).
found
}
/// This method does not do anything.
fn housekeep(&mut self) {
// nothing to do
}
/// Write out the table
fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error> {
writeln!(out, "routing_table:")?;
for entries in self.0.values() {
for entry in entries {
writeln!(
out,
" - \"{}/{}\": {{ peer: \"{}\" }}",
entry.bytes,
entry.prefix_len,
addr_nice(entry.address)
)?;
}
}
Ok(())
}
/// Removes an address from the map and returns whether something has been removed
#[inline]
fn remove(&mut self, _addr: &Address) -> bool {
// Do nothing, removing single address from prefix-based routing tables does not make sense
false
}
/// Removed all addresses associated with a certain peer
fn remove_all(&mut self, addr: &SocketAddr) {
for entry in &mut self.0.values_mut() {
entry.retain(|entr| &entr.address != addr);
}
}
fn len(&self) -> usize {
self.0.len()
}
}
#[cfg(test)] use std::net::ToSocketAddrs;
#[cfg(test)] use std::str::FromStr;
#[test]
fn decode_ipv4_packet() {
let data = [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2];
let (src, dst) = Packet::parse(&data).unwrap();
assert_eq!(src, Address { data: [192, 168, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 });
assert_eq!(dst, Address { data: [192, 168, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 });
}
#[test]
fn decode_ipv6_packet() {
let data = [
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2, 1
];
let (src, dst) = Packet::parse(&data).unwrap();
assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6], len: 16 });
assert_eq!(dst, Address { data: [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5, 4, 3, 2, 1], len: 16 });
}
#[test]
fn decode_invalid_packet() {
assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2]).is_ok());
assert!(Packet::parse(&[
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2, 1
])
.is_ok());
// no data
assert!(Packet::parse(&[]).is_err());
// wrong version
assert!(Packet::parse(&[0x20]).is_err());
// truncated ipv4
assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1]).is_err());
// truncated ipv6
assert!(Packet::parse(&[
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2
])
.is_err());
}
#[test]
fn routing_table_ipv4() {
let mut table = RoutingTable::new();
let peer1 = "1.2.3.4:1".to_socket_addrs().unwrap().next().unwrap();
let node1 = [1; 16];
let peer2 = "1.2.3.4:2".to_socket_addrs().unwrap().next().unwrap();
let node2 = [2; 16];
let peer3 = "1.2.3.4:3".to_socket_addrs().unwrap().next().unwrap();
let node3 = [3; 16];
assert!(table.lookup(&Address::from_str("192.168.1.1").unwrap()).is_none());
table.learn(Address::from_str("192.168.1.1").unwrap(), Some(32), node1, peer1);
assert_eq!(table.lookup(&Address::from_str("192.168.1.1").unwrap()), Some(peer1));
table.learn(Address::from_str("192.168.1.2").unwrap(), None, node2, peer2);
assert_eq!(table.lookup(&Address::from_str("192.168.1.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.2").unwrap()), Some(peer2));
table.learn(Address::from_str("192.168.1.0").unwrap(), Some(24), node3, peer3);
assert_eq!(table.lookup(&Address::from_str("192.168.1.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("192.168.1.3").unwrap()), Some(peer3));
table.learn(Address::from_str("192.168.0.0").unwrap(), Some(16), node1, peer1);
assert_eq!(table.lookup(&Address::from_str("192.168.2.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("192.168.1.3").unwrap()), Some(peer3));
table.learn(Address::from_str("0.0.0.0").unwrap(), Some(0), node2, peer2);
assert_eq!(table.lookup(&Address::from_str("192.168.2.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("192.168.1.2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("192.168.1.3").unwrap()), Some(peer3));
assert_eq!(table.lookup(&Address::from_str("1.2.3.4").unwrap()), Some(peer2));
table.learn(Address::from_str("192.168.2.0").unwrap(), Some(27), node3, peer3);
assert_eq!(table.lookup(&Address::from_str("192.168.2.31").unwrap()), Some(peer3));
assert_eq!(table.lookup(&Address::from_str("192.168.2.32").unwrap()), Some(peer1));
table.learn(Address::from_str("192.168.2.0").unwrap(), Some(28), node3, peer3);
assert_eq!(table.lookup(&Address::from_str("192.168.2.1").unwrap()), Some(peer3));
}
#[test]
fn routing_table_ipv6() {
let mut table = RoutingTable::new();
let peer1 = "::1:1".to_socket_addrs().unwrap().next().unwrap();
let node1 = [1; 16];
let peer2 = "::1:2".to_socket_addrs().unwrap().next().unwrap();
let node2 = [2; 16];
let peer3 = "::1:3".to_socket_addrs().unwrap().next().unwrap();
let node3 = [3; 16];
assert!(table.lookup(&Address::from_str("::1").unwrap()).is_none());
table.learn(Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap(), Some(128), node1, peer1);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap()), Some(peer1));
table.learn(Address::from_str("dead:beef:dead:beef:dead:beef:dead:2").unwrap(), None, node2, peer2);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:2").unwrap()), Some(peer2));
table.learn(Address::from_str("dead:beef:dead:beef::").unwrap(), Some(64), node3, peer3);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:3").unwrap()), Some(peer3));
table.learn(Address::from_str("dead:beef:dead:be00::").unwrap(), Some(56), node1, peer1);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:1::").unwrap()), Some(peer3));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:be01::").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:3").unwrap()), Some(peer3));
table.learn(Address::from_str("::").unwrap(), Some(0), node2, peer2);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:1::").unwrap()), Some(peer3));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:be01::").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:1").unwrap()), Some(peer1));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:2").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:3").unwrap()), Some(peer3));
assert_eq!(table.lookup(&Address::from_str("::1").unwrap()), Some(peer2));
table.learn(Address::from_str("dead:beef:dead:beef:dead:beef:dead:be00").unwrap(), Some(123), node2, peer2);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:be1f").unwrap()), Some(peer2));
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:be20").unwrap()), Some(peer3));
table.learn(Address::from_str("dead:beef:dead:beef:dead:beef:dead:be00").unwrap(), Some(124), node3, peer3);
assert_eq!(table.lookup(&Address::from_str("dead:beef:dead:beef:dead:beef:dead:be01").unwrap()), Some(peer3));
}

View File

@ -21,21 +21,22 @@ pub mod cloud;
pub mod config; pub mod config;
pub mod crypto; pub mod crypto;
pub mod device; pub mod device;
pub mod ethernet; pub mod error;
pub mod ip; pub mod messages;
pub mod net; pub mod net;
pub mod payload;
pub mod poll; pub mod poll;
pub mod port_forwarding; pub mod port_forwarding;
pub mod table;
pub mod traffic; pub mod traffic;
pub mod types; pub mod types;
pub mod udpmessage;
use structopt::StructOpt; use structopt::StructOpt;
use std::{ use std::{
fs::{self, File, Permissions}, fs::{self, File, Permissions},
io::{self, Write}, io::{self, Write},
net::UdpSocket, net::{Ipv4Addr, UdpSocket},
os::unix::fs::PermissionsExt, os::unix::fs::PermissionsExt,
path::Path, path::Path,
process::Command, process::Command,
@ -46,151 +47,15 @@ use std::{
use crate::{ use crate::{
cloud::GenericCloud, cloud::GenericCloud,
config::Config, config::{Args, Config},
crypto::{Crypto, CryptoMethod}, crypto::Crypto,
device::{Device, TunTapDevice, Type}, device::{Device, TunTapDevice, Type},
ethernet::SwitchTable,
ip::RoutingTable,
port_forwarding::PortForwarding, port_forwarding::PortForwarding,
types::{Error, HeaderMagic, Mode, Protocol, Range}, payload::Protocol,
util::{Duration, SystemTimeSource} util::SystemTimeSource
}; };
const VERSION: u8 = 1;
const MAGIC: HeaderMagic = *b"vpn\x01";
#[derive(StructOpt, Debug, Default)]
pub struct Args {
/// Read configuration options from the specified file.
#[structopt(long)]
config: Option<String>,
/// Set the type of network ("tap" or "tun")
#[structopt(name = "type", short, long)]
type_: Option<Type>,
/// Set the path of the base device
#[structopt(long)]
device_path: Option<String>,
/// The mode of the VPN ("normal", "router", "switch", or "hub")
#[structopt(short, long)]
mode: Option<Mode>,
/// The shared key to encrypt all traffic
#[structopt(short, long, aliases=&["shared-key", "secret-key", "secret"])]
key: Option<String>,
/// The encryption method to use ("aes128", "aes256", or "chacha20")
#[structopt(long)]
crypto: Option<CryptoMethod>,
/// The local subnets to use
#[structopt(short, long)]
subnets: Vec<String>,
/// Name of the virtual device
#[structopt(short, long)]
device: Option<String>,
/// The port number (or ip:port) on which to listen for data
#[structopt(short, long)]
listen: Option<String>,
/// Optional token that identifies the network. (DEPRECATED)
#[structopt(long)]
network_id: Option<String>,
/// Override the 4-byte magic header of each packet
#[structopt(long)]
magic: Option<String>,
/// Address of a peer to connect to
#[structopt(short, long)]
connect: Vec<String>,
/// Peer timeout in seconds
#[structopt(long)]
peer_timeout: Option<Duration>,
/// Periodically send message to keep connections alive
#[structopt(long)]
keepalive: Option<Duration>,
/// Switch table entry timeout in seconds
#[structopt(long)]
dst_timeout: Option<Duration>,
/// The file path or |command to store the beacon
#[structopt(long)]
beacon_store: Option<String>,
/// The file path or |command to load the beacon
#[structopt(long)]
beacon_load: Option<String>,
/// Beacon store/load interval in seconds
#[structopt(long)]
beacon_interval: Option<Duration>,
/// Print debug information
#[structopt(short, long, conflicts_with = "quiet")]
verbose: bool,
/// Only print errors and warnings
#[structopt(short, long)]
quiet: bool,
/// A command to setup the network interface
#[structopt(long)]
ifup: Option<String>,
/// A command to bring down the network interface
#[structopt(long)]
ifdown: Option<String>,
/// Print the version and exit
#[structopt(long)]
version: bool,
/// Disable automatic port forwarding
#[structopt(long)]
no_port_forwarding: bool,
/// Run the process in the background
#[structopt(long)]
daemon: bool,
/// Store the process id in this file when daemonizing
#[structopt(long)]
pid_file: Option<String>,
/// Print statistics to this file
#[structopt(long)]
stats_file: Option<String>,
/// Send statistics to this statsd server
#[structopt(long)]
statsd_server: Option<String>,
/// Use the given prefix for statsd records
#[structopt(long)]
statsd_prefix: Option<String>,
/// Run as other user
#[structopt(long)]
user: Option<String>,
/// Run as other group
#[structopt(long)]
group: Option<String>,
/// Print logs also to this file
#[structopt(long)]
log_file: Option<String>
}
struct DualLogger { struct DualLogger {
file: Mutex<Option<File>> file: Mutex<Option<File>>
} }
@ -252,91 +117,21 @@ fn run_script(script: &str, ifname: &str) {
} }
} }
enum AnyTable { fn parse_ip_netmask(addr: &str) -> Result<(Ipv4Addr, Ipv4Addr), String> {
Switch(SwitchTable<SystemTimeSource>), let (ip_str, len_str) = match addr.find('/') {
Routing(RoutingTable) Some(pos) => (&addr[..pos], &addr[pos + 1..]),
None => (addr, "24")
};
let prefix_len = u8::from_str(len_str).map_err(|_| format!("Invalid prefix length: {}", len_str))?;
if prefix_len > 32 {
return Err(format!("Invalid prefix length: {}", prefix_len))
}
let ip = Ipv4Addr::from_str(ip_str).map_err(|_| format!("Invalid ip address: {}", ip_str))?;
let netmask = Ipv4Addr::from(u32::max_value().checked_shl(32 - prefix_len as u32).unwrap());
Ok((ip, netmask))
} }
enum AnyCloud<P: Protocol> { fn setup_device(config: &Config) -> TunTapDevice {
Switch(GenericCloud<TunTapDevice, P, SwitchTable<SystemTimeSource>, UdpSocket, SystemTimeSource>),
Routing(GenericCloud<TunTapDevice, P, RoutingTable, UdpSocket, SystemTimeSource>)
}
impl<P: Protocol> AnyCloud<P> {
#[allow(unknown_lints, clippy::too_many_arguments)]
fn new(
config: &Config, device: TunTapDevice, table: AnyTable, learning: bool, broadcast: bool, addresses: Vec<Range>,
crypto: Crypto, port_forwarding: Option<PortForwarding>, stats_file: Option<File>
) -> Self
{
match table {
AnyTable::Switch(t) => {
AnyCloud::Switch(GenericCloud::<
TunTapDevice,
P,
SwitchTable<SystemTimeSource>,
UdpSocket,
SystemTimeSource
>::new(
config,
device,
t,
learning,
broadcast,
addresses,
crypto,
port_forwarding,
stats_file
))
}
AnyTable::Routing(t) => {
AnyCloud::Routing(GenericCloud::<TunTapDevice, P, RoutingTable, UdpSocket, SystemTimeSource>::new(
config,
device,
t,
learning,
broadcast,
addresses,
crypto,
port_forwarding,
stats_file
))
}
}
}
fn ifname(&self) -> &str {
match *self {
AnyCloud::Switch(ref c) => c.ifname(),
AnyCloud::Routing(ref c) => c.ifname()
}
}
fn run(&mut self) {
match *self {
AnyCloud::Switch(ref mut c) => c.run(),
AnyCloud::Routing(ref mut c) => c.run()
}
}
fn connect(&mut self, a: &str) -> Result<(), Error> {
match *self {
AnyCloud::Switch(ref mut c) => c.connect(a),
AnyCloud::Routing(ref mut c) => c.connect(a)
}
}
fn add_reconnect_peer(&mut self, a: String) {
match *self {
AnyCloud::Switch(ref mut c) => c.add_reconnect_peer(a),
AnyCloud::Routing(ref mut c) => c.add_reconnect_peer(a)
}
}
}
#[allow(clippy::cognitive_complexity)]
fn run<P: Protocol>(config: Config) {
let device = try_fail!( let device = try_fail!(
TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)), TunTapDevice::new(&config.device_name, config.device_type, config.device_path.as_ref().map(|s| s as &str)),
"Failed to open virtual {} interface {}: {}", "Failed to open virtual {} interface {}: {}",
@ -344,27 +139,32 @@ fn run<P: Protocol>(config: Config) {
config.device_name config.device_name
); );
info!("Opened device {}", device.ifname()); info!("Opened device {}", device.ifname());
let mut ranges = Vec::with_capacity(config.subnets.len()); if let Err(err) = device.set_mtu(None) {
for s in &config.subnets { error!("Error setting optimal MTU on {}: {}", device.ifname(), err);
ranges.push(try_fail!(Range::from_str(s), "Invalid subnet format: {} ({})", s));
} }
let dst_timeout = config.dst_timeout; if let Some(ip) = &config.ip {
let (learning, broadcasting, table) = match config.mode { let (ip, netmask) = try_fail!(parse_ip_netmask(ip), "Invalid ip address given: {}");
Mode::Normal => { info!("Configuring device with ip {}, netmask {}", ip, netmask);
match config.device_type { try_fail!(device.configure(ip, netmask), "Failed to configure device: {}");
Type::Tap => (true, true, AnyTable::Switch(SwitchTable::new(dst_timeout, 10))), }
Type::Tun => (false, false, AnyTable::Routing(RoutingTable::new())), if let Some(script) = &config.ifup {
Type::Dummy => (false, false, AnyTable::Switch(SwitchTable::new(dst_timeout, 10))) run_script(script, device.ifname());
} }
if config.fix_rp_filter {
try_fail!(device.fix_rp_filter(), "Failed to change rp_filter settings: {}");
}
if let Ok(val) = device.get_rp_filter() {
if val != 1 {
warn!("Your networking configuration might be affected by a vulnerability (https://seclists.org/oss-sec/2019/q4/122), please change your rp_filter setting to 1 (currently {}).", val);
} }
Mode::Router => (false, false, AnyTable::Routing(RoutingTable::new())), }
Mode::Switch => (true, true, AnyTable::Switch(SwitchTable::new(dst_timeout, 10))), device
Mode::Hub => (false, true, AnyTable::Switch(SwitchTable::new(dst_timeout, 10))) }
};
let crypto = match config.shared_key {
Some(ref key) => Crypto::from_shared_key(config.crypto, key), #[allow(clippy::cognitive_complexity)]
None => Crypto::None fn run<P: Protocol>(config: Config) {
}; let device = setup_device(&config);
let port_forwarding = if config.port_forwarding { PortForwarding::new(config.listen.port()) } else { None }; let port_forwarding = if config.port_forwarding { PortForwarding::new(config.listen.port()) } else { None };
let stats_file = match config.stats_file { let stats_file = match config.stats_file {
None => None, None => None,
@ -382,10 +182,7 @@ fn run<P: Protocol>(config: Config) {
} }
}; };
let mut cloud = let mut cloud =
AnyCloud::<P>::new(&config, device, table, learning, broadcasting, ranges, crypto, port_forwarding, stats_file); GenericCloud::<TunTapDevice, P, UdpSocket, SystemTimeSource>::new(&config, device, port_forwarding, stats_file);
if let Some(script) = config.ifup {
run_script(&script, cloud.ifname());
}
for addr in config.peers { for addr in config.peers {
try_fail!(cloud.connect(&addr as &str), "Failed to send message to {}: {}", &addr); try_fail!(cloud.connect(&addr as &str), "Failed to send message to {}: {}", &addr);
cloud.add_reconnect_peer(addr); cloud.add_reconnect_peer(addr);
@ -425,7 +222,15 @@ fn run<P: Protocol>(config: Config) {
fn main() { fn main() {
let args: Args = Args::from_args(); let args: Args = Args::from_args();
if args.version { if args.version {
println!("VpnCloud v{}, protocol version {}", env!("CARGO_PKG_VERSION"), VERSION); println!("VpnCloud v{}", env!("CARGO_PKG_VERSION"));
return
}
if args.genkey {
let (privkey, pubkey) = Crypto::generate_keypair();
println!("Private key: {}\nPublic key: {}\n", privkey, pubkey);
println!(
"Attention: Keep the private key secret and use only the public key on other nodes to establish trust."
);
return return
} }
let logger = try_fail!(DualLogger::new(args.log_file.as_ref()), "Failed to open logfile: {}"); let logger = try_fail!(DualLogger::new(args.log_file.as_ref()), "Failed to open logfile: {}");
@ -448,8 +253,8 @@ fn main() {
config.merge_args(args); config.merge_args(args);
debug!("Config: {:?}", config); debug!("Config: {:?}", config);
match config.device_type { match config.device_type {
Type::Tap => run::<ethernet::Frame>(config), Type::Tap => run::<payload::Frame>(config),
Type::Tun => run::<ip::Packet>(config), Type::Tun => run::<payload::Packet>(config),
Type::Dummy => run::<ethernet::Frame>(config) Type::Dummy => run::<payload::Frame>(config)
} }
} }

210
src/messages.rs Normal file
View File

@ -0,0 +1,210 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use crate::{
crypto::Payload,
error::Error,
types::{NodeId, Range, RangeList, NODE_ID_BYTES},
util::MsgBuffer
};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use smallvec::{smallvec, SmallVec};
use std::{
io::{self, Cursor, Read, Seek, SeekFrom, Take, Write},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}
};
pub const MESSAGE_TYPE_DATA: u8 = 0;
pub const MESSAGE_TYPE_NODE_INFO: u8 = 1;
pub const MESSAGE_TYPE_KEEPALIVE: u8 = 2;
pub const MESSAGE_TYPE_CLOSE: u8 = 0xff;
pub type AddrList = SmallVec<[SocketAddr; 4]>;
pub type PeerList = SmallVec<[PeerInfo; 16]>;
#[derive(Debug, PartialEq)]
pub struct PeerInfo {
pub node_id: Option<NodeId>,
pub addrs: AddrList
}
#[derive(Debug, PartialEq)]
pub struct NodeInfo {
pub peers: PeerList,
pub claims: RangeList,
pub peer_timeout: Option<u16>
}
impl NodeInfo {
const PART_CLAIMS: u8 = 2;
const PART_END: u8 = 0;
const PART_PEERS: u8 = 1;
const PART_PEER_TIMEOUT: u8 = 3;
fn decode_peer_list_part<R: Read>(r: &mut Take<R>) -> Result<PeerList, io::Error> {
let mut peers = smallvec![];
while r.limit() > 0 {
let flags = r.read_u8()?;
let has_node_id = (flags & 0x80) != 0;
let num_ipv4_addrs = (flags & 0x07) as usize;
let num_ipv6_addrs = (flags & 0x38) as usize / 8;
let mut node_id = None;
if has_node_id {
let mut id = [0; NODE_ID_BYTES];
r.read_exact(&mut id)?;
node_id = Some(id)
}
let mut addrs = SmallVec::with_capacity(num_ipv4_addrs + num_ipv6_addrs);
for _ in 0..num_ipv6_addrs {
let mut ip = [0u8; 16];
r.read_exact(&mut ip)?;
let port = r.read_u16::<NetworkEndian>()?;
let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0));
addrs.push(addr);
}
for _ in 0..num_ipv4_addrs {
let mut ip = [0u8; 4];
r.read_exact(&mut ip)?;
let port = r.read_u16::<NetworkEndian>()?;
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port));
addrs.push(addr);
}
peers.push(PeerInfo { addrs, node_id })
}
Ok(peers)
}
fn decode_claims_part<R: Read>(mut r: &mut Take<R>) -> Result<RangeList, Error> {
let mut claims = smallvec![];
while r.limit() > 0 {
claims.push(Range::read_from(&mut r)?);
}
Ok(claims)
}
fn decode_internal<R: Read>(mut r: R) -> Result<Self, Error> {
let mut peers = smallvec![];
let mut claims = smallvec![];
let mut peer_timeout = None;
loop {
let part = r.read_u8().map_err(|_| Error::Message("Truncated message"))?;
if part == Self::PART_END {
break
}
let part_len = r.read_u16::<NetworkEndian>().map_err(|_| Error::Message("Truncated message"))? as usize;
let mut rp = r.take(part_len as u64);
match part {
Self::PART_PEERS => {
peers = Self::decode_peer_list_part(&mut rp).map_err(|_| Error::Message("Truncated message"))?
}
Self::PART_CLAIMS => claims = Self::decode_claims_part(&mut rp)?,
Self::PART_PEER_TIMEOUT => {
peer_timeout =
Some(rp.read_u16::<NetworkEndian>().map_err(|_| Error::Message("Truncated message"))?)
}
_ => {
let mut data = vec![0; part_len];
rp.read_exact(&mut data).map_err(|_| Error::Message("Truncated message"))?;
}
}
r = rp.into_inner();
}
Ok(Self { peers, claims, peer_timeout })
}
pub fn decode<R: Read>(r: R) -> Result<Self, Error> {
Self::decode_internal(r).map_err(|_| Error::Message("Input data too short"))
}
fn encode_peer_list_part<W: Write>(&self, mut out: W) -> Result<(), io::Error> {
for p in &self.peers {
let mut addr_ipv4 = vec![];
let mut addr_ipv6 = vec![];
for a in &p.addrs {
match a {
SocketAddr::V4(addr) => addr_ipv4.push(*addr),
SocketAddr::V6(addr) => addr_ipv6.push(*addr)
}
}
while addr_ipv4.len() >= 8 {
addr_ipv4.pop();
}
while addr_ipv6.len() >= 8 {
addr_ipv6.pop();
}
let mut flags = addr_ipv6.len() as u8 * 8 + addr_ipv4.len() as u8;
if p.node_id.is_some() {
flags += 0x80;
}
out.write_u8(flags)?;
if let Some(node_id) = &p.node_id {
out.write_all(node_id)?;
}
for a in addr_ipv6 {
out.write_all(&a.ip().octets())?;
out.write_u16::<NetworkEndian>(a.port())?;
}
for a in addr_ipv4 {
out.write_all(&a.ip().octets())?;
out.write_u16::<NetworkEndian>(a.port())?;
}
}
Ok(())
}
fn encode_part<F: FnOnce(&mut Cursor<&mut [u8]>) -> Result<(), io::Error>>(
cursor: &mut Cursor<&mut [u8]>, part: u8, f: F
) -> Result<(), io::Error> {
cursor.write_u8(part)?;
cursor.write_u16::<NetworkEndian>(0)?;
let part_start = cursor.position();
f(cursor)?;
let part_end = cursor.position();
let len = part_end - part_start;
cursor.seek(SeekFrom::Start(part_start - 2))?;
cursor.write_u16::<NetworkEndian>(len as u16)?;
cursor.seek(SeekFrom::Start(part_end))?;
Ok(())
}
fn encode_internal(&self, buffer: &mut MsgBuffer) -> Result<(), io::Error> {
let len;
{
let mut cursor = Cursor::new(buffer.buffer());
Self::encode_part(&mut cursor, Self::PART_PEERS, |cursor| self.encode_peer_list_part(cursor))?;
Self::encode_part(&mut cursor, Self::PART_CLAIMS, |mut cursor| {
for c in &self.claims {
c.write_to(&mut cursor);
}
Ok(())
})?;
if let Some(timeout) = self.peer_timeout {
Self::encode_part(&mut cursor, Self::PART_PEER_TIMEOUT, |cursor| {
cursor.write_u16::<NetworkEndian>(timeout)
})?
}
cursor.write_u8(Self::PART_END)?;
len = cursor.position() as usize;
}
buffer.set_length(len);
Ok(())
}
pub fn encode(&self, buffer: &mut MsgBuffer) {
self.encode_internal(buffer).expect("Buffer too small")
}
}
impl Payload for NodeInfo {
fn write_to(&self, buffer: &mut MsgBuffer) {
self.encode(buffer)
}
fn read_from<R: Read>(r: R) -> Result<Self, Error> {
Self::decode(r)
}
}

View File

@ -5,17 +5,24 @@
use std::{ use std::{
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
io::{self, ErrorKind}, io::{self, ErrorKind},
net::{SocketAddr, UdpSocket}, net::{IpAddr, SocketAddr, UdpSocket},
os::unix::io::{AsRawFd, RawFd}, os::unix::io::{AsRawFd, RawFd},
sync::atomic::{AtomicBool, Ordering} sync::atomic::{AtomicBool, Ordering}
}; };
use super::util::{MockTimeSource, Time, TimeSource}; use super::util::{MockTimeSource, MsgBuffer, Time, TimeSource};
pub fn mapped_addr(addr: SocketAddr) -> SocketAddr {
match addr {
SocketAddr::V4(addr4) => SocketAddr::new(IpAddr::V6(addr4.ip().to_ipv6_mapped()), addr4.port()),
_ => addr
}
}
pub trait Socket: AsRawFd + Sized { pub trait Socket: AsRawFd + Sized {
fn listen(addr: SocketAddr) -> Result<Self, io::Error>; fn listen(addr: SocketAddr) -> Result<Self, io::Error>;
fn receive(&mut self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), io::Error>; fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error>;
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>; fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error>;
fn address(&self) -> Result<SocketAddr, io::Error>; fn address(&self) -> Result<SocketAddr, io::Error>;
} }
@ -24,12 +31,18 @@ impl Socket for UdpSocket {
fn listen(addr: SocketAddr) -> Result<Self, io::Error> { fn listen(addr: SocketAddr) -> Result<Self, io::Error> {
UdpSocket::bind(addr) UdpSocket::bind(addr)
} }
fn receive(&mut self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), io::Error> {
self.recv_from(buffer) fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
buffer.clear();
let (size, addr) = self.recv_from(buffer.buffer())?;
buffer.set_length(size);
Ok(addr)
} }
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> { fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> {
self.send_to(data, addr) self.send_to(data, addr)
} }
fn address(&self) -> Result<SocketAddr, io::Error> { fn address(&self) -> Result<SocketAddr, io::Error> {
self.local_addr() self.local_addr()
} }
@ -96,14 +109,18 @@ impl Socket for MockSocket {
fn listen(addr: SocketAddr) -> Result<Self, io::Error> { fn listen(addr: SocketAddr) -> Result<Self, io::Error> {
Ok(Self::new(addr)) Ok(Self::new(addr))
} }
fn receive(&mut self, buffer: &mut [u8]) -> Result<(usize, SocketAddr), io::Error> {
fn receive(&mut self, buffer: &mut MsgBuffer) -> Result<SocketAddr, io::Error> {
if let Some((addr, data)) = self.inbound.pop_front() { if let Some((addr, data)) = self.inbound.pop_front() {
buffer[0..data.len()].copy_from_slice(&data); buffer.clear();
Ok((data.len(), addr)) buffer.set_length(data.len());
buffer.message_mut().copy_from_slice(&data);
Ok(addr)
} else { } else {
Err(io::Error::new(ErrorKind::Other, "nothing in queue")) Err(io::Error::new(ErrorKind::Other, "nothing in queue"))
} }
} }
fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> { fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result<usize, io::Error> {
self.outbound.push_back((addr, data.to_owned())); self.outbound.push_back((addr, data.to_owned()));
if self.nat { if self.nat {
@ -111,6 +128,7 @@ impl Socket for MockSocket {
} }
Ok(data.len()) Ok(data.len())
} }
fn address(&self) -> Result<SocketAddr, io::Error> { fn address(&self) -> Result<SocketAddr, io::Error> {
Ok(self.address) Ok(self.address)
} }

161
src/payload.rs Normal file
View File

@ -0,0 +1,161 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use crate::{error::Error, types::Address};
pub trait Protocol: Sized {
fn parse(_: &[u8]) -> Result<(Address, Address), Error>;
}
/// An ethernet frame dissector
///
/// This dissector is able to extract the source and destination addresses of ethernet frames.
///
/// If the ethernet frame contains a VLAN tag, both addresses will be prefixed with that tag,
/// resulting in 8-byte addresses. Additional nested tags will be ignored.
pub struct Frame;
impl Protocol for Frame {
/// Parses an ethernet frame and extracts the source and destination addresses
///
/// # Errors
/// This method will fail when the given data is not a valid ethernet frame.
fn parse(data: &[u8]) -> Result<(Address, Address), Error> {
if data.len() < 14 {
return Err(Error::Parse("Frame is too short"))
}
let mut pos = 0;
let dst_data = &data[pos..pos + 6];
pos += 6;
let src_data = &data[pos..pos + 6];
pos += 6;
if data[pos] == 0x81 && data[pos + 1] == 0x00 {
pos += 2;
if data.len() < pos + 2 {
return Err(Error::Parse("Vlan frame is too short"))
}
let mut src = [0; 16];
let mut dst = [0; 16];
src[0] = data[pos];
src[1] = data[pos + 1];
dst[0] = data[pos];
dst[1] = data[pos + 1];
src[2..8].copy_from_slice(src_data);
dst[2..8].copy_from_slice(dst_data);
Ok((Address { data: src, len: 8 }, Address { data: dst, len: 8 }))
} else {
let src = Address::read_from_fixed(src_data, 6)?;
let dst = Address::read_from_fixed(dst_data, 6)?;
Ok((src, dst))
}
}
}
#[test]
fn decode_frame_without_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 });
assert_eq!(dst, Address { data: [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 });
}
#[test]
fn decode_frame_with_vlan() {
let data = [6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0, 4, 210, 1, 2, 3, 4, 5, 6, 7, 8];
let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address { data: [4, 210, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 });
assert_eq!(dst, Address { data: [4, 210, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 });
}
#[test]
fn decode_invalid_frame() {
assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8]).is_ok());
// truncated frame
assert!(Frame::parse(&[]).is_err());
// truncated vlan frame
assert!(Frame::parse(&[6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 0x81, 0x00]).is_err());
}
/// An IP packet dissector
///
/// This dissector is able to extract the source and destination ip addresses of ipv4 packets and
/// ipv6 packets.
#[allow(dead_code)]
pub struct Packet;
impl Protocol for Packet {
/// Parses an ip packet and extracts the source and destination addresses
///
/// # Errors
/// This method will fail when the given data is not a valid ipv4 and ipv6 packet.
fn parse(data: &[u8]) -> Result<(Address, Address), Error> {
if data.is_empty() {
return Err(Error::Parse("Empty header"))
}
let version = data[0] >> 4;
match version {
4 => {
if data.len() < 20 {
return Err(Error::Parse("Truncated IPv4 header"))
}
let src = Address::read_from_fixed(&data[12..], 4)?;
let dst = Address::read_from_fixed(&data[16..], 4)?;
Ok((src, dst))
}
6 => {
if data.len() < 40 {
return Err(Error::Parse("Truncated IPv6 header"))
}
let src = Address::read_from_fixed(&data[8..], 16)?;
let dst = Address::read_from_fixed(&data[24..], 16)?;
Ok((src, dst))
}
_ => Err(Error::Parse("Invalid IP protocol version"))
}
}
}
#[test]
fn decode_ipv4_packet() {
let data = [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2];
let (src, dst) = Packet::parse(&data).unwrap();
assert_eq!(src, Address { data: [192, 168, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 });
assert_eq!(dst, Address { data: [192, 168, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 });
}
#[test]
fn decode_ipv6_packet() {
let data = [
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2, 1
];
let (src, dst) = Packet::parse(&data).unwrap();
assert_eq!(src, Address { data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6], len: 16 });
assert_eq!(dst, Address { data: [0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5, 4, 3, 2, 1], len: 16 });
}
#[test]
fn decode_invalid_packet() {
assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1, 2]).is_ok());
assert!(Packet::parse(&[
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2, 1
])
.is_ok());
// no data
assert!(Packet::parse(&[]).is_err());
// wrong version
assert!(Packet::parse(&[0x20]).is_err());
// truncated ipv4
assert!(Packet::parse(&[0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 192, 168, 1, 1, 192, 168, 1]).is_err());
// truncated ipv6
assert!(Packet::parse(&[
0x60, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 9, 8, 7, 6, 5, 4, 3, 2, 1, 6, 5,
4, 3, 2
])
.is_err());
}

134
src/table.rs Normal file
View File

@ -0,0 +1,134 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use fnv::FnvHasher;
use std::{
cmp::min, collections::HashMap, hash::BuildHasherDefault, io, io::Write, marker::PhantomData, net::SocketAddr
};
use crate::{
types::{Address, Range, RangeList},
util::{Duration, Time, TimeSource, addr_nice}
};
type Hash = BuildHasherDefault<FnvHasher>;
struct CacheValue {
peer: SocketAddr,
timeout: Time
}
struct ClaimEntry {
peer: SocketAddr,
claim: Range,
timeout: Time
}
pub struct ClaimTable<TS: TimeSource> {
cache: HashMap<Address, CacheValue, Hash>,
cache_timeout: Duration,
claims: Vec<ClaimEntry>,
claim_timeout: Duration,
_dummy: PhantomData<TS>
}
impl<TS: TimeSource> ClaimTable<TS> {
pub fn new(cache_timeout: Duration, claim_timeout: Duration) -> Self {
Self { cache: HashMap::default(), cache_timeout, claims: vec![], claim_timeout, _dummy: PhantomData }
}
pub fn cache(&mut self, addr: Address, peer: SocketAddr) {
self.cache.insert(addr, CacheValue { peer, timeout: TS::now() + self.cache_timeout as Time });
}
pub fn set_claims(&mut self, peer: SocketAddr, mut claims: RangeList) {
for entry in &mut self.claims {
if entry.peer == peer {
let pos = claims.iter().position(|r| r == &entry.claim);
if let Some(pos) = pos {
entry.timeout = TS::now() + self.claim_timeout as Time;
claims.swap_remove(pos);
if claims.is_empty() {
break
}
} else {
entry.timeout = 0
}
}
}
for claim in claims {
self.claims.push(ClaimEntry { peer, claim, timeout: TS::now() + self.claim_timeout as Time })
}
self.housekeep()
}
pub fn remove_claims(&mut self, peer: SocketAddr) {
for entry in &mut self.claims {
if entry.peer == peer {
entry.timeout = 0
}
}
self.housekeep()
}
pub fn lookup(&mut self, addr: Address) -> Option<SocketAddr> {
if let Some(entry) = self.cache.get(&addr) {
return Some(entry.peer)
}
for entry in &self.claims {
if entry.claim.matches(addr) {
self.cache.insert(addr, CacheValue {
peer: entry.peer,
timeout: min(TS::now() + self.cache_timeout as Time, entry.timeout)
});
return Some(entry.peer)
}
}
None
}
pub fn housekeep(&mut self) {
let now = TS::now();
// TODO: also remove cache when removing claims
self.cache.retain(|_, v| v.timeout >= now);
self.claims.retain(|e| e.timeout >= now);
}
pub fn cache_len(&self) -> usize {
self.cache.len()
}
pub fn claim_len(&self) -> usize {
self.claims.len()
}
/// Write out the table
pub fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error> {
let now = TS::now();
writeln!(out, "forwarding_table:")?;
writeln!(out, " claims:")?;
for entry in &self.claims {
writeln!(
out,
" - \"{}\": {{ peer: \"{}\", timeout: {} }}",
entry.claim,
addr_nice(entry.peer),
entry.timeout - now
)?;
}
writeln!(out, " cache:")?;
for (addr, entry) in &self.cache {
writeln!(
out,
" - \"{}\": {{ peer: \"{}\", timeout: {} }}",
addr,
addr_nice(entry.peer),
entry.timeout - now
)?;
}
Ok(())
}
}

View File

@ -1,70 +0,0 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
macro_rules! assert_clean {
($($node: expr),*) => {
$(
assert_eq!($node.socket().pop_outbound().map(|(addr, mut msg)| (addr, $node.decode_message(&mut msg).unwrap().without_data())), None);
assert_eq!($node.device().pop_outbound(), None);
)*
};
}
macro_rules! assert_message4 {
($from: expr, $from_addr: expr, $to: expr, $to_addr: expr, $message: expr) => {
let (addr, mut data) = msg_get(&mut $from);
assert_eq!($to_addr, addr);
{
let message = $from.decode_message(&mut data).unwrap();
assert_eq!($message, message.without_data());
}
msg_put(&mut $to, $from_addr, data);
};
}
#[allow(unused_macros)]
macro_rules! assert_message6 {
($from: expr, $from_addr: expr, $to: expr, $to_addr: expr, $message: expr) => {
let (addr, mut data) = msg6_get(&mut $from);
assert_eq!($to_addr, addr);
{
let message = $from.decode_message(&mut data).unwrap();
assert_eq!($message, message.without_data());
}
msg6_put(&mut $to, $from_addr, data);
};
}
macro_rules! simulate {
($($node: expr => $addr: expr),*) => {
simulate(&mut [$((&mut $node, $addr)),*]);
};
}
macro_rules! simulate_time {
($time:expr, $($node: expr => $addr: expr),*) => {
for _ in 0..$time {
use crate::util::{MockTimeSource, TimeSource};
MockTimeSource::set_time(MockTimeSource::now()+1);
$(
$node.trigger_housekeep();
)*
simulate(&mut [$((&mut $node, $addr)),*]);
}
};
}
macro_rules! assert_connected {
($($node:expr),*) => {
for node1 in [$(&$node),*].iter() {
for node2 in [$(&$node),*].iter() {
if node1.node_id() == node2.node_id() {
continue
}
assert!(node1.peers().contains_node(&node2.node_id()));
}
}
};
}

View File

@ -2,13 +2,12 @@
// Copyright (C) 2015-2020 Dennis Schwerdel // Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
#[macro_use]
mod helper;
mod nat; mod nat;
mod payload; mod payload;
mod peers; mod peers;
use std::{ use std::{
collections::{HashMap, VecDeque},
io::Write, io::Write,
net::{IpAddr, Ipv6Addr, SocketAddr}, net::{IpAddr, Ipv6Addr, SocketAddr},
sync::{ sync::{
@ -19,15 +18,12 @@ use std::{
pub use super::{ pub use super::{
cloud::GenericCloud, cloud::GenericCloud,
config::Config, config::{Config, CryptoConfig},
crypto::Crypto, device::{MockDevice, Type},
device::MockDevice,
ethernet::{self, SwitchTable},
ip::{self, RoutingTable},
net::MockSocket, net::MockSocket,
types::{Protocol, Range, Table}, payload::{Frame, Packet, Protocol},
udpmessage::Message, types::Range,
util::MockTimeSource util::{MockTimeSource, Time, TimeSource}
}; };
@ -40,8 +36,16 @@ pub fn init_debug_logger() {
}) })
} }
static CURRENT_NODE: AtomicUsize = AtomicUsize::new(0);
struct DebugLogger; struct DebugLogger;
impl DebugLogger {
pub fn set_node(node: usize) {
CURRENT_NODE.store(node, Ordering::SeqCst);
}
}
impl log::Log for DebugLogger { impl log::Log for DebugLogger {
#[inline] #[inline]
fn enabled(&self, _metadata: &log::Metadata) -> bool { fn enabled(&self, _metadata: &log::Metadata) -> bool {
@ -51,7 +55,7 @@ impl log::Log for DebugLogger {
#[inline] #[inline]
fn log(&self, record: &log::Record) { fn log(&self, record: &log::Record) {
if self.enabled(record.metadata()) { if self.enabled(record.metadata()) {
eprintln!("{} - {}", record.level(), record.args()); eprintln!("Node {} - {} - {}", CURRENT_NODE.load(Ordering::SeqCst), record.level(), record.args());
} }
} }
@ -62,83 +66,152 @@ impl log::Log for DebugLogger {
} }
type TestNode<P, T> = GenericCloud<MockDevice, P, T, MockSocket, MockTimeSource>; type TestNode<P> = GenericCloud<MockDevice, P, MockSocket, MockTimeSource>;
type TapTestNode = TestNode<ethernet::Frame, SwitchTable<MockTimeSource>>; pub struct Simulator<P: Protocol> {
next_port: u16,
nodes: HashMap<SocketAddr, TestNode<P>>,
messages: VecDeque<(SocketAddr, SocketAddr, Vec<u8>)>
}
pub type TapSimulator = Simulator<Frame>;
#[allow(dead_code)] #[allow(dead_code)]
type TunTestNode = TestNode<ip::Packet, RoutingTable>; pub type TunSimulator = Simulator<Packet>;
thread_local! { impl<P: Protocol> Simulator<P> {
static NEXT_PORT: AtomicUsize = AtomicUsize::new(1); pub fn new() -> Self {
} init_debug_logger();
MockTimeSource::set_time(0);
fn next_sock_addr() -> SocketAddr { Self { next_port: 1, nodes: HashMap::default(), messages: VecDeque::default() }
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), NEXT_PORT.with(|p| p.fetch_add(1, Ordering::Relaxed)) as u16)
}
fn create_tap_node(nat: bool) -> TapTestNode {
create_tap_node_with_config(nat, Config::default())
}
fn create_tap_node_with_config(nat: bool, mut config: Config) -> TapTestNode {
MockSocket::set_nat(nat);
config.listen = next_sock_addr();
TestNode::new(&config, MockDevice::new(), SwitchTable::new(1800, 10), true, true, vec![], Crypto::None, None, None)
}
#[allow(dead_code)]
fn create_tun_node(nat: bool, addresses: Vec<Range>) -> TunTestNode {
MockSocket::set_nat(nat);
TestNode::new(
&Config { listen: next_sock_addr(), ..Config::default() },
MockDevice::new(),
RoutingTable::new(),
false,
false,
addresses,
Crypto::None,
None,
None
)
}
fn msg_get<P: Protocol, T: Table>(node: &mut TestNode<P, T>) -> (SocketAddr, Vec<u8>) {
let msg = node.socket().pop_outbound();
assert!(msg.is_some());
msg.unwrap()
}
fn msg_put<P: Protocol, T: Table>(node: &mut TestNode<P, T>, from: SocketAddr, msg: Vec<u8>) {
if node.socket().put_inbound(from, msg) {
node.trigger_socket_event();
} }
}
fn simulate<P: Protocol, T: Table>(nodes: &mut [(&mut TestNode<P, T>, SocketAddr)]) { pub fn add_node(&mut self, nat: bool, config: &Config) -> SocketAddr {
for (ref mut node, ref _from_addr) in nodes.iter_mut() { let mut config = config.clone();
while node.device().has_inbound() { MockSocket::set_nat(nat);
node.trigger_device_event(); config.listen = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.next_port);
if config.crypto.password.is_none() && config.crypto.private_key.is_none() {
config.crypto.password = Some("test123".to_string())
} }
DebugLogger::set_node(self.next_port as usize);
self.next_port += 1;
let node = TestNode::new(&config, MockDevice::new(), None, None);
DebugLogger::set_node(0);
self.nodes.insert(config.listen, node);
config.listen
} }
let mut clean = false;
while !clean { #[allow(dead_code)]
clean = true; pub fn get_node(&mut self, addr: SocketAddr) -> &mut TestNode<P> {
let mut msgs = Vec::new(); let node = self.nodes.get_mut(&addr).unwrap();
for (ref mut node, ref from_addr) in nodes.iter_mut() { DebugLogger::set_node(node.get_num());
while let Some((to_addr, msg)) = node.socket().pop_outbound() { node
msgs.push((msg, *from_addr, to_addr)); }
}
} pub fn simulate_next_message(&mut self) {
clean &= msgs.is_empty(); if let Some((src, dst, data)) = self.messages.pop_front() {
for (msg, from_addr, to_addr) in msgs { if let Some(node) = self.nodes.get_mut(&dst) {
for (ref mut node, ref addr) in nodes.iter_mut() { if node.socket().put_inbound(src, data) {
if *addr == to_addr { DebugLogger::set_node(node.get_num());
msg_put(node, from_addr, msg); node.trigger_socket_event();
break DebugLogger::set_node(0);
let sock = node.socket();
let src = dst;
while let Some((dst, data)) = sock.pop_outbound() {
self.messages.push_back((src, dst, data));
}
} }
} else {
warn!("Message to unknown node {}", dst);
} }
} }
} }
pub fn simulate_all_messages(&mut self) {
while !self.messages.is_empty() {
self.simulate_next_message()
}
}
pub fn trigger_node_housekeep(&mut self, addr: SocketAddr) {
let node = self.nodes.get_mut(&addr).unwrap();
DebugLogger::set_node(node.get_num());
node.trigger_housekeep();
DebugLogger::set_node(0);
let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() {
self.messages.push_back((addr, dst, data));
}
}
pub fn trigger_housekeep(&mut self) {
for (src, node) in &mut self.nodes {
DebugLogger::set_node(node.get_num());
node.trigger_housekeep();
DebugLogger::set_node(0);
let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() {
self.messages.push_back((*src, dst, data));
}
}
}
pub fn set_time(&mut self, time: Time) {
MockTimeSource::set_time(time);
}
pub fn simulate_time(&mut self, time: Time) {
let mut t = MockTimeSource::now();
while t < time {
t += 1;
self.set_time(t);
self.trigger_housekeep();
self.simulate_all_messages();
}
}
pub fn connect(&mut self, src: SocketAddr, dst: SocketAddr) {
let node = self.nodes.get_mut(&src).unwrap();
DebugLogger::set_node(node.get_num());
node.connect(dst).unwrap();
DebugLogger::set_node(0);
let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() {
self.messages.push_back((src, dst, data));
}
}
pub fn is_connected(&self, src: SocketAddr, dst: SocketAddr) -> bool {
self.nodes.get(&src).unwrap().is_connected(&dst)
}
#[allow(dead_code)]
pub fn node_addresses(&self) -> Vec<SocketAddr> {
self.nodes.keys().copied().collect()
}
#[allow(dead_code)]
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn put_payload(&mut self, addr: SocketAddr, data: Vec<u8>) {
let node = self.nodes.get_mut(&addr).unwrap();
node.device().put_inbound(data);
DebugLogger::set_node(node.get_num());
node.trigger_device_event();
DebugLogger::set_node(0);
let sock = node.socket();
while let Some((dst, data)) = sock.pop_outbound() {
self.messages.push_back((addr, dst, data));
}
}
pub fn pop_payload(&mut self, node: SocketAddr) -> Option<Vec<u8>> {
self.nodes.get_mut(&node).unwrap().device().pop_outbound()
}
pub fn drop_message(&mut self) {
self.messages.pop_front();
}
} }

View File

@ -6,82 +6,68 @@ use super::*;
#[test] #[test]
fn connect_nat_2_peers() { fn connect_nat_2_peers() {
init_debug_logger(); let config = Config { port_forwarding: false, ..Default::default() };
MockTimeSource::set_time(0); let mut sim = TapSimulator::new();
let mut node1 = create_tap_node(true); let node1 = sim.add_node(true, &config);
let node1_addr = addr!("1.2.3.4:5678"); let node2 = sim.add_node(true, &config);
let mut node2 = create_tap_node(false);
let node2_addr = addr!("2.3.4.5:6789");
node2.connect("1.2.3.4:5678").unwrap(); sim.connect(node1, node2);
sim.connect(node2, node1);
simulate!(node1 => node1_addr, node2 => node2_addr); sim.simulate_time(60);
assert!(!node1.peers().contains_node(&node2.node_id())); assert!(sim.is_connected(node1, node2));
assert!(!node2.peers().contains_node(&node1.node_id())); assert!(sim.is_connected(node2, node1));
node1.connect("2.3.4.5:6789").unwrap();
simulate!(node1 => node1_addr, node2 => node2_addr);
assert_connected!(node1, node2);
} }
#[test] #[test]
fn connect_nat_3_peers() { fn connect_nat_3_peers() {
init_debug_logger(); let config = Config::default();
MockTimeSource::set_time(0); let mut sim = TapSimulator::new();
let mut node1 = create_tap_node(true); let node1 = sim.add_node(true, &config);
let node1_addr = addr!("1.2.3.4:5678"); let node2 = sim.add_node(true, &config);
let mut node2 = create_tap_node(false); let node3 = sim.add_node(true, &config);
let node2_addr = addr!("2.3.4.5:6789");
let mut node3 = create_tap_node(false);
let node3_addr = addr!("3.4.5.6:7890");
node2.connect("1.2.3.4:5678").unwrap();
node3.connect("1.2.3.4:5678").unwrap();
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr);
assert!(!node1.peers().contains_node(&node2.node_id())); sim.connect(node1, node2);
assert!(!node2.peers().contains_node(&node1.node_id())); sim.connect(node2, node1);
assert!(!node3.peers().contains_node(&node1.node_id())); sim.connect(node1, node3);
assert!(!node3.peers().contains_node(&node2.node_id())); sim.connect(node3, node1);
assert!(!node1.peers().contains_node(&node3.node_id()));
assert!(!node2.peers().contains_node(&node3.node_id()));
node1.connect("3.4.5.6:7890").unwrap(); sim.simulate_time(300);
node2.connect("3.4.5.6:7890").unwrap(); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
simulate_time!(1000, node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); assert!(sim.is_connected(node1, node3));
assert!(sim.is_connected(node3, node1));
assert_connected!(node1, node3); assert!(sim.is_connected(node2, node3));
assert_connected!(node2, node3); assert!(sim.is_connected(node3, node2));
assert_connected!(node1, node2);
} }
#[test] #[test]
#[allow(clippy::cognitive_complexity)]
fn nat_keepalive() { fn nat_keepalive() {
init_debug_logger(); let config = Config::default();
MockTimeSource::set_time(0); let mut sim = TapSimulator::new();
let mut node1 = create_tap_node(true); let node1 = sim.add_node(true, &config);
let node1_addr = addr!("1.2.3.4:5678"); let node2 = sim.add_node(true, &config);
let mut node2 = create_tap_node(false); let node3 = sim.add_node(true, &config);
let node2_addr = addr!("2.3.4.5:6789");
let mut node3 = create_tap_node(false);
let node3_addr = addr!("3.4.5.6:7890");
node1.connect("3.4.5.6:7890").unwrap();
node2.connect("3.4.5.6:7890").unwrap();
simulate_time!(1000, node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); sim.connect(node1, node2);
sim.connect(node2, node1);
sim.connect(node1, node3);
sim.connect(node3, node1);
assert_connected!(node1, node3); sim.simulate_time(1000);
assert_connected!(node2, node3); assert!(sim.is_connected(node1, node2));
assert_connected!(node1, node2); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3));
assert!(sim.is_connected(node3, node1));
assert!(sim.is_connected(node2, node3));
assert!(sim.is_connected(node3, node2));
simulate_time!(10000, node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); sim.simulate_time(10000);
assert!(sim.is_connected(node1, node2));
assert_connected!(node1, node3); assert!(sim.is_connected(node2, node1));
assert_connected!(node2, node3); assert!(sim.is_connected(node1, node3));
assert_connected!(node1, node2); assert!(sim.is_connected(node3, node1));
assert!(sim.is_connected(node2, node3));
assert!(sim.is_connected(node3, node2));
} }

View File

@ -6,81 +6,80 @@ use super::*;
#[test] #[test]
fn ethernet_delivers() { fn ethernet_delivers() {
let mut node1 = create_tap_node(false); let config = Config { device_type: Type::Tap, ..Config::default() };
let node1_addr = addr!("1.2.3.4:5678"); let mut sim = TapSimulator::new();
let mut node2 = create_tap_node(false); let node1 = sim.add_node(false, &config);
let node2_addr = addr!("2.3.4.5:6789"); let node2 = sim.add_node(false, &config);
node1.connect("2.3.4.5:6789").unwrap(); sim.connect(node1, node2);
simulate!(node1 => node1_addr, node2 => node2_addr); sim.simulate_all_messages();
assert_connected!(node1, node2); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5]; let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5];
node1.device().put_inbound(payload.clone()); sim.put_payload(node1, payload.clone());
sim.simulate_all_messages();
simulate!(node1 => node1_addr, node2 => node2_addr); assert_eq!(Some(payload), sim.pop_payload(node2));
assert_eq!(Some(payload), node2.device().pop_outbound());
assert_clean!(node1, node2);
} }
#[test] #[test]
fn switch_learns() { fn switch_learns() {
let mut node1 = create_tap_node(false); let config = Config { device_type: Type::Tap, ..Config::default() };
let node1_addr = addr!("1.2.3.4:5678"); let mut sim = TapSimulator::new();
let mut node2 = create_tap_node(false); let node1 = sim.add_node(false, &config);
let node2_addr = addr!("2.3.4.5:6789"); let node2 = sim.add_node(false, &config);
let mut node3 = create_tap_node(false); let node3 = sim.add_node(false, &config);
let node3_addr = addr!("3.4.5.6:7890");
node1.connect("2.3.4.5:6789").unwrap(); sim.connect(node1, node2);
node1.connect("3.4.5.6:7890").unwrap(); sim.connect(node1, node3);
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); sim.connect(node2, node3);
assert_connected!(node1, node2, node3); sim.simulate_all_messages();
assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3));
assert!(sim.is_connected(node3, node1));
assert!(sim.is_connected(node2, node3));
assert!(sim.is_connected(node3, node2));
let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5]; let payload = vec![2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5];
// Nothing learnt so far, node1 broadcasts // Nothing learnt so far, node1 broadcasts
node1.device().put_inbound(payload.clone()); sim.put_payload(node1, payload.clone());
sim.simulate_all_messages();
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); assert_eq!(Some(payload.clone()), sim.pop_payload(node2));
assert_eq!(Some(payload), sim.pop_payload(node3));
assert_eq!(Some(&payload), node2.device().pop_outbound().as_ref());
assert_eq!(Some(&payload), node3.device().pop_outbound().as_ref());
let payload = vec![1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 5, 4, 3, 2, 1]; let payload = vec![1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 5, 4, 3, 2, 1];
// Node 2 learned the address by receiving it, does not broadcast // Node 2 learned the address by receiving it, does not broadcast
node2.device().put_inbound(payload.clone()); sim.put_payload(node2, payload.clone());
sim.simulate_all_messages();
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr); assert_eq!(Some(payload), sim.pop_payload(node1));
assert_eq!(None, sim.pop_payload(node3));
assert_eq!(Some(&payload), node1.device().pop_outbound().as_ref());
assert_clean!(node3);
assert_clean!(node1, node2, node3);
} }
#[test] #[test]
fn switch_honours_vlans() { fn switch_honours_vlans() {
// TODO // TODO Test
} }
#[test] #[test]
fn switch_forgets() { fn switch_forgets() {
// TODO // TODO Test
} }
#[test] #[test]
fn router_delivers() { fn router_delivers() {
// TODO // TODO Test
} }
#[test] #[test]
fn router_drops_unknown_dest() { fn router_drops_unknown_dest() {
// TODO // TODO Test
} }

View File

@ -5,209 +5,165 @@
use super::*; use super::*;
#[test] #[test]
#[allow(clippy::cognitive_complexity)] fn direct_connect() {
fn connect_v4() { let config = Config::default();
let mut node1 = create_tap_node(false); let mut sim = TapSimulator::new();
let node1_addr = addr!("1.2.3.4:5678"); let node1 = sim.add_node(false, &config);
let mut node2 = create_tap_node(false); let node2 = sim.add_node(false, &config);
let node2_addr = addr!("2.3.4.5:6789");
assert_clean!(node1, node2);
assert!(!node1.peers().contains_node(&node2.node_id()));
assert!(!node2.peers().contains_node(&node1.node_id()));
node1.connect("2.3.4.5:6789").unwrap(); sim.connect(node1, node2);
sim.simulate_all_messages();
// Node 1 -> Node 2: Init 0 assert!(sim.is_connected(node1, node2));
assert_message4!(node1, node1_addr, node2, node2_addr, Message::Init(0, node1.node_id(), vec![], 600)); assert!(sim.is_connected(node2, node1));
assert_clean!(node1);
assert!(node2.peers().contains_node(&node1.node_id()));
// Node 2 -> Node 1: Init 1 | Node 2 -> Node 1: Peers
assert_message4!(node2, node2_addr, node1, node1_addr, Message::Init(1, node2.node_id(), vec![], 600));
assert!(node1.peers().contains_node(&node2.node_id()));
assert_message4!(node2, node2_addr, node1, node1_addr, Message::Peers(vec![node1_addr]));
assert_clean!(node2);
// Node 1 -> Node 2: Peers | Node 1 -> Node 1: Init 0
assert_message4!(node1, node1_addr, node2, node2_addr, Message::Peers(vec![node2_addr]));
assert_message4!(node1, node1_addr, node1, node1_addr, Message::Init(0, node1.node_id(), vec![], 600));
assert!(node1.own_addresses().contains(&node1_addr));
assert_clean!(node1);
// Node 2 -> Node 2: Init 0
assert_message4!(node2, node2_addr, node2, node2_addr, Message::Init(0, node2.node_id(), vec![], 600));
assert_clean!(node2);
assert!(node2.own_addresses().contains(&node2_addr));
assert_connected!(node1, node2);
} }
#[test] #[test]
fn connect_v6() { fn direct_connect_unencrypted() {
let mut node1 = create_tap_node(false); let config = Config {
let node1_addr = addr!("[::1]:5678"); crypto: CryptoConfig { algorithms: vec!["plain".to_string()], ..CryptoConfig::default() },
let mut node2 = create_tap_node(false); ..Config::default()
let node2_addr = addr!("[::2]:6789"); };
let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config);
let node2 = sim.add_node(false, &config);
node1.connect("[::2]:6789").unwrap(); sim.connect(node1, node2);
sim.simulate_all_messages();
simulate!(node1 => node1_addr, node2 => node2_addr); assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
assert_connected!(node1, node2);
} }
#[test] #[test]
#[allow(clippy::cognitive_complexity)]
fn cross_connect() { fn cross_connect() {
let mut node1 = create_tap_node(false); let config = Config::default();
let node1_addr = addr!("1.1.1.1:1111"); let mut sim = TapSimulator::new();
let mut node2 = create_tap_node(false); let node1 = sim.add_node(false, &config);
let node2_addr = addr!("2.2.2.2:2222"); let node2 = sim.add_node(false, &config);
let mut node3 = create_tap_node(false); let node3 = sim.add_node(false, &config);
let node3_addr = addr!("3.3.3.3:3333");
let mut node4 = create_tap_node(false);
let node4_addr = addr!("4.4.4.4:4444");
node1.connect("2.2.2.2:2222").unwrap(); sim.connect(node1, node2);
node3.connect("4.4.4.4:4444").unwrap(); sim.connect(node1, node3);
sim.simulate_all_messages();
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr, node4 => node4_addr); sim.simulate_time(120);
assert_connected!(node1, node2); assert!(sim.is_connected(node1, node2));
assert_connected!(node3, node4); assert!(sim.is_connected(node2, node1));
assert!(sim.is_connected(node1, node3));
node1.connect("3.3.3.3:3333").unwrap(); assert!(sim.is_connected(node3, node1));
assert!(sim.is_connected(node2, node3));
simulate!(node1 => node1_addr, node2 => node2_addr, node3 => node3_addr, node4 => node4_addr); assert!(sim.is_connected(node3, node2));
// existing connections
assert_connected!(node1, node2);
assert_connected!(node3, node4);
// new connection
assert_connected!(node1, node3);
// transient connections 1st degree
assert_connected!(node1, node4);
assert_connected!(node3, node2);
// transient connections 2nd degree
assert_connected!(node2, node4);
} }
#[test] #[test]
fn connect_via_beacons() { fn connect_via_beacons() {
MockTimeSource::set_time(0); let mut sim = TapSimulator::new();
let beacon_path = "target/.vpncloud_test"; let beacon_path = "target/.vpncloud_test";
let mut node1 = let config1 = Config { beacon_store: Some(beacon_path.to_string()), ..Default::default() };
create_tap_node_with_config(false, Config { beacon_store: Some(beacon_path.to_string()), ..Config::default() }); let node1 = sim.add_node(false, &config1);
let node1_addr = node1.address().unwrap(); let config2 = Config { beacon_load: Some(beacon_path.to_string()), ..Default::default() };
let mut node2 = let node2 = sim.add_node(false, &config2);
create_tap_node_with_config(false, Config { beacon_load: Some(beacon_path.to_string()), ..Config::default() });
let node2_addr = addr!("2.2.2.2:2222");
assert!(!node1.peers().contains_node(&node2.node_id())); sim.set_time(100);
assert!(!node2.peers().contains_node(&node1.node_id())); sim.trigger_node_housekeep(node1);
sim.trigger_node_housekeep(node2);
sim.simulate_all_messages();
MockTimeSource::set_time(5000); assert!(sim.is_connected(node1, node2));
node1.trigger_housekeep(); assert!(sim.is_connected(node2, node1));
MockTimeSource::set_time(10000);
node2.trigger_housekeep();
simulate!(node1 => node1_addr, node2 => node2_addr);
assert_clean!(node1, node2);
assert_connected!(node1, node2);
} }
#[test] #[test]
fn reconnect_after_timeout() { fn reconnect_after_timeout() {
MockTimeSource::set_time(0); let config = Config::default();
let mut node1 = create_tap_node(false); let mut sim = TapSimulator::new();
let node1_addr = addr!("1.1.1.1:1111"); let node1 = sim.add_node(false, &config);
let mut node2 = create_tap_node(false); let node2 = sim.add_node(false, &config);
let node2_addr = addr!("2.2.2.2:2222");
node1.add_reconnect_peer("2.2.2.2:2222".to_string()); sim.connect(node1, node2);
node1.connect(node2_addr).unwrap(); sim.simulate_all_messages();
assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
simulate!(node1 => node1_addr, node2 => node2_addr); sim.set_time(5000);
sim.trigger_housekeep();
assert!(!sim.is_connected(node1, node2));
assert!(!sim.is_connected(node2, node1));
assert_connected!(node1, node2); sim.simulate_all_messages();
assert!(sim.is_connected(node1, node2));
MockTimeSource::set_time(5000); assert!(sim.is_connected(node2, node1));
node1.trigger_housekeep();
node2.trigger_housekeep();
assert!(!node1.peers().contains_node(&node2.node_id()));
assert!(!node2.peers().contains_node(&node1.node_id()));
simulate!(node1 => node1_addr, node2 => node2_addr);
assert_connected!(node1, node2);
} }
#[test] #[test]
fn lost_init1() { fn lost_init_ping() {
let mut node1 = create_tap_node(false); let config = Config::default();
let node1_addr = addr!("1.2.3.4:5678"); let mut sim = TapSimulator::new();
let mut node2 = create_tap_node(false); let node1 = sim.add_node(false, &config);
let node2_addr = addr!("2.3.4.5:6789"); let node2 = sim.add_node(false, &config);
node1.connect("2.3.4.5:6789").unwrap(); sim.connect(node1, node2);
sim.drop_message(); // drop init ping
// Node 1 -> Node 2: Init 0 sim.simulate_time(120);
assert_message4!(node1, node1_addr, node2, node2_addr, Message::Init(0, node1.node_id(), vec![], 600)); assert!(sim.is_connected(node1, node2));
assert_clean!(node1); assert!(sim.is_connected(node2, node1));
// Node 2 -> Node 1: Init 1 | Node 2 -> Node 1: Peers
assert!(node2.socket().pop_outbound().is_some());
assert!(!node1.peers().contains_node(&node2.node_id()));
simulate!(node1 => node1_addr, node2 => node2_addr);
assert_connected!(node1, node2);
} }
#[test] #[test]
fn wrong_magic() { fn lost_init_pong() {
let mut node1 = create_tap_node(false); let config = Config::default();
let node1_addr = addr!("1.2.3.4:5678"); let mut sim = TapSimulator::new();
let mut node2 = let node1 = sim.add_node(false, &config);
create_tap_node_with_config(false, Config { magic: Some("hash:different".to_string()), ..Config::default() }); let node2 = sim.add_node(false, &config);
let node2_addr = addr!("2.3.4.5:6789");
node1.connect("2.3.4.5:6789").unwrap();
assert_message4!(node1, node1_addr, node2, node2_addr, Message::Init(0, node1.node_id(), vec![], 600)); sim.connect(node1, node2);
sim.simulate_next_message(); // init ping
sim.drop_message(); // drop init pong
assert_clean!(node1, node2); sim.simulate_time(120);
assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
}
assert!(!node1.peers().contains_node(&node2.node_id())); #[test]
assert!(!node2.peers().contains_node(&node1.node_id())); fn lost_init_peng() {
let config = Config::default();
let mut sim = TapSimulator::new();
let node1 = sim.add_node(false, &config);
let node2 = sim.add_node(false, &config);
sim.connect(node1, node2);
sim.simulate_next_message(); // init ping
sim.simulate_next_message(); // init pong
sim.drop_message(); // drop init peng
sim.simulate_time(120);
assert!(sim.is_connected(node1, node2));
assert!(sim.is_connected(node2, node1));
} }
#[test] #[test]
fn peer_exchange() { fn peer_exchange() {
// TODO // TODO Test
} }
#[test] #[test]
fn lost_peer_exchange() { fn lost_peer_exchange() {
// TODO // TODO Test
} }
#[test] #[test]
fn remove_dead_peers() { fn remove_dead_peers() {
// TODO // TODO Test
} }
#[test] #[test]
fn update_primary_address() { fn update_primary_address() {
// TODO // TODO Test
} }
#[test] #[test]
fn automatic_peer_timeout() { fn automatic_peer_timeout() {
// TODO // TODO Test
} }

View File

@ -2,19 +2,22 @@
// Copyright (C) 2015-2020 Dennis Schwerdel // Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md) // This software is licensed under GPL-3 or newer (see LICENSE.md)
use crate::{
error::Error,
util::{bytes_to_hex, Encoder}
};
use byteorder::{ReadBytesExt, WriteBytesExt};
use smallvec::SmallVec;
use std::{ use std::{
fmt, fmt,
hash::{Hash, Hasher}, hash::{Hash, Hasher},
io::{self, Write}, io::{Read, Write},
net::{Ipv4Addr, Ipv6Addr, SocketAddr}, net::{Ipv4Addr, Ipv6Addr},
str::FromStr str::FromStr
}; };
use super::util::{bytes_to_hex, Encoder};
pub const NODE_ID_BYTES: usize = 16; pub const NODE_ID_BYTES: usize = 16;
pub type HeaderMagic = [u8; 4];
pub type NodeId = [u8; NODE_ID_BYTES]; pub type NodeId = [u8; NODE_ID_BYTES];
@ -26,35 +29,31 @@ pub struct Address {
impl Address { impl Address {
#[inline] #[inline]
pub fn read_from(data: &[u8]) -> Result<(Address, usize), Error> { pub fn read_from<R: Read>(mut r: R) -> Result<Address, Error> {
if data.is_empty() { let len = r.read_u8().map_err(|_| Error::Parse("Address too short"))?;
return Err(Error::Parse("Address too short")) Address::read_from_fixed(r, len)
}
let len = data[0] as usize;
let addr = Address::read_from_fixed(&data[1..], len)?;
Ok((addr, len + 1))
} }
#[inline] #[inline]
pub fn read_from_fixed(data: &[u8], len: usize) -> Result<Address, Error> { pub fn read_from_fixed<R: Read>(mut r: R, len: u8) -> Result<Address, Error> {
if len > 16 { if len > 16 {
return Err(Error::Parse("Invalid address, too long")) return Err(Error::Parse("Invalid address, too long"))
} }
if data.len() < len { let mut data = [0; 16];
return Err(Error::Parse("Address too short")) r.read_exact(&mut data[..len as usize]).map_err(|_| Error::Parse("Address too short"))?;
} Ok(Address { data, len })
let mut bytes = [0; 16];
bytes[0..len].copy_from_slice(&data[0..len]);
Ok(Address { data: bytes, len: len as u8 })
} }
#[inline] #[inline]
pub fn write_to(&self, data: &mut [u8]) -> usize { pub fn write_to<W: Write>(&self, mut w: W) {
assert!(data.len() > self.len as usize); w.write_u8(self.len).expect("Buffer too small");
data[0] = self.len; w.write_all(&self.data[..self.len as usize]).expect("Buffer too small");
let len = self.len as usize; }
data[1..=len].copy_from_slice(&self.data[0..len]);
self.len as usize + 1 pub fn from_ipv4(ip: Ipv4Addr) -> Self {
let mut data = [0; 16];
data[0..4].copy_from_slice(&ip.octets());
Self { data, len: 4 }
} }
} }
@ -70,7 +69,7 @@ impl PartialEq for Address {
impl Hash for Address { impl Hash for Address {
#[inline] #[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) { fn hash<H: Hasher>(&self, hasher: &mut H) {
hasher.write(&self.data[0..self.len as usize]) hasher.write(&self.data[..self.len as usize])
} }
} }
@ -89,7 +88,7 @@ impl fmt::Display for Address {
}, },
16 => write!(formatter, "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}", 16 => write!(formatter, "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15]), d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15]),
_ => write!(formatter, "{}", bytes_to_hex(&d[0..self.len as usize])) _ => write!(formatter, "{}", bytes_to_hex(&d[..self.len as usize]))
} }
} }
} }
@ -138,23 +137,35 @@ pub struct Range {
pub prefix_len: u8 pub prefix_len: u8
} }
pub type RangeList = SmallVec<[Range; 4]>;
impl Range { impl Range {
#[inline] pub fn matches(&self, addr: Address) -> bool {
pub fn read_from(data: &[u8]) -> Result<(Range, usize), Error> { if self.base.len != addr.len {
let (address, read) = Address::read_from(data)?; return false
if data.len() < read + 1 {
return Err(Error::Parse("Range too short"))
} }
let prefix_len = data[read]; let mut match_len = 0;
Ok((Range { base: address, prefix_len }, read + 1)) for i in 0..addr.len as usize {
let m = addr.data[i] ^ self.base.data[i];
match_len += m.leading_zeros() as u8;
if m != 0 {
break
}
}
match_len >= self.prefix_len
} }
#[inline] #[inline]
pub fn write_to(&self, data: &mut [u8]) -> usize { pub fn read_from<R: Read>(mut r: R) -> Result<Range, Error> {
let pos = self.base.write_to(data); let base = Address::read_from(&mut r)?;
assert!(data.len() > pos); let prefix_len = r.read_u8().map_err(|_| Error::Parse("Address too short"))?;
data[pos] = self.prefix_len; Ok(Range { base, prefix_len })
pos + 1 }
#[inline]
pub fn write_to<W: Write>(&self, mut w: W) {
self.base.write_to(&mut w);
w.write_u8(self.prefix_len).expect("Buffer too small")
} }
} }
@ -220,115 +231,81 @@ impl FromStr for Mode {
} }
} }
pub trait Table {
fn learn(&mut self, _: Address, _: Option<u8>, _: NodeId, _: SocketAddr); #[cfg(test)]
fn lookup(&mut self, _: &Address) -> Option<SocketAddr>; mod tests {
fn housekeep(&mut self);
fn write_out<W: Write>(&self, out: &mut W) -> Result<(), io::Error>; use super::*;
fn remove(&mut self, _: &Address) -> bool;
fn remove_all(&mut self, _: &SocketAddr); use std::io::Cursor;
fn len(&self) -> usize;
fn is_empty(&self) -> bool { #[test]
self.len() == 0 fn address_parse_fmt() {
assert_eq!(format!("{}", Address::from_str("120.45.22.5").unwrap()), "120.45.22.5");
assert_eq!(format!("{}", Address::from_str("78:2d:16:05:01:02").unwrap()), "78:2d:16:05:01:02");
assert_eq!(
format!("{}", Address { data: [3, 56, 120, 45, 22, 5, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 }),
"vlan824/78:2d:16:05:01:02"
);
assert_eq!(
format!("{}", Address::from_str("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f").unwrap()),
"0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"
);
assert_eq!(format!("{:?}", Address { data: [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 2 }), "0102");
assert!(Address::from_str("").is_err()); // Failed to parse address
}
#[test]
fn address_decode_encode() {
let mut buf = vec![];
let addr = Address::from_str("120.45.22.5").unwrap();
addr.write_to(Cursor::new(&mut buf));
assert_eq!(&buf[0..5], &[4, 120, 45, 22, 5]);
assert_eq!(addr, Address::read_from(Cursor::new(&buf)).unwrap());
assert_eq!(addr, Address::read_from_fixed(Cursor::new(&buf[1..]), 4).unwrap());
buf.clear();
let addr = Address::from_str("78:2d:16:05:01:02").unwrap();
addr.write_to(Cursor::new(&mut buf));
assert_eq!(&buf[0..7], &[6, 0x78, 0x2d, 0x16, 0x05, 0x01, 0x02]);
assert_eq!(addr, Address::read_from(Cursor::new(&buf)).unwrap());
assert_eq!(addr, Address::read_from_fixed(Cursor::new(&buf[1..]), 6).unwrap());
assert!(Address::read_from(Cursor::new(&buf[0..1])).is_err()); // Address too short
buf[0] = 100;
assert!(Address::read_from(Cursor::new(&buf)).is_err()); // Invalid address, too long
buf[0] = 5;
assert!(Address::read_from(Cursor::new(&buf[0..4])).is_err()); // Address too short
}
#[test]
fn address_eq() {
assert_eq!(
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap(),
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap()
);
assert_ne!(
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap(),
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 5]), 4).unwrap()
);
assert_eq!(
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 3).unwrap(),
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 5]), 3).unwrap()
);
assert_ne!(
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 3).unwrap(),
Address::read_from_fixed(Cursor::new(&[1, 2, 3, 4]), 4).unwrap()
);
}
#[test]
fn address_range_decode_encode() {
let mut buf = vec![];
let range =
Range { base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }, prefix_len: 24 };
range.write_to(Cursor::new(&mut buf));
assert_eq!(&buf[0..6], &[4, 0, 1, 2, 3, 24]);
assert_eq!(range, Range::read_from(Cursor::new(&buf)).unwrap());
assert!(Range::read_from(Cursor::new(&buf[..5])).is_err()); // Missing prefix length
buf[0] = 17;
assert!(Range::read_from(Cursor::new(&buf)).is_err());
} }
} }
pub trait Protocol: Sized {
fn parse(_: &[u8]) -> Result<(Address, Address), Error>;
}
#[derive(Debug)]
pub enum Error {
Parse(&'static str),
WrongHeaderMagic(HeaderMagic),
Socket(&'static str, io::Error),
Name(String),
TunTapDev(&'static str, io::Error),
Crypto(&'static str),
File(&'static str, io::Error),
Beacon(&'static str, io::Error)
}
impl fmt::Display for Error {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Error::Parse(msg) => write!(formatter, "{}", msg),
Error::Socket(msg, ref err) => write!(formatter, "{}: {:?}", msg, err),
Error::TunTapDev(msg, ref err) => write!(formatter, "{}: {:?}", msg, err),
Error::Crypto(msg) => write!(formatter, "{}", msg),
Error::Name(ref name) => write!(formatter, "failed to resolve name '{}'", name),
Error::WrongHeaderMagic(net) => write!(formatter, "wrong header magic: {}", bytes_to_hex(&net)),
Error::File(msg, ref err) => write!(formatter, "{}: {:?}", msg, err),
Error::Beacon(msg, ref err) => write!(formatter, "{}: {:?}", msg, err)
}
}
}
#[test]
fn address_parse_fmt() {
assert_eq!(format!("{}", Address::from_str("120.45.22.5").unwrap()), "120.45.22.5");
assert_eq!(format!("{}", Address::from_str("78:2d:16:05:01:02").unwrap()), "78:2d:16:05:01:02");
assert_eq!(
format!("{}", Address { data: [3, 56, 120, 45, 22, 5, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0], len: 8 }),
"vlan824/78:2d:16:05:01:02"
);
assert_eq!(
format!("{}", Address::from_str("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f").unwrap()),
"0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"
);
assert_eq!(format!("{:?}", Address { data: [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 2 }), "0102");
assert!(Address::from_str("").is_err()); // Failed to parse address
}
#[test]
fn address_decode_encode() {
let mut buf = [0; 32];
let addr = Address::from_str("120.45.22.5").unwrap();
assert_eq!(addr.write_to(&mut buf), 5);
assert_eq!(&buf[0..5], &[4, 120, 45, 22, 5]);
assert_eq!((addr, 5), Address::read_from(&buf).unwrap());
assert_eq!(addr, Address::read_from_fixed(&buf[1..], 4).unwrap());
let addr = Address::from_str("78:2d:16:05:01:02").unwrap();
assert_eq!(addr.write_to(&mut buf), 7);
assert_eq!(&buf[0..7], &[6, 0x78, 0x2d, 0x16, 0x05, 0x01, 0x02]);
assert_eq!((addr, 7), Address::read_from(&buf).unwrap());
assert_eq!(addr, Address::read_from_fixed(&buf[1..], 6).unwrap());
assert!(Address::read_from(&buf[0..1]).is_err()); // Address too short
buf[0] = 100;
assert!(Address::read_from(&buf).is_err()); // Invalid address, too long
buf[0] = 5;
assert!(Address::read_from(&buf[0..4]).is_err()); // Address too short
}
#[test]
fn address_eq() {
assert_eq!(
Address::read_from_fixed(&[1, 2, 3, 4], 4).unwrap(),
Address::read_from_fixed(&[1, 2, 3, 4], 4).unwrap()
);
assert_ne!(
Address::read_from_fixed(&[1, 2, 3, 4], 4).unwrap(),
Address::read_from_fixed(&[1, 2, 3, 5], 4).unwrap()
);
assert_eq!(
Address::read_from_fixed(&[1, 2, 3, 4], 3).unwrap(),
Address::read_from_fixed(&[1, 2, 3, 5], 3).unwrap()
);
assert_ne!(
Address::read_from_fixed(&[1, 2, 3, 4], 3).unwrap(),
Address::read_from_fixed(&[1, 2, 3, 4], 4).unwrap()
);
}
#[test]
fn address_range_decode_encode() {
let mut buf = [0; 32];
let range =
Range { base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }, prefix_len: 24 };
assert_eq!(range.write_to(&mut buf), 6);
assert_eq!(&buf[0..6], &[4, 0, 1, 2, 3, 24]);
assert_eq!((range, 6), Range::read_from(&buf).unwrap());
assert!(Range::read_from(&buf[..5]).is_err()); // Missing prefix length
buf[0] = 17;
assert!(Range::read_from(&buf).is_err());
}

View File

@ -1,516 +0,0 @@
// VpnCloud - Peer-to-Peer VPN
// Copyright (C) 2015-2020 Dennis Schwerdel
// This software is licensed under GPL-3 or newer (see LICENSE.md)
use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}
};
use super::{
config::DEFAULT_PEER_TIMEOUT,
crypto::Crypto,
types::{Error, HeaderMagic, NodeId, Range, NODE_ID_BYTES},
util::{bytes_to_hex, Encoder}
};
#[derive(Clone, Copy, Default)]
#[repr(packed)]
struct TopHeader {
magic: HeaderMagic,
crypto_method: u8,
_reserved1: u8,
_reserved2: u8,
msgtype: u8
}
impl TopHeader {
#[inline]
pub fn size() -> usize {
8
}
pub fn read_from(data: &[u8]) -> Result<(TopHeader, usize), Error> {
if data.len() < TopHeader::size() {
return Err(Error::Parse("Empty message"))
}
let mut header = TopHeader::default();
header.magic.copy_from_slice(&data[0..4]);
header.crypto_method = data[4];
header.msgtype = data[7];
Ok((header, TopHeader::size()))
}
#[allow(unknown_lints, clippy::trivially_copy_pass_by_ref)]
pub fn write_to(&self, data: &mut [u8]) -> usize {
assert!(data.len() >= 8);
data[0..4].copy_from_slice(&self.magic);
data[4] = self.crypto_method;
data[5] = 0;
data[6] = 0;
data[7] = self.msgtype;
TopHeader::size()
}
}
pub enum Message<'a> {
Data(&'a mut [u8], usize, usize), // data, start, end
Peers(Vec<SocketAddr>), // peers
Init(u8, NodeId, Vec<Range>, u16), // step, node_id, ranges
Close
}
impl<'a> Message<'a> {
pub fn without_data(self) -> Message<'static> {
match self {
Message::Data(_, start, end) => Message::Data(&mut [], start, end),
Message::Peers(peers) => Message::Peers(peers),
Message::Init(step, node_id, ranges, timeout) => Message::Init(step, node_id, ranges, timeout),
Message::Close => Message::Close
}
}
}
impl<'a> fmt::Debug for Message<'a> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Message::Data(_, start, end) => write!(formatter, "Data({} bytes)", end - start),
Message::Peers(ref peers) => {
write!(formatter, "Peers [")?;
let mut first = true;
for p in peers {
if !first {
write!(formatter, ", ")?;
}
first = false;
write!(formatter, "{}", p)?;
}
write!(formatter, "]")
}
Message::Init(stage, ref node_id, ref peers, ref peer_timeout) => {
write!(
formatter,
"Init(stage={}, node_id={}, peer_timeout={}, {:?})",
stage,
bytes_to_hex(node_id),
peer_timeout,
peers
)
}
Message::Close => write!(formatter, "Close")
}
}
}
#[allow(unknown_lints, clippy::needless_range_loop)]
pub fn decode<'a>(data: &'a mut [u8], magic: HeaderMagic, crypto: &Crypto) -> Result<Message<'a>, Error> {
let mut end = data.len();
let (header, mut pos) = TopHeader::read_from(&data[..end])?;
if header.magic != magic {
return Err(Error::WrongHeaderMagic(header.magic))
}
if header.crypto_method != crypto.method() {
return Err(Error::Crypto("Wrong crypto method"))
}
if crypto.method() > 0 {
let len = crypto.nonce_bytes();
if end < pos + len {
return Err(Error::Parse("Truncated crypto header"))
}
{
let (before, after) = data.split_at_mut(pos);
let (nonce, crypto_data) = after.split_at_mut(len);
pos += len;
end = crypto.decrypt(crypto_data, nonce, &before[..TopHeader::size()])? + pos;
}
assert_eq!(end, data.len() - crypto.additional_bytes());
}
let msg = match header.msgtype {
0 => Message::Data(data, pos, end),
1 => {
if end < pos + 1 {
return Err(Error::Parse("Missing IPv4 count"))
}
let mut peers = Vec::new();
let count = data[pos];
pos += 1;
let len = count as usize * 6;
if end < pos + len {
return Err(Error::Parse("IPv4 peer data too short"))
}
for _ in 0..count {
let ip = &data[pos..];
assert!(ip.len() >= 4);
pos += 4;
let port = Encoder::read_u16(&data[pos..]);
pos += 2;
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), port));
peers.push(addr);
}
if end < pos + 1 {
return Err(Error::Parse("Missing IPv6 count"))
}
let count = data[pos];
pos += 1;
let len = count as usize * 18;
if end < pos + len {
return Err(Error::Parse("IPv6 peer data too short"))
}
for _ in 0..count {
let mut ip = [0u16; 8];
for i in 0..8 {
ip[i] = Encoder::read_u16(&data[pos..]);
pos += 2;
}
let port = Encoder::read_u16(&data[pos..]);
pos += 2;
let addr = SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::new(ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7]),
port,
0,
0
));
peers.push(addr);
}
Message::Peers(peers)
}
2 => {
if end < pos + 2 + NODE_ID_BYTES {
return Err(Error::Parse("Init data too short"))
}
let stage = data[pos];
pos += 1;
let mut node_id = [0; NODE_ID_BYTES];
node_id.copy_from_slice(&data[pos..pos + NODE_ID_BYTES]);
pos += NODE_ID_BYTES;
let count = data[pos] as usize;
pos += 1;
let mut addrs = Vec::with_capacity(count);
for _ in 0..count {
let (range, read) = Range::read_from(&data[pos..end])?;
pos += read;
addrs.push(range);
}
let peer_timeout =
if data.len() >= pos + 2 { Encoder::read_u16(&data[pos..]) } else { DEFAULT_PEER_TIMEOUT };
Message::Init(stage, node_id, addrs, peer_timeout)
}
3 => Message::Close,
_ => return Err(Error::Parse("Unknown message type"))
};
Ok(msg)
}
#[allow(unknown_lints, clippy::needless_range_loop)]
pub fn encode<'a>(
msg: &'a mut Message, mut buf: &'a mut [u8], magic: HeaderMagic, crypto: &mut Crypto
) -> &'a mut [u8] {
let header_type = match msg {
Message::Data(_, _, _) => 0,
Message::Peers(_) => 1,
Message::Init(_, _, _, _) => 2,
Message::Close => 3
};
let mut start = 64;
let mut end = 64;
match *msg {
Message::Data(ref mut data, data_start, data_end) => {
buf = data;
start = data_start;
end = data_end;
}
Message::Peers(ref peers) => {
let mut v4addrs = Vec::new();
let mut v6addrs = Vec::new();
for p in peers {
match *p {
SocketAddr::V4(addr) => v4addrs.push(addr),
SocketAddr::V6(addr) => v6addrs.push(addr)
}
}
assert!(v4addrs.len() <= 255);
assert!(v6addrs.len() <= 255);
let mut pos = start;
assert!(buf.len() >= pos + 2 + v4addrs.len() * 6 + v6addrs.len() * 18);
buf[pos] = v4addrs.len() as u8;
pos += 1;
for addr in v4addrs {
let ip = addr.ip().octets();
buf[pos..pos + 4].copy_from_slice(&ip);
pos += 4;
Encoder::write_u16(addr.port(), &mut buf[pos..]);
pos += 2;
}
buf[pos] = v6addrs.len() as u8;
pos += 1;
for addr in v6addrs {
let ip = addr.ip().segments();
for i in 0..8 {
Encoder::write_u16(ip[i], &mut buf[pos..]);
pos += 2;
}
Encoder::write_u16(addr.port(), &mut buf[pos..]);
pos += 2;
}
end = pos;
}
Message::Init(stage, ref node_id, ref ranges, peer_timeout) => {
let mut pos = start;
assert!(buf.len() >= pos + 2 + NODE_ID_BYTES);
buf[pos] = stage;
pos += 1;
buf[pos..pos + NODE_ID_BYTES].copy_from_slice(node_id);
pos += NODE_ID_BYTES;
assert!(ranges.len() <= 255);
buf[pos] = ranges.len() as u8;
pos += 1;
for range in ranges {
pos += range.write_to(&mut buf[pos..]);
}
Encoder::write_u16(peer_timeout, &mut buf[pos..]);
pos += 2;
end = pos;
}
Message::Close => {}
}
assert!(start >= 64);
assert!(buf.len() >= end + 64);
let crypto_start = start;
start -= crypto.nonce_bytes();
let mut header = TopHeader::default();
header.magic = magic;
header.msgtype = header_type;
header.crypto_method = crypto.method();
start -= TopHeader::size();
header.write_to(&mut buf[start..]);
if crypto.method() > 0 {
let (junk_before, rest) = buf.split_at_mut(start);
let (header, rest) = rest.split_at_mut(TopHeader::size());
let (nonce, rest) = rest.split_at_mut(crypto.nonce_bytes());
debug_assert_eq!(junk_before.len() + header.len() + crypto.nonce_bytes(), crypto_start);
assert!(rest.len() >= end - crypto_start + crypto.additional_bytes());
end = crypto.encrypt(rest, end - crypto_start, nonce, header) + crypto_start;
}
&mut buf[start..end]
}
impl<'a> PartialEq for Message<'a> {
fn eq(&self, other: &Message) -> bool {
match *self {
Message::Data(ref data1, start1, end1) => {
if let Message::Data(ref data2, start2, end2) = *other {
data1[start1..end1] == data2[start2..end2]
} else {
false
}
}
Message::Peers(ref peers1) => {
if let Message::Peers(ref peers2) = *other {
peers1 == peers2
} else {
false
}
}
Message::Init(step1, node_id1, ref ranges1, peer_timeout1) => {
if let Message::Init(step2, node_id2, ref ranges2, peer_timeout2) = *other {
step1 == step2 && node_id1 == node_id2 && ranges1 == ranges2 && peer_timeout1 == peer_timeout2
} else {
false
}
}
Message::Close => {
if let Message::Close = *other {
true
} else {
false
}
}
}
}
}
#[cfg(test)] use super::crypto::CryptoMethod;
#[cfg(test)] use super::types::Address;
#[cfg(test)] use super::MAGIC;
#[cfg(test)] use std::str::FromStr;
#[test]
#[allow(unused_assignments)]
fn udpmessage_packet() {
let mut crypto = Crypto::None;
let mut payload = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
];
let mut msg = Message::Data(&mut payload, 64, 69);
let mut buf = [0; 1024];
let mut len = 0;
{
let res = encode(&mut msg, &mut [], MAGIC, &mut crypto);
assert_eq!(res.len(), 13);
assert_eq!(&res[..8], &[118, 112, 110, 1, 0, 0, 0, 0]);
buf[..res.len()].clone_from_slice(&res);
len = res.len();
}
let msg2 = decode(&mut buf[..len], MAGIC, &crypto).unwrap();
assert_eq!(msg, msg2);
}
#[test]
#[allow(unused_assignments)]
fn udpmessage_encrypted() {
let mut crypto = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test");
let mut payload = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
];
let mut orig_payload = [0; 133];
orig_payload[..payload.len()].clone_from_slice(&payload);
let orig_msg = Message::Data(&mut orig_payload, 64, 69);
let mut msg = Message::Data(&mut payload, 64, 69);
let mut buf = [0; 1024];
let mut len = 0;
{
let res = encode(&mut msg, &mut [], MAGIC, &mut crypto);
assert_eq!(res.len(), 41);
assert_eq!(&res[..8], &[118, 112, 110, 1, 1, 0, 0, 0]);
buf[..res.len()].clone_from_slice(&res);
len = res.len();
}
let msg2 = decode(&mut buf[..len], MAGIC, &crypto).unwrap();
assert_eq!(orig_msg, msg2);
}
#[test]
fn udpmessage_peers() {
use std::str::FromStr;
let mut crypto = Crypto::None;
let mut msg = Message::Peers(vec![
SocketAddr::from_str("1.2.3.4:123").unwrap(),
SocketAddr::from_str("5.6.7.8:12345").unwrap(),
SocketAddr::from_str("[0001:0203:0405:0607:0809:0a0b:0c0d:0e0f]:6789").unwrap(),
]);
let mut should = [
118, 112, 110, 1, 0, 0, 0, 1, 2, 1, 2, 3, 4, 0, 123, 5, 6, 7, 8, 48, 57, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 26, 133
];
{
let mut buf = [0; 1024];
let res = encode(&mut msg, &mut buf[..], MAGIC, &mut crypto);
assert_eq!(res.len(), 40);
for i in 0..res.len() {
assert_eq!(res[i], should[i]);
}
}
let msg2 = decode(&mut should, MAGIC, &crypto).unwrap();
assert_eq!(msg, msg2);
// Missing IPv4 count
assert!(decode(&mut [118, 112, 110, 1, 0, 0, 0, 1], MAGIC, &crypto).is_err());
// Truncated IPv4
assert!(decode(&mut [118, 112, 110, 1, 0, 0, 0, 1, 1], MAGIC, &crypto).is_err());
// Missing IPv6 count
assert!(decode(&mut [118, 112, 110, 1, 0, 0, 0, 1, 1, 1, 2, 3, 4, 0, 0], MAGIC, &crypto).is_err());
// Truncated IPv6
assert!(decode(&mut [118, 112, 110, 1, 0, 0, 0, 1, 1, 1, 2, 3, 4, 0, 0, 1], MAGIC, &crypto).is_err());
}
#[test]
fn udpmessage_init() {
use super::types::Address;
let mut crypto = Crypto::None;
let addrs = vec![
Range { base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 }, prefix_len: 24 },
Range { base: Address { data: [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 }, prefix_len: 16 },
];
let node_id = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let mut msg = Message::Init(0, node_id, addrs, 1800);
let mut should = [
118, 112, 110, 1, 0, 0, 0, 2, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2, 4, 0, 1, 2, 3, 24, 6,
0, 1, 2, 3, 4, 5, 16, 7, 8
];
{
let mut buf = [0; 1024];
let res = encode(&mut msg, &mut buf[..], MAGIC, &mut crypto);
assert_eq!(res.len(), 42);
for i in 0..res.len() {
assert_eq!(res[i], should[i]);
}
}
let msg2 = decode(&mut should, MAGIC, &crypto).unwrap();
assert_eq!(msg, msg2);
}
#[test]
fn udpmessage_close() {
let mut crypto = Crypto::None;
let mut msg = Message::Close;
let mut should = [118, 112, 110, 1, 0, 0, 0, 3];
{
let mut buf = [0; 1024];
let res = encode(&mut msg, &mut buf[..], MAGIC, &mut crypto);
assert_eq!(res.len(), 8);
assert_eq!(&res, &should);
}
let msg2 = decode(&mut should, MAGIC, &crypto).unwrap();
assert_eq!(msg, msg2);
}
#[test]
fn udpmessage_invalid() {
let crypto = Crypto::None;
assert!(decode(&mut [0x76, 0x70, 0x6e, 1, 0, 0, 0, 0], MAGIC, &crypto).is_ok());
// too short
assert!(decode(&mut [], MAGIC, &crypto).is_err());
// invalid protocol
assert!(decode(&mut [0, 1, 2, 0, 0, 0, 0, 0], MAGIC, &crypto).is_err());
// invalid version
assert!(decode(&mut [0x76, 0x70, 0x6e, 0xaa, 0, 0, 0, 0], MAGIC, &crypto).is_err());
// invalid crypto
assert!(decode(&mut [0x76, 0x70, 0x6e, 1, 0xaa, 0, 0, 0], MAGIC, &crypto).is_err());
// invalid msg type
assert!(decode(&mut [0x76, 0x70, 0x6e, 1, 0, 0, 0, 0xaa], MAGIC, &crypto).is_err());
}
#[test]
fn udpmessage_invalid_crypto() {
let crypto = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test");
// truncated crypto
assert!(decode(&mut [0x76, 0x70, 0x6e, 1, 1, 0, 0, 0], MAGIC, &crypto).is_err());
}
#[test]
fn message_fmt() {
assert_eq!(format!("{:?}", Message::Data(&mut [1, 2, 3, 4, 5], 0, 5)), "Data(5 bytes)");
assert_eq!(
format!(
"{:?}",
Message::Peers(vec![
SocketAddr::from_str("1.2.3.4:123").unwrap(),
SocketAddr::from_str("5.6.7.8:12345").unwrap(),
SocketAddr::from_str("[0001:0203:0405:0607:0809:0a0b:0c0d:0e0f]:6789").unwrap()
])
),
"Peers [1.2.3.4:123, 5.6.7.8:12345, [1:203:405:607:809:a0b:c0d:e0f]:6789]"
);
assert_eq!(
format!(
"{:?}",
Message::Init(0, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], vec![
Range {
base: Address { data: [0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 4 },
prefix_len: 24
},
Range {
base: Address { data: [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], len: 6 },
prefix_len: 16
}
], 1800)
),
"Init(stage=0, node_id=000102030405060708090a0b0c0d0e0f, peer_timeout=1800, [0.1.2.3/24, 00:01:02:03:04:05/16])"
);
assert_eq!(format!("{:?}", Message::Close), "Close");
}

View File

@ -8,7 +8,7 @@ use std::{
sync::atomic::{AtomicIsize, Ordering} sync::atomic::{AtomicIsize, Ordering}
}; };
use super::types::Error; use crate::error::Error;
#[cfg(not(target_os = "linux"))] use time; #[cfg(not(target_os = "linux"))] use time;
@ -20,6 +20,84 @@ pub type Duration = u32;
pub type Time = i64; pub type Time = i64;
#[derive(Clone)]
pub struct MsgBuffer {
space_before: usize,
buffer: [u8; 65535],
start: usize,
end: usize
}
impl MsgBuffer {
pub fn new(space_before: usize) -> Self {
Self { buffer: [0; 65535], space_before, start: space_before, end: space_before }
}
pub fn get_start(&self) -> usize {
self.start
}
pub fn set_start(&mut self, start: usize) {
self.start = start
}
pub fn prepend_byte(&mut self, byte: u8) {
self.start -= 1;
self.buffer[self.start] = byte
}
pub fn take_prefix(&mut self) -> u8 {
let byte = self.buffer[self.start];
self.start += 1;
byte
}
pub fn buffer(&mut self) -> &mut [u8] {
&mut self.buffer[self.start..]
}
pub fn message(&self) -> &[u8] {
&self.buffer[self.start..self.end]
}
pub fn take(&mut self) -> Option<&[u8]> {
if self.start != self.end {
let end = self.end;
self.end = self.start;
Some(&self.buffer[self.start..end])
} else {
None
}
}
pub fn message_mut(&mut self) -> &mut [u8] {
&mut self.buffer[self.start..self.end]
}
pub fn set_length(&mut self, length: usize) {
self.end = self.start + length
}
pub fn clone_from(&mut self, other: &[u8]) {
self.set_length(other.len());
self.message_mut().clone_from_slice(other);
}
pub fn len(&self) -> usize {
self.end - self.start
}
pub fn is_empty(&self) -> bool {
self.start == self.end
}
pub fn clear(&mut self) {
self.set_start(self.space_before);
self.set_length(0)
}
}
const HEX_CHARS: &[u8] = b"0123456789abcdef"; const HEX_CHARS: &[u8] = b"0123456789abcdef";
pub fn bytes_to_hex(bytes: &[u8]) -> String { pub fn bytes_to_hex(bytes: &[u8]) -> String {
@ -139,7 +217,7 @@ pub fn get_internal_ip() -> Ipv4Addr {
#[allow(unknown_lints, clippy::needless_pass_by_value)] #[allow(unknown_lints, clippy::needless_pass_by_value)]
pub fn resolve<Addr: ToSocketAddrs + fmt::Debug>(addr: Addr) -> Result<Vec<SocketAddr>, Error> { pub fn resolve<Addr: ToSocketAddrs + fmt::Debug>(addr: Addr) -> Result<Vec<SocketAddr>, Error> {
let addrs = addr.to_socket_addrs().map_err(|_| Error::Name(format!("{:?}", addr)))?; let addrs = addr.to_socket_addrs().map_err(|_| Error::NameUnresolvable(format!("{:?}", addr)))?;
// Remove duplicates in addrs (why are there duplicates???) // Remove duplicates in addrs (why are there duplicates???)
let mut addrs = addrs.collect::<Vec<_>>(); let mut addrs = addrs.collect::<Vec<_>>();
// Try IPv4 first as it usually is faster // Try IPv4 first as it usually is faster

View File

@ -182,7 +182,7 @@ in 3 different modes:
*Switch mode*:: In this mode, the VPN will dynamically learn addresses *Switch mode*:: In this mode, the VPN will dynamically learn addresses
as they are used as source addresses and use them to forward data to its as they are used as source addresses and use them to forward data to its
destination. Addresses that have not been seen for some time destination. Addresses that have not been seen for some time
(option *dst_timeout*) will be forgotten. Data for unknown addresses will be (option *switch_timeout*) will be forgotten. Data for unknown addresses will be
broadcast to all peers. This mode is the default mode for TAP devices that broadcast to all peers. This mode is the default mode for TAP devices that
process Ethernet frames but it can also be used with TUN devices and IP process Ethernet frames but it can also be used with TUN devices and IP
packets. packets.
@ -295,7 +295,7 @@ detailed descriptions of the options.
*beacon_load*:: Path or command to load beacons. Same as *--beacon-load* *beacon_load*:: Path or command to load beacons. Same as *--beacon-load*
*beacon_interval*:: Interval for loading and storing beacons in seconds. Same as *--beacon-interval* *beacon_interval*:: Interval for loading and storing beacons in seconds. Same as *--beacon-interval*
*mode*:: The mode of the VPN. Same as *--mode* *mode*:: The mode of the VPN. Same as *--mode*
*dst_timeout*:: Switch table entry timeout in seconds. Same as *--dst-timeout* *switch_timeout*:: Switch table entry timeout in seconds. Same as *--dst-timeout*
*subnets*:: A list of local subnets to use. See *--subnet* *subnets*:: A list of local subnets to use. See *--subnet*
*port_forwarding*:: Whether to activate port forwardig. See *--no-port-forwarding* *port_forwarding*:: Whether to activate port forwardig. See *--no-port-forwarding*
*user*:: The name of a user to run the background process under. Same as *--user* *user*:: The name of a user to run the background process under. Same as *--user*