diff --git a/src/tests.rs b/src/tests.rs index 1364428..8da499c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -179,11 +179,11 @@ fn routing_table() { } #[test] -fn address_fmt() { - assert_eq!(format!("{}", Address{data: [120,45,22,5,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}), "120.45.22.5"); - assert_eq!(format!("{}", Address{data: [120,45,22,5,1,2,0,0,0,0,0,0,0,0,0,0], len: 6}), "78:2d:16:05:01:02"); +fn address_parse_fmt() { + assert_eq!(format!("{}", Address::from_str("120.45.22.5").unwrap()), "120.45.22.5"); + assert_eq!(format!("{}", Address::from_str("78:2d:16:05:01:02").unwrap()), "78:2d:16:05:01:02"); assert_eq!(format!("{}", Address{data: [3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], len: 8}), "vlan824/78:2d:16:05:01:02"); - assert_eq!(format!("{}", Address{data: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], len: 16}), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); + assert_eq!(format!("{}", Address::from_str("0001:0203:0405:0607:0809:0a0b:0c0d:0e0f").unwrap()), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); } #[test] diff --git a/src/types.rs b/src/types.rs index 93b6695..9adc03e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,9 +1,9 @@ use std::net::{SocketAddr, Ipv4Addr, Ipv6Addr}; use std::hash::Hasher; -use std::{fmt, ptr}; +use std::fmt; use std::str::FromStr; -use super::util::{as_bytes, as_obj}; +use super::util::Encoder; pub type NetworkId = u64; @@ -33,7 +33,9 @@ impl Address { return Err(Error::ParseError("Address too short")); } let mut bytes = [0; 16]; - unsafe { ptr::copy_nonoverlapping(data.as_ptr(), bytes.as_mut_ptr(), len) }; + for i in 0..len { + bytes[i] = data[i]; + } Ok(Address{data: bytes, len: len as u8}) } @@ -41,7 +43,9 @@ impl Address { pub fn write_to(&self, data: &mut[u8]) -> usize { assert!(data.len() >= self.len as usize + 1); data[0] = self.len; - unsafe { ptr::copy_nonoverlapping(self.data.as_ptr(), data[1..].as_mut_ptr(), self.len as usize) }; + for i in 0..self.len as usize { + data[i+1] = self.data[i]; + } self.len as usize + 1 } } @@ -70,7 +74,7 @@ impl fmt::Display for Address { 6 => write!(formatter, "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", d[0], d[1], d[2], d[3], d[4], d[5]), 8 => { - let vlan = u16::from_be( *unsafe { as_obj(&d[0..2]) }); + let vlan = Encoder::read_u16(&d[0..2]); write!(formatter, "vlan{}/{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", vlan, d[2], d[3], d[4], d[5], d[6], d[7]) }, @@ -94,20 +98,16 @@ impl FromStr for Address { if let Ok(addr) = Ipv4Addr::from_str(text) { let ip = addr.octets(); let mut res = [0; 16]; - unsafe { - ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len()); + for i in 0..4 { + res[i] = ip[i]; } return Ok(Address{data: res, len: 4}); } if let Ok(addr) = Ipv6Addr::from_str(text) { - let mut segments = addr.segments(); - for i in 0..8 { - segments[i] = segments[i].to_be(); - } - let bytes = unsafe { as_bytes(&segments) }; + let segments = addr.segments(); let mut res = [0; 16]; - unsafe { - ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); + for i in 0..8 { + Encoder::write_u16(segments[i], &mut res[2*i..]); } return Ok(Address{data: res, len: 16}); } diff --git a/src/udpmessage.rs b/src/udpmessage.rs index 06568b6..fe7e097 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -1,4 +1,4 @@ -use std::{mem, fmt}; +use std::fmt; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr, SocketAddrV6, Ipv6Addr}; use super::types::{Error, NetworkId, Range}; @@ -22,8 +22,13 @@ struct TopHeader { } impl TopHeader { + #[inline(always)] + pub fn size() -> usize { + 8 + } + pub fn read_from(data: &[u8]) -> Result<(TopHeader, usize), Error> { - if data.len() < 8 { + if data.len() < TopHeader::size() { return Err(Error::ParseError("Empty message")); } let mut header = TopHeader::default(); @@ -34,7 +39,7 @@ impl TopHeader { header.crypto_method = data[4]; header.flags = data[6]; header.msgtype = data[7]; - Ok((header, 8)) + Ok((header, TopHeader::size())) } pub fn write_to(&self, data: &mut [u8]) -> usize { @@ -45,7 +50,7 @@ impl TopHeader { data[4] = self.crypto_method; data[6] = self.flags; data[7] = self.msgtype; - 8 + TopHeader::size() } } @@ -112,7 +117,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M let (before, after) = data.split_at_mut(pos); let (nonce, crypto_data) = after.split_at_mut(len); pos += len; - end = try!(crypto.decrypt(crypto_data, nonce, &before[..mem::size_of::()])) + pos; + end = try!(crypto.decrypt(crypto_data, nonce, &before[..TopHeader::size()])) + pos; } assert_eq!(end, data.len()-crypto.additional_bytes()); } @@ -262,7 +267,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry } } if crypto.method() > 0 { - let (header, rest) = buf.split_at_mut(mem::size_of::()); + let (header, rest) = buf.split_at_mut(TopHeader::size()); let (nonce, rest) = rest.split_at_mut(crypto.nonce_bytes()); let crypto_start = header.len() + nonce.len(); assert!(rest.len() >= pos - crypto_start + crypto.additional_bytes()); diff --git a/src/util.rs b/src/util.rs index a51b894..1069f38 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,4 @@ -use std::{mem, slice, ptr}; +use std::ptr; use libc; pub type Duration = u32; @@ -18,18 +18,6 @@ pub fn time_rand() -> i64 { tv.tv_sec ^ tv.tv_nsec } - -#[inline(always)] -pub unsafe fn as_bytes(obj: &T) -> &[u8] { - slice::from_raw_parts(mem::transmute::<&T, *const u8>(obj), mem::size_of::()) -} - -#[inline(always)] -pub unsafe fn as_obj(data: &[u8]) -> &T { - assert!(data.len() >= mem::size_of::()); - mem::transmute(data.as_ptr()) -} - #[inline(always)] pub fn memcopy(src: &[u8], dst: &mut[u8]) { assert!(dst.len() >= src.len());