Removed lots of unsafe blocks

This commit is contained in:
Dennis Schwerdel 2015-11-26 10:45:25 +01:00
parent 0597b617af
commit f933c541f8
6 changed files with 216 additions and 155 deletions

View File

@ -1,7 +1,5 @@
use test::Bencher; use test::Bencher;
use time::Duration;
use std::str::FromStr; use std::str::FromStr;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
@ -39,7 +37,7 @@ fn message_decode(b: &mut Bencher) {
#[bench] #[bench]
fn switch_learn(b: &mut Bencher) { fn switch_learn(b: &mut Bencher) {
let mut table = SwitchTable::new(Duration::seconds(10)); let mut table = SwitchTable::new(10);
let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let addr = Address::from_str("12:34:56:78:90:ab").unwrap();
let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap();
b.iter(|| { b.iter(|| {
@ -49,7 +47,7 @@ fn switch_learn(b: &mut Bencher) {
#[bench] #[bench]
fn switch_lookup(b: &mut Bencher) { fn switch_lookup(b: &mut Bencher) {
let mut table = SwitchTable::new(Duration::seconds(10)); let mut table = SwitchTable::new(10);
let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let addr = Address::from_str("12:34:56:78:90:ab").unwrap();
let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap();
table.learn(addr.clone(), None, peer); table.learn(addr.clone(), None, peer);

View File

@ -1,4 +1,3 @@
use std::ptr;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::collections::HashMap; use std::collections::HashMap;
@ -25,21 +24,17 @@ impl Protocol for Frame {
} }
let mut src = [0; 16]; let mut src = [0; 16];
let mut dst = [0; 16]; let mut dst = [0; 16];
unsafe { src[0] = data[pos]; src[1] = data[pos+1];
ptr::copy_nonoverlapping(data[pos..].as_ptr(), src.as_mut_ptr(), 2); dst[0] = data[pos]; dst[1] = data[pos+1];
ptr::copy_nonoverlapping(src_data.as_ptr(), src[2..].as_mut_ptr(), 6); for i in 0..6 {
ptr::copy_nonoverlapping(data[pos..].as_ptr(), dst.as_mut_ptr(), 2); src[i+2] = src_data[i];
ptr::copy_nonoverlapping(dst_data.as_ptr(), dst[2..].as_mut_ptr(), 6); dst[i+2] = dst_data[i];
} }
Ok((Address(src, 8), Address(dst, 8))) Ok((Address{data: src, len: 8}, Address{data: dst, len: 8}))
} else { } else {
let mut src = [0; 16]; let src = try!(Address::read_from_fixed(&src_data, 6));
let mut dst = [0; 16]; let dst = try!(Address::read_from_fixed(&dst_data, 6));
unsafe { Ok((src, dst))
ptr::copy_nonoverlapping(src_data.as_ptr(), src.as_mut_ptr(), 6);
ptr::copy_nonoverlapping(dst_data.as_ptr(), dst.as_mut_ptr(), 6);
}
Ok((Address(src, 6), Address(dst, 6)))
} }
} }
} }
@ -104,14 +99,14 @@ impl Table for SwitchTable {
fn without_vlan() { fn without_vlan() {
let data = [6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8]; let data = [6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8];
let (src, dst) = Frame::parse(&data).unwrap(); let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address([1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0], 6)); assert_eq!(src, Address{data: [1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0], len: 6});
assert_eq!(dst, Address([6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0], 6)); assert_eq!(dst, Address{data: [6,5,4,3,2,1,0,0,0,0,0,0,0,0,0,0], len: 6});
} }
#[test] #[test]
fn with_vlan() { fn with_vlan() {
let data = [6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8]; let data = [6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8];
let (src, dst) = Frame::parse(&data).unwrap(); let (src, dst) = Frame::parse(&data).unwrap();
assert_eq!(src, Address([4,210,1,2,3,4,5,6,0,0,0,0,0,0,0,0], 8)); assert_eq!(src, Address{data: [4,210,1,2,3,4,5,6,0,0,0,0,0,0,0,0], len: 8});
assert_eq!(dst, Address([4,210,6,5,4,3,2,1,0,0,0,0,0,0,0,0], 8)); assert_eq!(dst, Address{data: [4,210,6,5,4,3,2,1,0,0,0,0,0,0,0,0], len: 8});
} }

View File

@ -1,7 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::collections::{hash_map, HashMap}; use std::collections::{hash_map, HashMap};
use std::io::Read; use std::io::Read;
use std::ptr;
use super::types::{Protocol, Error, Table, Address}; use super::types::{Protocol, Error, Table, Address};
@ -15,28 +14,16 @@ impl Protocol for Packet {
return Err(Error::ParseError("Empty header")); return Err(Error::ParseError("Empty header"));
} }
let version = data[0] >> 4; let version = data[0] >> 4;
let mut src = [0; 16];
let mut dst = [0; 16];
match version { match version {
4 => { 4 => {
if data.len() < 20 { let src = try!(Address::read_from_fixed(&data[12..], 4));
return Err(Error::ParseError("Truncated header")); let dst = try!(Address::read_from_fixed(&data[16..], 4));
} Ok((src, dst))
unsafe {
ptr::copy_nonoverlapping(data[12..].as_ptr(), src.as_mut_ptr(), 4);
ptr::copy_nonoverlapping(data[16..].as_ptr(), dst.as_mut_ptr(), 4);
}
Ok((Address(src, 4), Address(dst, 4)))
}, },
6 => { 6 => {
if data.len() < 40 { let src = try!(Address::read_from_fixed(&data[8..], 16));
return Err(Error::ParseError("Truncated header")); let dst = try!(Address::read_from_fixed(&data[24..], 16));
} Ok((src, dst))
unsafe {
ptr::copy_nonoverlapping(data[8..].as_ptr(), src.as_mut_ptr(), 16);
ptr::copy_nonoverlapping(data[24..].as_ptr(), dst.as_mut_ptr(), 16);
}
Ok((Address(src, 16), Address(dst, 16)))
}, },
_ => Err(Error::ParseError("Invalid version")) _ => Err(Error::ParseError("Invalid version"))
} }
@ -62,15 +49,16 @@ impl Table for RoutingTable {
fn learn(&mut self, addr: Address, prefix_len: Option<u8>, address: SocketAddr) { fn learn(&mut self, addr: Address, prefix_len: Option<u8>, address: SocketAddr) {
let prefix_len = match prefix_len { let prefix_len = match prefix_len {
Some(val) => val, Some(val) => val,
None => addr.0.len() as u8 * 8 None => addr.len * 8
}; };
info!("New routing entry: {:?}/{} => {}", addr, prefix_len, address); info!("New routing entry: {:?}/{} => {}", addr, prefix_len, address);
let group_len = (prefix_len as usize / 16) * 2; let group_len = (prefix_len as usize / 16) * 2;
assert!(group_len <= 16); assert!(group_len <= 16);
let mut group_bytes = [0; 16]; let mut group_bytes = [0; 16];
unsafe { ptr::copy_nonoverlapping(addr.0.as_ptr(), group_bytes.as_mut_ptr(), group_len) }; for i in 0..group_len {
group_bytes[i] = addr.data[i];
let routing_entry = RoutingEntry{address: address, bytes: addr.0, prefix_len: prefix_len}; }
let routing_entry = RoutingEntry{address: address, bytes: addr.data, prefix_len: prefix_len};
match self.0.entry(group_bytes) { match self.0.entry(group_bytes) {
hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(routing_entry), hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(routing_entry),
hash_map::Entry::Vacant(entry) => { entry.insert(vec![routing_entry]); () } hash_map::Entry::Vacant(entry) => { entry.insert(vec![routing_entry]); () }
@ -78,16 +66,13 @@ impl Table for RoutingTable {
} }
fn lookup(&mut self, addr: &Address) -> Option<SocketAddr> { fn lookup(&mut self, addr: &Address) -> Option<SocketAddr> {
let len = addr.0.len()/2 * 2; let len = addr.len as usize/2 * 2;
for i in 0..(len/2)+1 { for i in 0..(len/2)+1 {
if let Some(group) = self.0.get(&addr.0[0..len-2*i]) { if let Some(group) = self.0.get(&addr.data[0..len-2*i]) {
for entry in group { for entry in group {
if entry.bytes.len() != addr.0.len() {
continue;
}
let mut match_len = 0; let mut match_len = 0;
for i in 0..addr.0.len() { for j in 0..addr.len as usize {
let b = addr.0[i] ^ entry.bytes[i]; let b = addr.data[j] ^ entry.bytes[j];
if b == 0 { if b == 0 {
match_len += 8; match_len += 8;
} else { } else {

View File

@ -8,15 +8,53 @@ use super::util::{as_bytes, as_obj};
pub type NetworkId = u64; pub type NetworkId = u64;
#[derive(PartialOrd, Eq, Ord, Clone, Hash)] #[derive(PartialOrd, Eq, Ord, Clone, Hash)]
pub struct Address(pub [u8; 16], pub u8); pub struct Address {
pub data: [u8; 16],
pub len: u8
}
impl Address {
#[inline]
pub fn read_from(data: &[u8]) -> Result<(Address, usize), Error> {
if data.len() < 1 {
return Err(Error::ParseError("Address too short"));
}
let len = data[0] as usize;
let addr = try!(Address::read_from_fixed(&data[1..], len));
Ok((addr, len + 1))
}
#[inline]
pub fn read_from_fixed(data: &[u8], len: usize) -> Result<Address, Error> {
if len > 16 {
return Err(Error::ParseError("Invalid address, too long"));
}
if data.len() < len {
return Err(Error::ParseError("Address too short"));
}
let mut bytes = [0; 16];
unsafe { ptr::copy_nonoverlapping(data.as_ptr(), bytes.as_mut_ptr(), len) };
Ok(Address{data: bytes, len: len as u8})
}
#[inline]
pub fn write_to(&self, data: &mut[u8]) -> usize {
assert!(data.len() >= self.len as usize + 1);
data[0] = self.len;
unsafe { ptr::copy_nonoverlapping(self.data.as_ptr(), data[1..].as_mut_ptr(), self.len as usize) };
self.len as usize + 1
}
}
impl PartialEq for Address { impl PartialEq for Address {
#[inline]
fn eq(&self, rhs: &Self) -> bool { fn eq(&self, rhs: &Self) -> bool {
if self.1 != rhs.1 { if self.len != rhs.len {
return false; return false;
} }
for i in 0..self.1 as usize { for i in 0..self.len as usize {
if self.0[i] != rhs.0[i] { if self.data[i] != rhs.data[i] {
return false; return false;
} }
} }
@ -27,20 +65,19 @@ impl PartialEq for Address {
impl fmt::Debug for Address { impl fmt::Debug for Address {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match self.1 { let d = &self.data;
4 => write!(formatter, "{}.{}.{}.{}", self.0[0], self.0[1], self.0[2], self.0[3]), match self.len {
4 => write!(formatter, "{}.{}.{}.{}", d[0], d[1], d[2], d[3]),
6 => write!(formatter, "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", 6 => write!(formatter, "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]), d[0], d[1], d[2], d[3], d[4], d[5]),
8 => { 8 => {
let vlan = u16::from_be( *unsafe { as_obj(&self.0[0..2]) }); let vlan = u16::from_be( *unsafe { as_obj(&d[0..2]) });
write!(formatter, "vlan{}/{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", write!(formatter, "vlan{}/{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
vlan, self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7]) vlan, d[2], d[3], d[4], d[5], d[6], d[7])
}, },
16 => write!(formatter, "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}", 16 => write!(formatter, "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7], d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15]),
self.0[8], self.0[9], self.0[10], self.0[11], self.0[12], self.0[13], self.0[14], self.0[15] _ => d.fmt(formatter)
),
_ => self.0.fmt(formatter)
} }
} }
} }
@ -55,7 +92,7 @@ impl FromStr for Address {
unsafe { unsafe {
ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len()); ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len());
} }
return Ok(Address(res, 4)); return Ok(Address{data: res, len: 4});
} }
if let Ok(addr) = Ipv6Addr::from_str(text) { if let Ok(addr) = Ipv6Addr::from_str(text) {
let mut segments = addr.segments(); let mut segments = addr.segments();
@ -67,7 +104,7 @@ impl FromStr for Address {
unsafe { unsafe {
ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len());
} }
return Ok(Address(res, 16)); return Ok(Address{data: res, len: 16});
} }
let parts: Vec<&str> = text.split(':').collect(); let parts: Vec<&str> = text.split(':').collect();
if parts.len() == 6 { if parts.len() == 6 {
@ -75,7 +112,7 @@ impl FromStr for Address {
for i in 0..6 { for i in 0..6 {
bytes[i] = try!(u8::from_str_radix(&parts[i], 16).map_err(|_| Error::ParseError("Failed to parse mac"))); bytes[i] = try!(u8::from_str_radix(&parts[i], 16).map_err(|_| Error::ParseError("Failed to parse mac")));
} }
return Ok(Address(bytes, 6)); return Ok(Address{data: bytes, len: 6});
} }
return Err(Error::ParseError("Failed to parse address")) return Err(Error::ParseError("Failed to parse address"))
} }
@ -88,6 +125,26 @@ pub struct Range {
pub prefix_len: u8 pub prefix_len: u8
} }
impl Range {
#[inline]
pub fn read_from(data: &[u8]) -> Result<(Range, usize), Error> {
let (address, read) = try!(Address::read_from(data));
if data.len() < read + 1 {
return Err(Error::ParseError("Range too short"));
}
let prefix_len = data[read];
Ok((Range{base: address, prefix_len: prefix_len}, read + 1))
}
#[inline]
pub fn write_to(&self, data: &mut[u8]) -> usize {
let pos = self.base.write_to(data);
assert!(data.len() >= pos + 1);
data[pos] = self.prefix_len;
pos + 1
}
}
impl FromStr for Range { impl FromStr for Range {
type Err=Error; type Err=Error;
@ -165,10 +222,11 @@ impl fmt::Display for Error {
} }
} }
#[test] #[test]
fn address_fmt() { fn address_fmt() {
assert_eq!(format!("{:?}", Address([120,45,22,5,0,0,0,0,0,0,0,0,0,0,0,0], 4)), "120.45.22.5"); assert_eq!(format!("{:?}", Address{data: [120,45,22,5,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}), "120.45.22.5");
assert_eq!(format!("{:?}", Address([120,45,22,5,1,2,0,0,0,0,0,0,0,0,0,0], 6)), "78:2d:16:05:01:02"); assert_eq!(format!("{:?}", Address{data: [120,45,22,5,1,2,0,0,0,0,0,0,0,0,0,0], len: 6}), "78:2d:16:05:01:02");
assert_eq!(format!("{:?}", Address([3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], 8)), "vlan824/78:2d:16:05:01:02"); assert_eq!(format!("{:?}", Address{data: [3,56,120,45,22,5,1,2,0,0,0,0,0,0,0,0], len: 8}), "vlan824/78:2d:16:05:01:02");
assert_eq!(format!("{:?}", Address([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], 16)), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f"); assert_eq!(format!("{:?}", Address{data: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], len: 16}), "0001:0203:0405:0607:0809:0a0b:0c0d:0e0f");
} }

View File

@ -1,9 +1,8 @@
use std::{mem, ptr, fmt}; use std::{mem, fmt};
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr, SocketAddrV6, Ipv6Addr}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr, SocketAddrV6, Ipv6Addr};
use std::u16;
use super::types::{Error, NetworkId, Range, Address}; use super::types::{Error, NetworkId, Range, Address};
use super::util::{as_obj, as_bytes}; use super::util::{Encoder, memcopy};
use super::crypto::Crypto; use super::crypto::Crypto;
const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e];
@ -22,6 +21,34 @@ struct TopHeader {
msgtype: u8 msgtype: u8
} }
impl TopHeader {
pub fn read_from(data: &[u8]) -> Result<(TopHeader, usize), Error> {
if data.len() < 8 {
return Err(Error::ParseError("Empty message"));
}
let mut header = TopHeader::default();
for i in 0..3 {
header.magic[i] = data[i];
}
header.version = data[3];
header.crypto_method = data[4];
header.flags = data[6];
header.msgtype = data[7];
Ok((header, 8))
}
pub fn write_to(&self, data: &mut [u8]) -> usize {
for i in 0..3 {
data[i] = self.magic[i];
}
data[3] = self.version;
data[4] = self.crypto_method;
data[6] = self.flags;
data[7] = self.msgtype;
8
}
}
impl Default for TopHeader { impl Default for TopHeader {
fn default() -> Self { fn default() -> Self {
TopHeader{magic: MAGIC, version: VERSION, crypto_method: 0, _reserved: 0, flags: 0, msgtype: 0} TopHeader{magic: MAGIC, version: VERSION, crypto_method: 0, _reserved: 0, flags: 0, msgtype: 0}
@ -66,12 +93,7 @@ impl<'a> fmt::Debug for Message<'a> {
pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, Message<'a>), Error> { pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, Message<'a>), Error> {
let mut end = data.len(); let mut end = data.len();
if end < mem::size_of::<TopHeader>() { let (header, mut pos) = try!(TopHeader::read_from(&data[..end]));
return Err(Error::ParseError("Empty message"));
}
let mut pos = 0;
let header = unsafe { as_obj::<TopHeader>(&data[pos..]) }.clone();
pos += mem::size_of::<TopHeader>();
if header.magic != MAGIC { if header.magic != MAGIC {
return Err(Error::ParseError("Wrong protocol")); return Err(Error::ParseError("Wrong protocol"));
} }
@ -99,8 +121,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M
if end < pos + NETWORK_ID_BYTES { if end < pos + NETWORK_ID_BYTES {
return Err(Error::ParseError("Truncated options")); return Err(Error::ParseError("Truncated options"));
} }
let id = u64::from_be(*unsafe { as_obj::<u64>(&data[pos..pos+NETWORK_ID_BYTES]) }); options.network_id = Some(Encoder::read_u64(&data[pos..pos+NETWORK_ID_BYTES]));
options.network_id = Some(id);
pos += NETWORK_ID_BYTES; pos += NETWORK_ID_BYTES;
} }
let msg = match header.msgtype { let msg = match header.msgtype {
@ -117,14 +138,10 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M
return Err(Error::ParseError("Peer data too short")); return Err(Error::ParseError("Peer data too short"));
} }
for _ in 0..count { for _ in 0..count {
let (ip, port) = unsafe { let ip = &data[pos..];
let ip = as_obj::<[u8; 4]>(&data[pos..]);
pos += 4; pos += 4;
let port = *as_obj::<u16>(&data[pos..]); let port = Encoder::read_u16(&data[pos..]);
let port = u16::from_be(port);
pos += 2; pos += 2;
(ip, port)
};
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), port)); let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), port));
peers.push(addr); peers.push(addr);
} }
@ -135,17 +152,15 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M
return Err(Error::ParseError("Peer data too short")); return Err(Error::ParseError("Peer data too short"));
} }
for _ in 0..count { for _ in 0..count {
let (ip, port) = unsafe { let mut ip = [0u16; 8];
let ip = as_obj::<[u16; 8]>(&data[pos..]); for i in 0..8 {
pos += 16; ip[i] = Encoder::read_u16(&data[pos..]);
let port = *as_obj::<u16>(&data[pos..]);
let port = u16::from_be(port);
pos += 2; pos += 2;
(ip, port) }
}; let port = Encoder::read_u16(&data[pos..]);
let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(u16::from_be(ip[0]), pos += 2;
u16::from_be(ip[1]), u16::from_be(ip[2]), u16::from_be(ip[3]), u16::from_be(ip[4]), let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(ip[0], ip[1], ip[2],
u16::from_be(ip[5]), u16::from_be(ip[6]), u16::from_be(ip[7])), port, 0, 0)); ip[3], ip[4], ip[5], ip[6], ip[7]), port, 0, 0));
peers.push(addr); peers.push(addr);
} }
Message::Peers(peers) Message::Peers(peers)
@ -158,20 +173,9 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M
pos += 1; pos += 1;
let mut addrs = Vec::with_capacity(count); let mut addrs = Vec::with_capacity(count);
for _ in 0..count { for _ in 0..count {
if end < pos + 1 { let (range, read) = try!(Range::read_from(&data[pos..end]));
return Err(Error::ParseError("Init data too short")); pos += read;
} addrs.push(range);
let len = data[pos] as usize;
pos += 1;
if end < pos + len + 1 {
return Err(Error::ParseError("Init data too short"));
}
let mut addr = [0; 16];
unsafe { ptr::copy_nonoverlapping(data[pos..].as_ptr(), addr.as_mut_ptr(), len) };
let base = Address(addr, len as u8);
let prefix_len = data[pos+len];
pos += len + 1;
addrs.push(Range{base: base, prefix_len: prefix_len});
} }
Message::Init(addrs) Message::Init(addrs)
}, },
@ -182,7 +186,6 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M
} }
pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Crypto) -> usize { pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Crypto) -> usize {
assert!(buf.len() >= mem::size_of::<TopHeader>());
let mut pos = 0; let mut pos = 0;
let mut header = TopHeader::default(); let mut header = TopHeader::default();
header.msgtype = match msg { header.msgtype = match msg {
@ -195,22 +198,16 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry
if options.network_id.is_some() { if options.network_id.is_some() {
header.flags |= 0x01; header.flags |= 0x01;
} }
let header_dat = unsafe { as_bytes(&header) }; pos += header.write_to(&mut buf[pos..]);
unsafe { ptr::copy_nonoverlapping(header_dat.as_ptr(), buf[pos..].as_mut_ptr(), header_dat.len()) };
pos += header_dat.len();
pos += crypto.nonce_bytes(); pos += crypto.nonce_bytes();
if let Some(id) = options.network_id { if let Some(id) = options.network_id {
assert!(buf.len() >= pos + NETWORK_ID_BYTES); assert!(buf.len() >= pos + NETWORK_ID_BYTES);
unsafe { Encoder::write_u64(id, &mut buf[pos..]);
let id_dat = mem::transmute::<u64, [u8; NETWORK_ID_BYTES]>(id.to_be());
ptr::copy_nonoverlapping(id_dat.as_ptr(), buf[pos..pos+NETWORK_ID_BYTES].as_mut_ptr(), id_dat.len());
}
pos += NETWORK_ID_BYTES; pos += NETWORK_ID_BYTES;
} }
match msg { match msg {
&Message::Data(ref data) => { &Message::Data(ref data) => {
assert!(buf.len() >= pos + data.len()); memcopy(data, &mut buf[pos..]);
unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf[pos..].as_mut_ptr(), data.len()) };
pos += data.len(); pos += data.len();
}, },
&Message::Peers(ref peers) => { &Message::Peers(ref peers) => {
@ -229,30 +226,23 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry
pos += 1; pos += 1;
for addr in v4addrs { for addr in v4addrs {
let ip = addr.ip().octets(); let ip = addr.ip().octets();
let port = addr.port(); for i in 0..4 {
unsafe { buf[pos+i] = ip[i];
ptr::copy_nonoverlapping(ip.as_ptr(), buf[pos..].as_mut_ptr(), ip.len());
pos += ip.len();
let port = mem::transmute::<u16, [u8; 2]>(port.to_be());
ptr::copy_nonoverlapping(port.as_ptr(), buf[pos..].as_mut_ptr(), port.len());
pos += port.len();
} }
pos += 4;
Encoder::write_u16(addr.port(), &mut buf[pos..]);
pos += 2;
}; };
buf[pos] = v6addrs.len() as u8; buf[pos] = v6addrs.len() as u8;
pos += 1; pos += 1;
for addr in v6addrs { for addr in v6addrs {
let mut ip = addr.ip().segments(); let ip = addr.ip().segments();
for i in 0..ip.len() { for i in 0..8 {
ip[i] = ip[i].to_be(); Encoder::write_u16(ip[i], &mut buf[pos..]);
} pos += 2;
let port = addr.port();
unsafe {
ptr::copy_nonoverlapping(ip.as_ptr() as *const u8, buf[pos..].as_mut_ptr(), 16);
pos += ip.len();
let port = mem::transmute::<u16, [u8; 2]>(port.to_be());
ptr::copy_nonoverlapping(port.as_ptr(), buf[pos..].as_mut_ptr(), port.len());
pos += port.len();
} }
Encoder::write_u16(addr.port(), &mut buf[pos..]);
pos += 2;
}; };
}, },
&Message::Init(ref ranges) => { &Message::Init(ref ranges) => {
@ -261,15 +251,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry
buf[pos] = ranges.len() as u8; buf[pos] = ranges.len() as u8;
pos += 1; pos += 1;
for range in ranges { for range in ranges {
let base = &range.base; pos += range.write_to(&mut buf[pos..]);
let len = base.1 as usize;
assert!(buf.len() >= pos + 1 + len + 1);
buf[pos] = len as u8;
pos += 1;
unsafe { ptr::copy_nonoverlapping(base.0.as_ptr(), buf[pos..].as_mut_ptr(), len) };
pos += len;
buf[pos] = range.prefix_len;
pos += 1;
} }
}, },
&Message::Close => { &Message::Close => {
@ -351,8 +333,8 @@ fn encode_option_network_id() {
fn encode_message_init() { fn encode_message_init() {
let mut options = Options::default(); let mut options = Options::default();
let mut crypto = Crypto::None; let mut crypto = Crypto::None;
let addrs = vec![Range{base: Address([0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], 4), prefix_len: 24}, let addrs = vec![Range{base: Address{data: [0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}, prefix_len: 24},
Range{base: Address([0,1,2,3,4,5,0,0,0,0,0,0,0,0,0,0], 6), prefix_len: 16}]; Range{base: Address{data: [0,1,2,3,4,5,0,0,0,0,0,0,0,0,0,0], len: 6}, prefix_len: 16}];
let msg = Message::Init(addrs); let msg = Message::Init(addrs);
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); let size = encode(&mut options, &msg, &mut buf[..], &mut crypto);

View File

@ -1,4 +1,4 @@
use std::{mem, slice}; use std::{mem, slice, ptr};
use libc; use libc;
pub type Duration = u32; pub type Duration = u32;
@ -30,6 +30,49 @@ pub unsafe fn as_obj<T>(data: &[u8]) -> &T {
mem::transmute(data.as_ptr()) mem::transmute(data.as_ptr())
} }
#[inline(always)]
pub fn memcopy(src: &[u8], dst: &mut[u8]) {
assert!(dst.len() >= src.len());
unsafe { ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), src.len()) };
}
pub struct Encoder;
impl Encoder {
#[inline(always)]
pub fn read_u16(data: &[u8]) -> u16 {
((data[0] as u16) << 8) | data[1] as u16
}
#[inline(always)]
pub fn write_u16(val: u16, data: &mut [u8]) {
data[0] = ((val >> 8) & 0xff) as u8;
data[1] = (val & 0xff) as u8;
}
#[inline(always)]
pub fn read_u64(data: &[u8]) -> u64 {
((data[0] as u64) << 56) | ((data[1] as u64) << 48) |
((data[2] as u64) << 40) | ((data[3] as u64) << 32) |
((data[4] as u64) << 24) | ((data[5] as u64) << 16) |
((data[6] as u64) << 8) | data[7] as u64
}
#[inline(always)]
pub fn write_u64(val: u64, data: &mut [u8]) {
data[0] = ((val >> 56) & 0xff) as u8;
data[1] = ((val >> 48) & 0xff) as u8;
data[2] = ((val >> 40) & 0xff) as u8;
data[3] = ((val >> 32) & 0xff) as u8;
data[4] = ((val >> 24) & 0xff) as u8;
data[5] = ((val >> 16) & 0xff) as u8;
data[6] = ((val >> 8) & 0xff) as u8;
data[7] = (val & 0xff) as u8;
}
}
macro_rules! fail { macro_rules! fail {
($format:expr) => ( { ($format:expr) => ( {
use std::process; use std::process;