diff --git a/src/benches.rs b/src/benches.rs index 3413e15..56677b6 100644 --- a/src/benches.rs +++ b/src/benches.rs @@ -1,7 +1,5 @@ use test::Bencher; -use time::Duration; - use std::str::FromStr; use std::net::ToSocketAddrs; @@ -39,7 +37,7 @@ fn message_decode(b: &mut Bencher) { #[bench] fn switch_learn(b: &mut Bencher) { - let mut table = SwitchTable::new(Duration::seconds(10)); + let mut table = SwitchTable::new(10); let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); b.iter(|| { @@ -49,7 +47,7 @@ fn switch_learn(b: &mut Bencher) { #[bench] fn switch_lookup(b: &mut Bencher) { - let mut table = SwitchTable::new(Duration::seconds(10)); + let mut table = SwitchTable::new(10); let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); table.learn(addr.clone(), None, peer); diff --git a/src/ethernet.rs b/src/ethernet.rs index fca5cb6..5846304 100644 --- a/src/ethernet.rs +++ b/src/ethernet.rs @@ -1,4 +1,3 @@ -use std::ptr; use std::net::SocketAddr; use std::collections::HashMap; @@ -25,21 +24,17 @@ impl Protocol for Frame { } let mut src = [0; 16]; let mut dst = [0; 16]; - unsafe { - ptr::copy_nonoverlapping(data[pos..].as_ptr(), src.as_mut_ptr(), 2); - ptr::copy_nonoverlapping(src_data.as_ptr(), src[2..].as_mut_ptr(), 6); - ptr::copy_nonoverlapping(data[pos..].as_ptr(), dst.as_mut_ptr(), 2); - ptr::copy_nonoverlapping(dst_data.as_ptr(), dst[2..].as_mut_ptr(), 6); + src[0] = data[pos]; src[1] = data[pos+1]; + dst[0] = data[pos]; dst[1] = data[pos+1]; + for i in 0..6 { + src[i+2] = src_data[i]; + dst[i+2] = dst_data[i]; } - Ok((Address(src, 8), Address(dst, 8))) + Ok((Address{data: src, len: 8}, Address{data: dst, len: 8})) } else { - let mut src = [0; 16]; - let mut dst = [0; 16]; - unsafe { - ptr::copy_nonoverlapping(src_data.as_ptr(), src.as_mut_ptr(), 6); - ptr::copy_nonoverlapping(dst_data.as_ptr(), dst.as_mut_ptr(), 6); - } - Ok((Address(src, 6), Address(dst, 6))) + let src = try!(Address::read_from_fixed(&src_data, 6)); + let dst = try!(Address::read_from_fixed(&dst_data, 6)); + Ok((src, dst)) } } } @@ -104,14 +99,14 @@ impl Table for SwitchTable { fn without_vlan() { let data = [6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8]; let (src, dst) = Frame::parse(&data).unwrap(); - assert_eq!(src, Address([1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0], 6)); - assert_eq!(dst, Address([6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0], 6)); + assert_eq!(src, Address{data: [1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0], len: 6}); + assert_eq!(dst, Address{data: [6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0], len: 6}); } #[test] fn with_vlan() { let data = [6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8]; let (src, dst) = Frame::parse(&data).unwrap(); - assert_eq!(src, Address([4,210,1,2,3,4,5,6,0,0,0,0,0,0,0,0], 8)); - assert_eq!(dst, Address([4,210,6,5,4,3,2,1,0,0,0,0,0,0,0,0], 8)); + assert_eq!(src, Address{data: [4,210,1,2,3,4,5,6,0,0,0,0,0,0,0,0], len: 8}); + assert_eq!(dst, Address{data: [4,210,6,5,4,3,2,1,0,0,0,0,0,0,0,0], len: 8}); } diff --git a/src/ip.rs b/src/ip.rs index 097823a..a849781 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -1,7 +1,6 @@ use std::net::SocketAddr; use std::collections::{hash_map, HashMap}; use std::io::Read; -use std::ptr; use super::types::{Protocol, Error, Table, Address}; @@ -15,28 +14,16 @@ impl Protocol for Packet { return Err(Error::ParseError("Empty header")); } let version = data[0] >> 4; - let mut src = [0; 16]; - let mut dst = [0; 16]; match version { 4 => { - if data.len() < 20 { - return Err(Error::ParseError("Truncated header")); - } - unsafe { - ptr::copy_nonoverlapping(data[12..].as_ptr(), src.as_mut_ptr(), 4); - ptr::copy_nonoverlapping(data[16..].as_ptr(), dst.as_mut_ptr(), 4); - } - Ok((Address(src, 4), Address(dst, 4))) + let src = try!(Address::read_from_fixed(&data[12..], 4)); + let dst = try!(Address::read_from_fixed(&data[16..], 4)); + Ok((src, dst)) }, 6 => { - if data.len() < 40 { - return Err(Error::ParseError("Truncated header")); - } - unsafe { - ptr::copy_nonoverlapping(data[8..].as_ptr(), src.as_mut_ptr(), 16); - ptr::copy_nonoverlapping(data[24..].as_ptr(), dst.as_mut_ptr(), 16); - } - Ok((Address(src, 16), Address(dst, 16))) + let src = try!(Address::read_from_fixed(&data[8..], 16)); + let dst = try!(Address::read_from_fixed(&data[24..], 16)); + Ok((src, dst)) }, _ => Err(Error::ParseError("Invalid version")) } @@ -62,15 +49,16 @@ impl Table for RoutingTable { fn learn(&mut self, addr: Address, prefix_len: Option, address: SocketAddr) { let prefix_len = match prefix_len { Some(val) => val, - None => addr.0.len() as u8 * 8 + None => addr.len * 8 }; info!("New routing entry: {:?}/{} => {}", addr, prefix_len, address); let group_len = (prefix_len as usize / 16) * 2; assert!(group_len <= 16); let mut group_bytes = [0; 16]; - unsafe { ptr::copy_nonoverlapping(addr.0.as_ptr(), group_bytes.as_mut_ptr(), group_len) }; - - let routing_entry = RoutingEntry{address: address, bytes: addr.0, prefix_len: prefix_len}; + for i in 0..group_len { + group_bytes[i] = addr.data[i]; + } + let routing_entry = RoutingEntry{address: address, bytes: addr.data, prefix_len: prefix_len}; match self.0.entry(group_bytes) { hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(routing_entry), hash_map::Entry::Vacant(entry) => { entry.insert(vec![routing_entry]); () } @@ -78,16 +66,13 @@ impl Table for RoutingTable { } fn lookup(&mut self, addr: &Address) -> Option { - let len = addr.0.len()/2 * 2; + let len = addr.len as usize/2 * 2; for i in 0..(len/2)+1 { - if let Some(group) = self.0.get(&addr.0[0..len-2*i]) { + if let Some(group) = self.0.get(&addr.data[0..len-2*i]) { for entry in group { - if entry.bytes.len() != addr.0.len() { - continue; - } let mut match_len = 0; - for i in 0..addr.0.len() { - let b = addr.0[i] ^ entry.bytes[i]; + for j in 0..addr.len as usize { + let b = addr.data[j] ^ entry.bytes[j]; if b == 0 { match_len += 8; } else { diff --git a/src/types.rs b/src/types.rs index 58cc3ed..388b8af 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,15 +8,53 @@ use super::util::{as_bytes, as_obj}; pub type NetworkId = u64; #[derive(PartialOrd, Eq, Ord, Clone, Hash)] -pub struct Address(pub [u8; 16], pub u8); +pub struct Address { + pub data: [u8; 16], + pub len: u8 +} + +impl Address { + #[inline] + pub fn read_from(data: &[u8]) -> Result<(Address, usize), Error> { + if data.len() < 1 { + return Err(Error::ParseError("Address too short")); + } + let len = data[0] as usize; + let addr = try!(Address::read_from_fixed(&data[1..], len)); + Ok((addr, len + 1)) + } + + #[inline] + pub fn read_from_fixed(data: &[u8], len: usize) -> Result { + if len > 16 { + return Err(Error::ParseError("Invalid address, too long")); + } + if data.len() < len { + return Err(Error::ParseError("Address too short")); + } + let mut bytes = [0; 16]; + unsafe { ptr::copy_nonoverlapping(data.as_ptr(), bytes.as_mut_ptr(), len) }; + Ok(Address{data: bytes, len: len as u8}) + } + + #[inline] + pub fn write_to(&self, data: &mut[u8]) -> usize { + assert!(data.len() >= self.len as usize + 1); + data[0] = self.len; + unsafe { ptr::copy_nonoverlapping(self.data.as_ptr(), data[1..].as_mut_ptr(), self.len as usize) }; + self.len as usize + 1 + } +} + impl PartialEq for Address { + #[inline] fn eq(&self, rhs: &Self) -> bool { - if self.1 != rhs.1 { + if self.len != rhs.len { return false; } - for i in 0..self.1 as usize { - if self.0[i] != rhs.0[i] { + for i in 0..self.len as usize { + if self.data[i] != rhs.data[i] { return false; } } @@ -27,20 +65,19 @@ impl PartialEq for Address { impl fmt::Debug for Address { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match self.1 { - 4 => write!(formatter, "{}.{}.{}.{}", self.0[0], self.0[1], self.0[2], self.0[3]), + let d = &self.data; + match self.len { + 4 => write!(formatter, "{}.{}.{}.{}", d[0], d[1], d[2], d[3]), 6 => write!(formatter, "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", - self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]), + d[0], d[1], d[2], d[3], d[4], d[5]), 8 => { - let vlan = u16::from_be( *unsafe { as_obj(&self.0[0..2]) }); + let vlan = u16::from_be( *unsafe { as_obj(&d[0..2]) }); write!(formatter, "vlan{}/{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", - vlan, self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7]) + vlan, d[2], d[3], d[4], d[5], d[6], d[7]) }, 16 => write!(formatter, "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}", - self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7], - self.0[8], self.0[9], self.0[10], self.0[11], self.0[12], self.0[13], self.0[14], self.0[15] - ), - _ => self.0.fmt(formatter) + d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15]), + _ => d.fmt(formatter) } } } @@ -55,7 +92,7 @@ impl FromStr for Address { unsafe { ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len()); } - return Ok(Address(res, 4)); + return Ok(Address{data: res, len: 4}); } if let Ok(addr) = Ipv6Addr::from_str(text) { let mut segments = addr.segments(); @@ -67,7 +104,7 @@ impl FromStr for Address { unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); } - return Ok(Address(res, 16)); + return Ok(Address{data: res, len: 16}); } let parts: Vec<&str> = text.split(':').collect(); if parts.len() == 6 { @@ -75,7 +112,7 @@ impl FromStr for Address { for i in 0..6 { bytes[i] = try!(u8::from_str_radix(&parts[i], 16).map_err(|_| Error::ParseError("Failed to parse mac"))); } - return Ok(Address(bytes, 6)); + return Ok(Address{data: bytes, len: 6}); } return Err(Error::ParseError("Failed to parse address")) } @@ -88,6 +125,26 @@ pub struct Range { pub prefix_len: u8 } +impl Range { + #[inline] + pub fn read_from(data: &[u8]) -> Result<(Range, usize), Error> { + let (address, read) = try!(Address::read_from(data)); + if data.len() < read + 1 { + return Err(Error::ParseError("Range too short")); + } + let prefix_len = data[read]; + Ok((Range{base: address, prefix_len: prefix_len}, read + 1)) + } + + #[inline] + pub fn write_to(&self, data: &mut[u8]) -> usize { + let pos = self.base.write_to(data); + assert!(data.len() >= pos + 1); + data[pos] = self.prefix_len; + pos + 1 + } +} + impl FromStr for Range { type Err=Error; @@ -165,10 +222,11 @@ impl fmt::Display for Error { } } + #[test] fn address_fmt() { - assert_eq!(format!("{:?}", Address([120,45,22,5,0,0,0,0,0,0,0,0,0,0,0,0], 4)), "120.45.22.5"); - assert_eq!(format!("{:?}", Address([120,45,22,5,1,2,0,0,0,0,0,0,0,0,0,0], 6)), "78:2d:16:05:01:02"); - assert_eq!(format!("{:?}", Address([3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], 8)), "vlan824/78:2d:16:05:01:02"); - assert_eq!(format!("{:?}", Address([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], 16)), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); + assert_eq!(format!("{:?}", Address{data: [120,45,22,5,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}), "120.45.22.5"); + assert_eq!(format!("{:?}", Address{data: [120,45,22,5,1,2,0,0,0,0,0,0,0,0,0,0], len: 6}), "78:2d:16:05:01:02"); + assert_eq!(format!("{:?}", Address{data: [3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], len: 8}), "vlan824/78:2d:16:05:01:02"); + assert_eq!(format!("{:?}", Address{data: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], len: 16}), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); } diff --git a/src/udpmessage.rs b/src/udpmessage.rs index c209c9d..77401c3 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -1,9 +1,8 @@ -use std::{mem, ptr, fmt}; +use std::{mem, fmt}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr, SocketAddrV6, Ipv6Addr}; -use std::u16; use super::types::{Error, NetworkId, Range, Address}; -use super::util::{as_obj, as_bytes}; +use super::util::{Encoder, memcopy}; use super::crypto::Crypto; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; @@ -22,6 +21,34 @@ struct TopHeader { msgtype: u8 } +impl TopHeader { + pub fn read_from(data: &[u8]) -> Result<(TopHeader, usize), Error> { + if data.len() < 8 { + return Err(Error::ParseError("Empty message")); + } + let mut header = TopHeader::default(); + for i in 0..3 { + header.magic[i] = data[i]; + } + header.version = data[3]; + header.crypto_method = data[4]; + header.flags = data[6]; + header.msgtype = data[7]; + Ok((header, 8)) + } + + pub fn write_to(&self, data: &mut [u8]) -> usize { + for i in 0..3 { + data[i] = self.magic[i]; + } + data[3] = self.version; + data[4] = self.crypto_method; + data[6] = self.flags; + data[7] = self.msgtype; + 8 + } +} + impl Default for TopHeader { fn default() -> Self { TopHeader{magic: MAGIC, version: VERSION, crypto_method: 0, _reserved: 0, flags: 0, msgtype: 0} @@ -66,12 +93,7 @@ impl<'a> fmt::Debug for Message<'a> { pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, Message<'a>), Error> { let mut end = data.len(); - if end < mem::size_of::() { - return Err(Error::ParseError("Empty message")); - } - let mut pos = 0; - let header = unsafe { as_obj::(&data[pos..]) }.clone(); - pos += mem::size_of::(); + let (header, mut pos) = try!(TopHeader::read_from(&data[..end])); if header.magic != MAGIC { return Err(Error::ParseError("Wrong protocol")); } @@ -99,8 +121,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M if end < pos + NETWORK_ID_BYTES { return Err(Error::ParseError("Truncated options")); } - let id = u64::from_be(*unsafe { as_obj::(&data[pos..pos+NETWORK_ID_BYTES]) }); - options.network_id = Some(id); + options.network_id = Some(Encoder::read_u64(&data[pos..pos+NETWORK_ID_BYTES])); pos += NETWORK_ID_BYTES; } let msg = match header.msgtype { @@ -117,14 +138,10 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M return Err(Error::ParseError("Peer data too short")); } for _ in 0..count { - let (ip, port) = unsafe { - let ip = as_obj::<[u8; 4]>(&data[pos..]); - pos += 4; - let port = *as_obj::(&data[pos..]); - let port = u16::from_be(port); - pos += 2; - (ip, port) - }; + let ip = &data[pos..]; + pos += 4; + let port = Encoder::read_u16(&data[pos..]); + pos += 2; let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), port)); peers.push(addr); } @@ -135,17 +152,15 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M return Err(Error::ParseError("Peer data too short")); } for _ in 0..count { - let (ip, port) = unsafe { - let ip = as_obj::<[u16; 8]>(&data[pos..]); - pos += 16; - let port = *as_obj::(&data[pos..]); - let port = u16::from_be(port); + let mut ip = [0u16; 8]; + for i in 0..8 { + ip[i] = Encoder::read_u16(&data[pos..]); pos += 2; - (ip, port) - }; - let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(u16::from_be(ip[0]), - u16::from_be(ip[1]), u16::from_be(ip[2]), u16::from_be(ip[3]), u16::from_be(ip[4]), - u16::from_be(ip[5]), u16::from_be(ip[6]), u16::from_be(ip[7])), port, 0, 0)); + } + let port = Encoder::read_u16(&data[pos..]); + pos += 2; + let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(ip[0], ip[1], ip[2], + ip[3], ip[4], ip[5], ip[6], ip[7]), port, 0, 0)); peers.push(addr); } Message::Peers(peers) @@ -158,20 +173,9 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M pos += 1; let mut addrs = Vec::with_capacity(count); for _ in 0..count { - if end < pos + 1 { - return Err(Error::ParseError("Init data too short")); - } - let len = data[pos] as usize; - pos += 1; - if end < pos + len + 1 { - return Err(Error::ParseError("Init data too short")); - } - let mut addr = [0; 16]; - unsafe { ptr::copy_nonoverlapping(data[pos..].as_ptr(), addr.as_mut_ptr(), len) }; - let base = Address(addr, len as u8); - let prefix_len = data[pos+len]; - pos += len + 1; - addrs.push(Range{base: base, prefix_len: prefix_len}); + let (range, read) = try!(Range::read_from(&data[pos..end])); + pos += read; + addrs.push(range); } Message::Init(addrs) }, @@ -182,7 +186,6 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M } pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Crypto) -> usize { - assert!(buf.len() >= mem::size_of::()); let mut pos = 0; let mut header = TopHeader::default(); header.msgtype = match msg { @@ -195,22 +198,16 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry if options.network_id.is_some() { header.flags |= 0x01; } - let header_dat = unsafe { as_bytes(&header) }; - unsafe { ptr::copy_nonoverlapping(header_dat.as_ptr(), buf[pos..].as_mut_ptr(), header_dat.len()) }; - pos += header_dat.len(); + pos += header.write_to(&mut buf[pos..]); pos += crypto.nonce_bytes(); if let Some(id) = options.network_id { assert!(buf.len() >= pos + NETWORK_ID_BYTES); - unsafe { - let id_dat = mem::transmute::(id.to_be()); - ptr::copy_nonoverlapping(id_dat.as_ptr(), buf[pos..pos+NETWORK_ID_BYTES].as_mut_ptr(), id_dat.len()); - } + Encoder::write_u64(id, &mut buf[pos..]); pos += NETWORK_ID_BYTES; } match msg { &Message::Data(ref data) => { - assert!(buf.len() >= pos + data.len()); - unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf[pos..].as_mut_ptr(), data.len()) }; + memcopy(data, &mut buf[pos..]); pos += data.len(); }, &Message::Peers(ref peers) => { @@ -229,30 +226,23 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry pos += 1; for addr in v4addrs { let ip = addr.ip().octets(); - let port = addr.port(); - unsafe { - ptr::copy_nonoverlapping(ip.as_ptr(), buf[pos..].as_mut_ptr(), ip.len()); - pos += ip.len(); - let port = mem::transmute::(port.to_be()); - ptr::copy_nonoverlapping(port.as_ptr(), buf[pos..].as_mut_ptr(), port.len()); - pos += port.len(); + for i in 0..4 { + buf[pos+i] = ip[i]; } + pos += 4; + Encoder::write_u16(addr.port(), &mut buf[pos..]); + pos += 2; }; buf[pos] = v6addrs.len() as u8; pos += 1; for addr in v6addrs { - let mut ip = addr.ip().segments(); - for i in 0..ip.len() { - ip[i] = ip[i].to_be(); - } - let port = addr.port(); - unsafe { - ptr::copy_nonoverlapping(ip.as_ptr() as *const u8, buf[pos..].as_mut_ptr(), 16); - pos += ip.len(); - let port = mem::transmute::(port.to_be()); - ptr::copy_nonoverlapping(port.as_ptr(), buf[pos..].as_mut_ptr(), port.len()); - pos += port.len(); + let ip = addr.ip().segments(); + for i in 0..8 { + Encoder::write_u16(ip[i], &mut buf[pos..]); + pos += 2; } + Encoder::write_u16(addr.port(), &mut buf[pos..]); + pos += 2; }; }, &Message::Init(ref ranges) => { @@ -261,15 +251,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry buf[pos] = ranges.len() as u8; pos += 1; for range in ranges { - let base = &range.base; - let len = base.1 as usize; - assert!(buf.len() >= pos + 1 + len + 1); - buf[pos] = len as u8; - pos += 1; - unsafe { ptr::copy_nonoverlapping(base.0.as_ptr(), buf[pos..].as_mut_ptr(), len) }; - pos += len; - buf[pos] = range.prefix_len; - pos += 1; + pos += range.write_to(&mut buf[pos..]); } }, &Message::Close => { @@ -351,8 +333,8 @@ fn encode_option_network_id() { fn encode_message_init() { let mut options = Options::default(); let mut crypto = Crypto::None; - let addrs = vec![Range{base: Address([0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], 4), prefix_len: 24}, - Range{base: Address([0,1,2,3,4,5,0,0,0,0,0,0,0,0,0,0], 6), prefix_len: 16}]; + let addrs = vec![Range{base: Address{data: [0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}, prefix_len: 24}, + Range{base: Address{data: [0,1,2,3,4,5,0,0,0,0,0,0,0,0,0,0], len: 6}, prefix_len: 16}]; let msg = Message::Init(addrs); let mut buf = [0; 1024]; let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); diff --git a/src/util.rs b/src/util.rs index 9f586c5..a51b894 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,4 @@ -use std::{mem, slice}; +use std::{mem, slice, ptr}; use libc; pub type Duration = u32; @@ -30,6 +30,49 @@ pub unsafe fn as_obj(data: &[u8]) -> &T { mem::transmute(data.as_ptr()) } +#[inline(always)] +pub fn memcopy(src: &[u8], dst: &mut[u8]) { + assert!(dst.len() >= src.len()); + unsafe { ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), src.len()) }; +} + + +pub struct Encoder; + +impl Encoder { + #[inline(always)] + pub fn read_u16(data: &[u8]) -> u16 { + ((data[0] as u16) << 8) | data[1] as u16 + } + + #[inline(always)] + pub fn write_u16(val: u16, data: &mut [u8]) { + data[0] = ((val >> 8) & 0xff) as u8; + data[1] = (val & 0xff) as u8; + } + + #[inline(always)] + pub fn read_u64(data: &[u8]) -> u64 { + ((data[0] as u64) << 56) | ((data[1] as u64) << 48) | + ((data[2] as u64) << 40) | ((data[3] as u64) << 32) | + ((data[4] as u64) << 24) | ((data[5] as u64) << 16) | + ((data[6] as u64) << 8) | data[7] as u64 + } + + #[inline(always)] + pub fn write_u64(val: u64, data: &mut [u8]) { + data[0] = ((val >> 56) & 0xff) as u8; + data[1] = ((val >> 48) & 0xff) as u8; + data[2] = ((val >> 40) & 0xff) as u8; + data[3] = ((val >> 32) & 0xff) as u8; + data[4] = ((val >> 24) & 0xff) as u8; + data[5] = ((val >> 16) & 0xff) as u8; + data[6] = ((val >> 8) & 0xff) as u8; + data[7] = (val & 0xff) as u8; + } +} + + macro_rules! fail { ($format:expr) => ( { use std::process;