From 8a5b2cc5ee6cfd647973d279328c0a2ba43ef911 Mon Sep 17 00:00:00 2001 From: Dennis Schwerdel Date: Sun, 13 Dec 2015 22:03:06 +0100 Subject: [PATCH] Removed last payload memcopy --- CHANGELOG.md | 1 + src/cloud.rs | 41 ++++++------- src/tests.rs | 142 ++++++++++++++++++++++++++++++++-------------- src/udpmessage.rs | 91 ++++++++++++++++------------- src/util.rs | 7 --- 5 files changed, 172 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1516226..1747d7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ This project follows [semantic versioning](http://semver.org). ### UNRELEASED +- [changed] Removed last payload memcopy - [changed] Using RNG to select peers for peers list exchange - [changed] Updated dependency versions - [changed] Updated documentation diff --git a/src/cloud.rs b/src/cloud.rs index b3f921c..b2d39c0 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -135,11 +135,11 @@ impl GenericCloud

{ self.device.ifname() } - fn send_msg(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { + fn send_msg(&mut self, addr: Addr, msg: &mut Message) -> Result<(), Error> { debug!("Sending {:?} to {}", msg, addr); - let size = encode(&mut self.options, msg, &mut self.buffer_out, &mut self.crypto); - match self.socket.send_to(&self.buffer_out[..size], addr) { - Ok(written) if written == size => Ok(()), + let msg_data = encode(&mut self.options, msg, &mut self.buffer_out, &mut self.crypto); + match self.socket.send_to(msg_data, addr) { + Ok(written) if written == msg_data.len() => Ok(()), Ok(_) => Err(Error::SocketError("Sent out truncated packet")), Err(e) => { error!("Failed to send via network {:?}", e); @@ -166,7 +166,7 @@ impl GenericCloud

{ } let addrs = self.addresses.clone(); let node_id = self.node_id.clone(); - self.send_msg(addr, &Message::Init(0, node_id, addrs)) + self.send_msg(addr, &mut Message::Init(0, node_id, addrs)) } fn housekeep(&mut self) -> Result<(), Error> { @@ -182,9 +182,9 @@ impl GenericCloud

{ } } let peers = self.peers.subset(peer_num); - let msg = Message::Peers(peers); + let mut msg = Message::Peers(peers); for addr in &self.peers.as_vec() { - try!(self.send_msg(addr, &msg)); + try!(self.send_msg(addr, &mut msg)); } self.next_peerlist = now() + self.update_freq as Time; } @@ -194,14 +194,14 @@ impl GenericCloud

{ Ok(()) } - fn handle_interface_data(&mut self, payload: &[u8]) -> Result<(), Error> { + fn handle_interface_data(&mut self, payload: &mut [u8], start: usize, end: usize) -> Result<(), Error> { let (src, dst) = try!(P::parse(payload)); debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, payload.len()); match self.table.lookup(&dst) { Some(addr) => { debug!("Found destination for {} => {}", dst, addr); if self.peers.contains(&addr) { - try!(self.send_msg(addr, &Message::Data(payload))) + try!(self.send_msg(addr, &mut Message::Data(payload, start, end))) } else { warn!("Destination for {} not found in peers: {}", dst, addr); } @@ -212,9 +212,9 @@ impl GenericCloud

{ return Ok(()); } debug!("No destination for {} found, broadcasting", dst); - let msg = Message::Data(payload); + let mut msg = Message::Data(payload, start, end); for addr in &self.peers.as_vec() { - try!(self.send_msg(addr, &msg)); + try!(self.send_msg(addr, &mut msg)); } } } @@ -228,10 +228,10 @@ impl GenericCloud

{ } debug!("Received {:?} from {}", msg, peer); match msg { - Message::Data(payload) => { - let (src, _dst) = try!(P::parse(payload)); - debug!("Writing data to device: {} bytes", payload.len()); - match self.device.write(&payload) { + Message::Data(payload, start, end) => { + let (src, _dst) = try!(P::parse(&payload[start..end])); + debug!("Writing data to device: {} bytes", end-start); + match self.device.write(&payload[start..end]) { Ok(()) => (), Err(e) => { error!("Failed to send via device: {}", e); @@ -265,8 +265,8 @@ impl GenericCloud

{ let peers = self.peers.as_vec(); let own_addrs = self.addresses.clone(); let own_node_id = self.node_id.clone(); - try!(self.send_msg(peer, &Message::Init(stage+1, own_node_id, own_addrs))); - try!(self.send_msg(peer, &Message::Peers(peers))); + try!(self.send_msg(peer, &mut Message::Init(stage+1, own_node_id, own_addrs))); + try!(self.send_msg(peer, &mut Message::Peers(peers))); } }, Message::Close => { @@ -301,8 +301,9 @@ impl GenericCloud

{ } }, &1 => { - let size = try_fail!(self.device.read(&mut buffer), "Failed to read from tap device: {}"); - match self.handle_interface_data(&buffer[..size]) { + let start = 64; + let size = try_fail!(self.device.read(&mut buffer[start..]), "Failed to read from tap device: {}"); + match self.handle_interface_data(&mut buffer, start, start+size) { Ok(_) => (), Err(e) => error!("Error: {}", e) } @@ -325,7 +326,7 @@ impl GenericCloud

{ } info!("Shutting down..."); for p in &self.peers.as_vec() { - let _ = self.send_msg(p, &Message::Close); + let _ = self.send_msg(p, &mut Message::Close); } } } diff --git a/src/tests.rs b/src/tests.rs index 2b055ac..2adf857 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -8,34 +8,82 @@ use super::udpmessage::{Options, Message, decode, encode}; use super::crypto::{Crypto, CryptoMethod}; +impl<'a> PartialEq for Message<'a> { + fn eq(&self, other: &Message) -> bool { + match self { + &Message::Data(ref data1, start1, end1) => if let &Message::Data(ref data2, start2, end2) = other { + data1[start1..end1] == data2[start2..end2] + } else { false }, + &Message::Peers(ref peers1) => if let &Message::Peers(ref peers2) = other { + peers1 == peers2 + } else { false }, + &Message::Init(step1, node_id1, ref ranges1) => if let &Message::Init(step2, node_id2, ref ranges2) = other { + step1 == step2 && node_id1 == node_id2 && ranges1 == ranges2 + } else { false }, + &Message::Close => if let &Message::Close = other { + true + } else { false } + } + } +} + #[test] +#[allow(unused_assignments)] fn udpmessage_packet() { let mut options = Options::default(); let mut crypto = Crypto::None; - let payload = [1,2,3,4,5]; - let msg = Message::Data(&payload); + let mut payload = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 1,2,3,4,5, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]; + let mut msg = Message::Data(&mut payload, 64, 69); let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 13); - assert_eq!(&buf[..8], &[118,112,110,1,0,0,0,0]); - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let mut len = 0; + { + let res = encode(&mut options, &mut msg, &mut [], &mut crypto); + assert_eq!(res.len(), 13); + assert_eq!(&res[..8], &[118,112,110,1,0,0,0,0]); + for i in 0..res.len() { + buf[i] = res[i]; + } + len = res.len(); + } + let (options2, msg2) = decode(&mut buf[..len], &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); } #[test] +#[allow(unused_assignments)] fn udpmessage_encrypted() { let mut options = Options::default(); let mut crypto = Crypto::from_shared_key(CryptoMethod::ChaCha20, "test"); - let payload = [1,2,3,4,5]; - let msg = Message::Data(&payload); + let mut payload = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 1,2,3,4,5, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]; + let mut orig_payload = [0; 133]; + for i in 0..payload.len() { + orig_payload[i] = payload[i]; + } + let orig_msg = Message::Data(&mut orig_payload, 64, 69); + let mut msg = Message::Data(&mut payload, 64, 69); let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 41); - assert_eq!(&buf[..8], &[118,112,110,1,1,0,0,0]); - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let mut len = 0; + { + let res = encode(&mut options, &mut msg, &mut [], &mut crypto); + assert_eq!(res.len(), 41); + assert_eq!(&res[..8], &[118,112,110,1,1,0,0,0]); + for i in 0..res.len() { + buf[i] = res[i]; + } + len = res.len(); + } + let (options2, msg2) = decode(&mut buf[..len], &mut crypto).unwrap(); assert_eq!(options, options2); - assert_eq!(msg, msg2); + assert_eq!(orig_msg, msg2); } #[test] @@ -43,16 +91,18 @@ fn udpmessage_peers() { use std::str::FromStr; let mut options = Options::default(); let mut crypto = Crypto::None; - let msg = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap(), SocketAddr::from_str("[0001:0203:0405:0607:0809:0a0b:0c0d:0e0f]:6789").unwrap()]); - let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 40); - let should = [118,112,110,1,0,0,0,1,2,1,2,3,4,0,123,5,6,7,8,48,57,1,0,1,2,3,4,5,6,7, + let mut msg = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap(), SocketAddr::from_str("[0001:0203:0405:0607:0809:0a0b:0c0d:0e0f]:6789").unwrap()]); + let mut should = [118,112,110,1,0,0,0,1,2,1,2,3,4,0,123,5,6,7,8,48,57,1,0,1,2,3,4,5,6,7, 8,9,10,11,12,13,14,15,26,133]; - for i in 0..size { - assert_eq!(buf[i], should[i]); + { + let mut buf = [0; 1024]; + let res = encode(&mut options, &mut msg, &mut buf[..], &mut crypto); + assert_eq!(res.len(), 40); + for i in 0..res.len() { + assert_eq!(res[i], should[i]); + } } - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let (options2, msg2) = decode(&mut should, &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); } @@ -62,12 +112,15 @@ fn udpmessage_option_network_id() { let mut options = Options::default(); options.network_id = Some(134); let mut crypto = Crypto::None; - let msg = Message::Close; - let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 16); - assert_eq!(&buf[..size], &[118,112,110,1,0,0,1,3,0,0,0,0,0,0,0,134]); - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let mut msg = Message::Close; + let mut should = [118,112,110,1,0,0,1,3,0,0,0,0,0,0,0,134]; + { + let mut buf = [0; 1024]; + let res = encode(&mut options, &mut msg, &mut buf[..], &mut crypto); + assert_eq!(res.len(), 16); + assert_eq!(&res, &should); + } + let (options2, msg2) = decode(&mut should, &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); } @@ -80,15 +133,17 @@ fn udpmessage_init() { let addrs = vec![Range{base: Address{data: [0,1,2,3,0,0,0,0,0,0,0,0,0,0,0,0], len: 4}, prefix_len: 24}, Range{base: Address{data: [0,1,2,3,4,5,0,0,0,0,0,0,0,0,0,0], len: 6}, prefix_len: 16}]; let node_id = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]; - let msg = Message::Init(0, node_id, addrs); - let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 40); - let should = [118,112,110,1,0,0,0,2,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,2,4,0,1,2,3,24,6,0,1,2,3,4,5,16]; - for i in 0..size { - assert_eq!(buf[i], should[i]); + let mut msg = Message::Init(0, node_id, addrs); + let mut should = [118,112,110,1,0,0,0,2,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,2,4,0,1,2,3,24,6,0,1,2,3,4,5,16]; + { + let mut buf = [0; 1024]; + let res = encode(&mut options, &mut msg, &mut buf[..], &mut crypto); + assert_eq!(res.len(), 40); + for i in 0..res.len() { + assert_eq!(res[i], should[i]); + } } - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let (options2, msg2) = decode(&mut should, &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); } @@ -97,12 +152,15 @@ fn udpmessage_init() { fn udpmessage_close() { let mut options = Options::default(); let mut crypto = Crypto::None; - let msg = Message::Close; - let mut buf = [0; 1024]; - let size = encode(&mut options, &msg, &mut buf[..], &mut crypto); - assert_eq!(size, 8); - assert_eq!(&buf[..size], &[118,112,110,1,0,0,0,3]); - let (options2, msg2) = decode(&mut buf[..size], &mut crypto).unwrap(); + let mut msg = Message::Close; + let mut should = [118,112,110,1,0,0,0,3]; + { + let mut buf = [0; 1024]; + let res = encode(&mut options, &mut msg, &mut buf[..], &mut crypto); + assert_eq!(res.len(), 8); + assert_eq!(&res, &should); + } + let (options2, msg2) = decode(&mut should, &mut crypto).unwrap(); assert_eq!(options, options2); assert_eq!(msg, msg2); } @@ -242,7 +300,7 @@ fn address_parse_fmt() { #[test] fn message_fmt() { - assert_eq!(format!("{:?}", Message::Data(&[1,2,3,4,5])), "Data(5 bytes)"); + assert_eq!(format!("{:?}", Message::Data(&mut [1,2,3,4,5], 0, 5)), "Data(5 bytes)"); assert_eq!(format!("{:?}", Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap(), SocketAddr::from_str("[0001:0203:0405:0607:0809:0a0b:0c0d:0e0f]:6789").unwrap()])), diff --git a/src/udpmessage.rs b/src/udpmessage.rs index 311f706..96cc146 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -2,7 +2,7 @@ use std::fmt; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr, SocketAddrV6, Ipv6Addr}; use super::types::{NodeId, Error, NetworkId, Range, NODE_ID_BYTES}; -use super::util::{bytes_to_hex, Encoder, memcopy}; +use super::util::{bytes_to_hex, Encoder}; use super::crypto::Crypto; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; @@ -65,19 +65,17 @@ pub struct Options { pub network_id: Option, } - -#[derive(PartialEq)] pub enum Message<'a> { - Data(&'a[u8]), - Peers(Vec), - Init(u8, NodeId, Vec), + Data(&'a mut[u8], usize, usize), // data, start, end + Peers(Vec), // peers + Init(u8, NodeId, Vec), // step, node_id, ranges Close, } impl<'a> fmt::Debug for Message<'a> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self { - &Message::Data(ref data) => write!(formatter, "Data({} bytes)", data.len()), + &Message::Data(_, start, end) => write!(formatter, "Data({} bytes)", end-start), &Message::Peers(ref peers) => { try!(write!(formatter, "Peers [")); let mut first = true; @@ -130,7 +128,7 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M pos += NETWORK_ID_BYTES; } let msg = match header.msgtype { - 0 => Message::Data(&data[pos..end]), + 0 => Message::Data(data, pos, end), 1 => { if end < pos + 1 { return Err(Error::ParseError("Empty peers")); @@ -197,32 +195,16 @@ pub fn decode<'a>(data: &'a mut [u8], crypto: &mut Crypto) -> Result<(Options, M Ok((options, msg)) } -pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Crypto) -> usize { - let mut pos = 0; - let mut header = TopHeader::default(); - header.msgtype = match msg { - &Message::Data(_) => 0, - &Message::Peers(_) => 1, - &Message::Init(_, _, _) => 2, - &Message::Close => 3 - }; - header.crypto_method = crypto.method(); - if options.network_id.is_some() { - header.flags |= 0x01; - } - pos += header.write_to(&mut buf[pos..]); - pos += crypto.nonce_bytes(); - if let Some(id) = options.network_id { - assert!(buf.len() >= pos + NETWORK_ID_BYTES); - Encoder::write_u64(id, &mut buf[pos..]); - pos += NETWORK_ID_BYTES; - } +pub fn encode<'a>(options: &Options, msg: &'a mut Message, mut buf: &'a mut [u8], crypto: &mut Crypto) -> &'a mut [u8] { + let mut start = 64; + let mut end = 64; match msg { - &Message::Data(ref data) => { - memcopy(data, &mut buf[pos..]); - pos += data.len(); + &mut Message::Data(ref mut data, data_start, data_end) => { + buf = data; + start = data_start; + end = data_end; }, - &Message::Peers(ref peers) => { + &mut Message::Peers(ref peers) => { let mut v4addrs = Vec::new(); let mut v6addrs = Vec::new(); for p in peers { @@ -233,6 +215,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry }; assert!(v4addrs.len() <= 255); assert!(v6addrs.len() <= 255); + let mut pos = start; assert!(buf.len() >= pos + 2 + v4addrs.len() * 6 + v6addrs.len() * 18); buf[pos] = v4addrs.len() as u8; pos += 1; @@ -256,8 +239,10 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry Encoder::write_u16(addr.port(), &mut buf[pos..]); pos += 2; }; + end = pos; }, - &Message::Init(stage, ref node_id, ref ranges) => { + &mut Message::Init(stage, ref node_id, ref ranges) => { + let mut pos = start; assert!(buf.len() >= pos + 2 + NODE_ID_BYTES); buf[pos] = stage; pos += 1; @@ -271,16 +256,40 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8], crypto: &mut Cry for range in ranges { pos += range.write_to(&mut buf[pos..]); } + end = pos; }, - &Message::Close => { + &mut Message::Close => { } } - if crypto.method() > 0 { - let (header, rest) = buf.split_at_mut(TopHeader::size()); - let (nonce, rest) = rest.split_at_mut(crypto.nonce_bytes()); - let crypto_start = header.len() + nonce.len(); - assert!(rest.len() >= pos - crypto_start + crypto.additional_bytes()); - pos = crypto.encrypt(rest, pos-crypto_start, nonce, header) + crypto_start; + assert!(start >= 64); + assert!(buf.len() >= end + 64); + if let Some(id) = options.network_id { + assert!(start >= NETWORK_ID_BYTES); + Encoder::write_u64(id, &mut buf[start-NETWORK_ID_BYTES..]); + start -= NETWORK_ID_BYTES; } - pos + let crypto_start = start; + start -= crypto.nonce_bytes(); + let mut header = TopHeader::default(); + header.msgtype = match msg { + &mut Message::Data(_, _, _) => 0, + &mut Message::Peers(_) => 1, + &mut Message::Init(_, _, _) => 2, + &mut Message::Close => 3 + }; + header.crypto_method = crypto.method(); + if options.network_id.is_some() { + header.flags |= 0x01; + } + start -= TopHeader::size(); + header.write_to(&mut buf[start..]); + if crypto.method() > 0 { + let (junk_before, rest) = buf.split_at_mut(start); + let (header, rest) = rest.split_at_mut(TopHeader::size()); + let (nonce, rest) = rest.split_at_mut(crypto.nonce_bytes()); + debug_assert_eq!(junk_before.len() + header.len() + crypto.nonce_bytes(), crypto_start); + assert!(rest.len() >= end - crypto_start + crypto.additional_bytes()); + end = crypto.encrypt(rest, end-crypto_start, nonce, header) + crypto_start; + } + &mut buf[start..end] } diff --git a/src/util.rs b/src/util.rs index 010449b..b995ad6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,3 @@ -use std::ptr; use libc; pub type Duration = u32; @@ -11,12 +10,6 @@ pub fn now() -> Time { tv.tv_sec } -#[inline(always)] -pub fn memcopy(src: &[u8], dst: &mut[u8]) { - assert!(dst.len() >= src.len()); - unsafe { ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), src.len()) }; -} - const HEX_CHARS: &'static [u8] = b"0123456789abcdef"; pub fn bytes_to_hex(bytes: &[u8]) -> String {