From 0bcdfbe21ece95a0253c031826b181630b53968c Mon Sep 17 00:00:00 2001 From: Dennis Schwerdel Date: Wed, 6 Jul 2016 21:14:09 +0200 Subject: [PATCH] More tests, bugfixes and code cleanup --- CHANGELOG.md | 2 ++ src/cloud.rs | 10 +++++----- src/crypto.rs | 4 ++-- src/device.rs | 6 +++--- src/ethernet.rs | 4 ++-- src/ip.rs | 8 ++++---- src/tests.rs | 39 ++++++++++++++++++++++++++++++++++----- src/types.rs | 44 +++++++++++++++++--------------------------- src/udpmessage.rs | 25 ++++++++++++++----------- src/util.rs | 2 +- 10 files changed, 84 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5640d78..579cde4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This project follows [semantic versioning](http://semver.org). ### UNRELEASED +- [added] Added more tests - [added] Added pluggable polling system - [added] Added documentation - [changed] Code cleanup @@ -14,6 +15,7 @@ This project follows [semantic versioning](http://semver.org). - [fixed] Reconnecting to lost peers when receiving from them or sending to them - [fixed] Sending peer list more often to prevent timeouts - [fixed] Removing learnt addresses of lost peers +- [fixed] Fixed possible crash in message decoding ### v0.6.0 (2016-06-02) diff --git a/src/cloud.rs b/src/cloud.rs index 370e2f0..bfbb911 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -237,10 +237,10 @@ impl GenericCloud

{ }; try!(match socket.send_to(msg_data, addr) { Ok(written) if written == msg_data.len() => Ok(()), - Ok(_) => Err(Error::SocketError("Sent out truncated packet")), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), Err(e) => { error!("Failed to send via network {:?}", e); - Err(Error::SocketError("IOError when sending")) + Err(Error::Socket("IOError when sending")) } }) } @@ -263,10 +263,10 @@ impl GenericCloud

{ }; match socket.send_to(msg_data, addr) { Ok(written) if written == msg_data.len() => Ok(()), - Ok(_) => Err(Error::SocketError("Sent out truncated packet")), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), Err(e) => { error!("Failed to send via network {:?}", e); - Err(Error::SocketError("IOError when sending")) + Err(Error::Socket("IOError when sending")) } } } @@ -504,7 +504,7 @@ impl GenericCloud

{ Ok(()) => (), Err(e) => { error!("Failed to send via device: {}", e); - return Err(Error::TunTapDevError("Failed to write to device")); + return Err(Error::TunTapDev("Failed to write to device")); } } if self.learning { diff --git a/src/crypto.rs b/src/crypto.rs index 0ed1566..571c234 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -259,7 +259,7 @@ impl Crypto { ) }; match res { 0 => Ok(mlen as usize), - _ => Err(Error::CryptoError("Failed to decrypt")) + _ => Err(Error::Crypto("Failed to decrypt")) } }, Crypto::AES256GCM{ref state, ..} => { @@ -277,7 +277,7 @@ impl Crypto { ) }; match res { 0 => Ok(mlen as usize), - _ => Err(Error::CryptoError("Failed to decrypt")) + _ => Err(Error::Crypto("Failed to decrypt")) } } } diff --git a/src/device.rs b/src/device.rs index 0369889..758c947 100644 --- a/src/device.rs +++ b/src/device.rs @@ -133,7 +133,7 @@ impl Device { /// This method will return an error if the underlying read call fails. #[inline] pub fn read(&mut self, mut buffer: &mut [u8]) -> Result<(usize, usize), Error> { - let read = try!(self.fd.read(&mut buffer).map_err(|_| Error::TunTapDevError("Read error"))); + let read = try!(self.fd.read(&mut buffer).map_err(|_| Error::TunTapDev("Read error"))); let (start, read) = self.correct_data_after_read(&mut buffer, 0, read); Ok((start, read)) } @@ -171,8 +171,8 @@ impl Device { pub fn write(&mut self, mut data: &mut [u8], start: usize) -> Result<(), Error> { let start = self.correct_data_before_write(&mut data, start); match self.fd.write_all(&data[start..]) { - Ok(_) => self.fd.flush().map_err(|_| Error::TunTapDevError("Flush error")), - Err(_) => Err(Error::TunTapDevError("Write error")) + Ok(_) => self.fd.flush().map_err(|_| Error::TunTapDev("Flush error")), + Err(_) => Err(Error::TunTapDev("Write error")) } } diff --git a/src/ethernet.rs b/src/ethernet.rs index 9d6a992..0271dee 100644 --- a/src/ethernet.rs +++ b/src/ethernet.rs @@ -26,7 +26,7 @@ impl Protocol for Frame { /// This method will fail when the given data is not a valid ethernet frame. fn parse(data: &[u8]) -> Result<(Address, Address), Error> { if data.len() < 14 { - return Err(Error::ParseError("Frame is too short")); + return Err(Error::Parse("Frame is too short")); } let mut pos = 0; let dst_data = &data[pos..pos+6]; @@ -36,7 +36,7 @@ impl Protocol for Frame { if data[pos] == 0x81 && data[pos+1] == 0x00 { pos += 2; if data.len() < pos + 2 { - return Err(Error::ParseError("Vlan frame is too short")); + return Err(Error::Parse("Vlan frame is too short")); } let mut src = [0; 16]; let mut dst = [0; 16]; diff --git a/src/ip.rs b/src/ip.rs index e734001..eea7c22 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -25,13 +25,13 @@ impl Protocol for Packet { /// This method will fail when the given data is not a valid ipv4 and ipv6 packet. fn parse(data: &[u8]) -> Result<(Address, Address), Error> { if data.len() < 1 { - return Err(Error::ParseError("Empty header")); + return Err(Error::Parse("Empty header")); } let version = data[0] >> 4; match version { 4 => { if data.len() < 20 { - return Err(Error::ParseError("Truncated IPv4 header")); + return Err(Error::Parse("Truncated IPv4 header")); } let src = try!(Address::read_from_fixed(&data[12..], 4)); let dst = try!(Address::read_from_fixed(&data[16..], 4)); @@ -39,13 +39,13 @@ impl Protocol for Packet { }, 6 => { if data.len() < 40 { - return Err(Error::ParseError("Truncated IPv6 header")); + return Err(Error::Parse("Truncated IPv6 header")); } let src = try!(Address::read_from_fixed(&data[8..], 16)); let dst = try!(Address::read_from_fixed(&data[24..], 16)); Ok((src, dst)) }, - _ => Err(Error::ParseError("Invalid version")) + _ => Err(Error::Parse("Invalid version")) } } } diff --git a/src/tests.rs b/src/tests.rs index b043eba..6f5a726 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -7,7 +7,7 @@ use std::str::FromStr; use super::ethernet::{Frame, SwitchTable}; use super::ip::{RoutingTable, Packet}; -use super::types::{Error, Protocol, Address, Range, Table}; +use super::types::{Protocol, Address, Range, Table}; use super::udpmessage::{Options, Message, decode, encode}; use super::crypto::{Crypto, CryptoMethod}; @@ -109,6 +109,14 @@ fn udpmessage_peers() { let (options2, msg2) = decode(&mut should, &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); + // Missing IPv4 count + assert!(decode(&mut[118,112,110,1,0,0,0,1], &mut crypto).is_err()); + // Truncated IPv4 + assert!(decode(&mut[118,112,110,1,0,0,0,1,1], &mut crypto).is_err()); + // Missing IPv6 count + assert!(decode(&mut[118,112,110,1,0,0,0,1,1,1,2,3,4,0,0], &mut crypto).is_err()); + // Truncated IPv6 + assert!(decode(&mut[118,112,110,1,0,0,0,1,1,1,2,3,4,0,0,1], &mut crypto).is_err()); } #[test] @@ -292,6 +300,8 @@ fn routing_table() { table.learn(Address::from_str("192.168.2.0").unwrap(), Some(27), peer3.clone()); assert_eq!(table.lookup(&Address::from_str("192.168.2.31").unwrap()), Some(peer3)); assert_eq!(table.lookup(&Address::from_str("192.168.2.32").unwrap()), Some(peer1)); + table.learn(Address::from_str("192.168.2.0").unwrap(), Some(28), peer3.clone()); + assert_eq!(table.lookup(&Address::from_str("192.168.2.1").unwrap()), Some(peer3)); } #[test] @@ -301,7 +311,7 @@ fn address_parse_fmt() { assert_eq!(format!("{}", Address{data: [3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], len: 8}), "vlan824/78:2d:16:05:01:02"); assert_eq!(format!("{}", Address::from_str("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f").unwrap()), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); assert_eq!(format!("{:?}", Address{data: [1,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0], len: 2}), "0102"); - assert_eq!(Address::from_str(""), Err(Error::ParseError("Failed to parse address"))); + assert!(Address::from_str("").is_err()); // Failed to parse address } #[test] @@ -317,11 +327,30 @@ fn address_decode_encode() { assert_eq!(&buf[0..7], &[6, 0x78, 0x2d, 0x16, 0x05, 0x01, 0x02]); assert_eq!((addr, 7), Address::read_from(&buf).unwrap()); assert_eq!(addr, Address::read_from_fixed(&buf[1..], 6).unwrap()); - assert_eq!(Address::read_from(&buf[0..0]), Err(Error::ParseError("Address too short"))); + assert!(Address::read_from(&buf[0..0]).is_err()); // Address too short buf[0] = 100; - assert_eq!(Address::read_from(&buf), Err(Error::ParseError("Invalid address, too long"))); + assert!(Address::read_from(&buf).is_err()); // Invalid address, too long buf[0] = 5; - assert_eq!(Address::read_from(&buf[0..4]), Err(Error::ParseError("Address too short"))); + assert!(Address::read_from(&buf[0..4]).is_err()); // Address too short +} + +#[test] +fn address_eq() { + assert!(Address::read_from_fixed(&[1,2,3,4], 4) == Address::read_from_fixed(&[1,2,3,4], 4)); + assert!(Address::read_from_fixed(&[1,2,3,4], 4) != Address::read_from_fixed(&[1,2,3,5], 4)); + assert!(Address::read_from_fixed(&[1,2,3,4], 3) == Address::read_from_fixed(&[1,2,3,5], 3)); + assert!(Address::read_from_fixed(&[1,2,3,4], 3) != Address::read_from_fixed(&[1,2,3,4], 4)); +} + +#[test] +fn address_range_decode_encode() { + let mut buf = [0; 32]; + let range = Range{base: Address{data: [0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}, prefix_len: 24}; + assert_eq!(range.write_to(&mut buf), 6); + assert_eq!(&buf[0..6], &[4, 0, 1, 2, 3, 24]); + assert_eq!((range, 6), Range::read_from(&buf).unwrap()); + buf[0] = 17; + assert!(Range::read_from(&buf).is_err()); } #[test] diff --git a/src/types.rs b/src/types.rs index 76a67fa..9fa4475 100644 --- a/src/types.rs +++ b/src/types.rs @@ -25,7 +25,7 @@ impl Address { #[inline] pub fn read_from(data: &[u8]) -> Result<(Address, usize), Error> { if data.len() < 1 { - return Err(Error::ParseError("Address too short")); + return Err(Error::Parse("Address too short")); } let len = data[0] as usize; let addr = try!(Address::read_from_fixed(&data[1..], len)); @@ -35,10 +35,10 @@ impl Address { #[inline] pub fn read_from_fixed(data: &[u8], len: usize) -> Result { if len > 16 { - return Err(Error::ParseError("Invalid address, too long")); + return Err(Error::Parse("Invalid address, too long")); } if data.len() < len { - return Err(Error::ParseError("Address too short")); + return Err(Error::Parse("Address too short")); } let mut bytes = [0; 16]; bytes[0..len].copy_from_slice(&data[0..len]); @@ -59,15 +59,7 @@ impl Address { impl PartialEq for Address { #[inline] fn eq(&self, rhs: &Self) -> bool { - if self.len != rhs.len { - return false; - } - for i in 0..self.len as usize { - if self.data[i] != rhs.data[i] { - return false; - } - } - true + self.len == rhs.len && self.data[..self.len as usize] == rhs.data[..self.len as usize] } } @@ -129,11 +121,11 @@ impl FromStr for Address { if parts.len() == 6 { let mut bytes = [0; 16]; for i in 0..6 { - bytes[i] = try!(u8::from_str_radix(&parts[i], 16).map_err(|_| Error::ParseError("Failed to parse mac"))); + bytes[i] = try!(u8::from_str_radix(&parts[i], 16).map_err(|_| Error::Parse("Failed to parse mac"))); } return Ok(Address{data: bytes, len: 6}); } - Err(Error::ParseError("Failed to parse address")) + Err(Error::Parse("Failed to parse address")) } } @@ -149,7 +141,7 @@ impl Range { pub fn read_from(data: &[u8]) -> Result<(Range, usize), Error> { let (address, read) = try!(Address::read_from(data)); if data.len() < read + 1 { - return Err(Error::ParseError("Range too short")); + return Err(Error::Parse("Range too short")); } let prefix_len = data[read]; Ok((Range{base: address, prefix_len: prefix_len}, read + 1)) @@ -170,10 +162,10 @@ impl FromStr for Range { fn from_str(text: &str) -> Result { let pos = match text.find('/') { Some(pos) => pos, - None => return Err(Error::ParseError("Invalid range format")) + None => return Err(Error::Parse("Invalid range format")) }; let prefix_len = try!(u8::from_str(&text[pos+1..]) - .map_err(|_| Error::ParseError("Failed to parse prefix length"))); + .map_err(|_| Error::Parse("Failed to parse prefix length"))); let base = try!(Address::from_str(&text[..pos])); Ok(Range{base: base, prefix_len: prefix_len}) } @@ -221,21 +213,19 @@ pub trait Protocol: Sized { #[derive(Debug, PartialEq)] pub enum Error { - ParseError(&'static str), + Parse(&'static str), WrongNetwork(Option), - SocketError(&'static str), - NameError(String), - TunTapDevError(&'static str), - CryptoError(&'static str) + Socket(&'static str), + Name(String), + TunTapDev(&'static str), + Crypto(&'static str) } impl fmt::Display for Error { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - Error::ParseError(ref msg) => write!(formatter, "{}", msg), - Error::SocketError(ref msg) => write!(formatter, "{}", msg), - Error::TunTapDevError(ref msg) => write!(formatter, "{}", msg), - Error::CryptoError(ref msg) => write!(formatter, "{}", msg), - Error::NameError(ref name) => write!(formatter, "failed to resolve name '{}'", name), + Error::Parse(ref msg) | Error::Socket(ref msg) | + Error::TunTapDev(ref msg) | Error::Crypto(ref msg) => write!(formatter, "{}", msg), + Error::Name(ref name) => write!(formatter, "failed to resolve name '{}'", name), Error::WrongNetwork(Some(net)) => write!(formatter, "wrong network id: {}", net), Error::WrongNetwork(None) => write!(formatter, "wrong network id: none"), } diff --git a/src/udpmessage.rs b/src/udpmessage.rs index 9e3e1f5..a1208bd 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -33,7 +33,7 @@ impl TopHeader { pub fn read_from(data: &[u8]) -> Result<(TopHeader, usize), Error> { if data.len() < TopHeader::size() { - return Err(Error::ParseError("Empty message")); + return Err(Error::Parse("Empty message")); } let mut header = TopHeader::default(); header.magic.copy_from_slice(&data[0..3]); @@ -100,18 +100,18 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M let mut end = data.len(); let (header, mut pos) = try!(TopHeader::read_from(&data[..end])); if header.magic != MAGIC { - return Err(Error::ParseError("Wrong protocol")); + return Err(Error::Parse("Wrong protocol")); } if header.version != VERSION { - return Err(Error::ParseError("Wrong version")); + return Err(Error::Parse("Wrong version")); } if header.crypto_method != crypto.method() { - return Err(Error::CryptoError("Wrong crypto method")); + return Err(Error::Crypto("Wrong crypto method")); } if crypto.method() > 0 { let len = crypto.nonce_bytes(); if end < pos + len { - return Err(Error::ParseError("Truncated crypto header")); + return Err(Error::Parse("Truncated crypto header")); } { let (before, after) = data.split_at_mut(pos); @@ -124,7 +124,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M let mut options = Options::default(); if header.flags & 0x01 > 0 { if end < pos + NETWORK_ID_BYTES { - return Err(Error::ParseError("Truncated options")); + return Err(Error::Parse("Truncated options")); } options.network_id = Some(Encoder::read_u64(&data[pos..pos+NETWORK_ID_BYTES])); pos += NETWORK_ID_BYTES; @@ -133,14 +133,14 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M 0 => Message::Data(data, pos, end), 1 => { if end < pos + 1 { - return Err(Error::ParseError("Empty peers")); + return Err(Error::Parse("Missing IPv4 count")); } let mut peers = Vec::new(); let count = data[pos]; pos += 1; let len = count as usize * 6; if end < pos + len { - return Err(Error::ParseError("Peer data too short")); + return Err(Error::Parse("IPv4 peer data too short")); } for _ in 0..count { let ip = &data[pos..]; @@ -150,11 +150,14 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), port)); peers.push(addr); } + if end < pos + 1 { + return Err(Error::Parse("Missing IPv6 count")); + } let count = data[pos]; pos += 1; let len = count as usize * 18; if end < pos + len { - return Err(Error::ParseError("Peer data too short")); + return Err(Error::Parse("IPv6 peer data too short")); } for _ in 0..count { let mut ip = [0u16; 8]; @@ -172,7 +175,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M }, 2 => { if end < pos + 2 + NODE_ID_BYTES { - return Err(Error::ParseError("Init data too short")); + return Err(Error::Parse("Init data too short")); } let stage = data[pos]; pos += 1; @@ -190,7 +193,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M Message::Init(stage, node_id, addrs) }, 3 => Message::Close, - _ => return Err(Error::ParseError("Unknown message type")) + _ => return Err(Error::Parse("Unknown message type")) }; Ok((options, msg)) } diff --git a/src/util.rs b/src/util.rs index 24e3667..2cb2371 100644 --- a/src/util.rs +++ b/src/util.rs @@ -110,7 +110,7 @@ macro_rules! try_fail { pub fn resolve(addr: Addr) -> Result, Error> { - let addrs = try!(addr.to_socket_addrs().map_err(|_| Error::NameError(format!("{}", addr)))); + let addrs = try!(addr.to_socket_addrs().map_err(|_| Error::Name(format!("{}", addr)))); // Remove duplicates in addrs (why are there duplicates???) let mut addrs = addrs.collect::>(); addrs.dedup();