diff --git a/src/udpmessage.rs b/src/udpmessage.rs index c94fd17..b3d686b 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -1,4 +1,4 @@ -use std::{mem, ptr, fmt, slice}; +use std::{mem, ptr, fmt}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::u16; @@ -9,6 +9,9 @@ use super::crypto::Crypto; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; pub const VERSION: u8 = 1; +const NETWORK_ID_BYTES: usize = 8; + +#[derive(Clone)] #[repr(packed)] struct TopHeader { magic: [u8; 3], @@ -66,7 +69,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M return Err(Error::ParseError("Empty message")); } let mut pos = 0; - let header = unsafe { as_obj::(&data[pos..]) }; + let header = unsafe { as_obj::(&data[pos..]) }.clone(); pos += mem::size_of::(); if header.magic != MAGIC { return Err(Error::ParseError("Wrong protocol")); @@ -78,25 +81,23 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M return Err(Error::CryptoError("Wrong crypto method")); } if crypto.method() > 0 { - if data.len() < pos + crypto.nonce_bytes() + crypto.auth_bytes() { + let len = crypto.nonce_bytes() + crypto.auth_bytes(); + if data.len() < pos + len { return Err(Error::ParseError("Truncated crypto header")); } - let nonce = &data[pos..pos+crypto.nonce_bytes()]; - pos += crypto.nonce_bytes(); - let hash = &data[pos..pos+crypto.auth_bytes()]; - pos += crypto.auth_bytes(); - // Cheat data mutable to make the borrow checker happy - let data = unsafe { slice::from_raw_parts_mut(mem::transmute(data[pos..].as_ptr()), data.len()-pos) }; - try!(crypto.decrypt(data, nonce, hash)); + let (crypto_header, crypto_data) = data[pos..].split_at_mut(len); + pos += len; + let (nonce, hash) = crypto_header.split_at(crypto.nonce_bytes()); + try!(crypto.decrypt(crypto_data, nonce, hash)); } let mut options = Options::default(); if header.flags & 0x01 > 0 { - if data.len() < pos + 8 { + if data.len() < pos + NETWORK_ID_BYTES { return Err(Error::ParseError("Truncated options")); } - let id = u64::from_be(*unsafe { as_obj::(&data[pos..]) }); + let id = u64::from_be(*unsafe { as_obj::(&data[pos..pos+NETWORK_ID_BYTES]) }); options.network_id = Some(id); - pos += 8; + pos += NETWORK_ID_BYTES; } let msg = match header.msgtype { 0 => Message::Data(&data[pos..]), @@ -181,12 +182,12 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry pos += crypto.auth_bytes(); let crypto_pos = pos; if let Some(id) = options.network_id { - assert!(buf.len() >= pos + 8); + assert!(buf.len() >= pos + NETWORK_ID_BYTES); unsafe { - let id_dat = mem::transmute::(id.to_be()); - ptr::copy_nonoverlapping(id_dat.as_ptr(), buf[pos..].as_mut_ptr(), id_dat.len()); + let id_dat = mem::transmute::(id.to_be()); + ptr::copy_nonoverlapping(id_dat.as_ptr(), buf[pos..pos+NETWORK_ID_BYTES].as_mut_ptr(), id_dat.len()); } - pos += 8; + pos += NETWORK_ID_BYTES; } match msg { &Message::Data(ref data) => {