From 28a79f2d47a8e142bf412b99145f73b45d1d94d9 Mon Sep 17 00:00:00 2001 From: Dennis Schwerdel Date: Mon, 23 Nov 2015 00:49:58 +0100 Subject: [PATCH] Rework --- src/cloud.rs | 136 ++++++++++++++++++++++++++++++----------- src/ethernet.rs | 97 ++++++++++++----------------- src/ip.rs | 153 ++++++---------------------------------------- src/udpmessage.rs | 56 +++++++++-------- src/util.rs | 11 +++- 5 files changed, 204 insertions(+), 249 deletions(-) diff --git a/src/cloud.rs b/src/cloud.rs index 2206301..def8a14 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -1,44 +1,111 @@ -use std::net::{SocketAddr, ToSocketAddrs}; +use std::net::{SocketAddr, ToSocketAddrs, Ipv4Addr, Ipv6Addr}; use std::collections::HashMap; use std::hash::Hasher; use std::net::UdpSocket; use std::io::Read; -use std::fmt; +use std::{fmt, ptr}; use std::os::unix::io::AsRawFd; use std::marker::PhantomData; +use std::str::FromStr; use time::{Duration, SteadyTime, precise_time_ns}; use epoll; use super::device::{TunDevice, TapDevice}; use super::udpmessage::{encode, decode, Options, Message}; -use super::ethernet::{Frame, EthAddr, MacTable}; -use super::ip::{InternetProtocol, IpAddress, RoutingTable}; +use super::ethernet::{Frame, MacTable}; +use super::ip::{InternetProtocol, RoutingTable}; +use super::util::as_bytes; pub type NetworkId = u64; +#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Clone)] +pub struct Address(pub Vec); + +impl fmt::Debug for Address { + fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match self.0.len() { + 4 => write!(formatter, "{}.{}.{}.{}", self.0[0], self.0[1], self.0[2], self.0[3]), + 6 => write!(formatter, "{:x}:{:x}:{:x}:{:x}:{:x}:{:x}", + self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]), + 16 => write!(formatter, "{:x}{:x}:{:x}{:x}:{:x}{:x}:{:x}{:x}:{:x}{:x}:{:x}{:x}:{:x}{:x}:{:x}{:x}", + self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5], self.0[6], self.0[7], + self.0[8], self.0[9], self.0[10], self.0[11], self.0[12], self.0[13], self.0[14], self.0[15] + ), + _ => self.0.fmt(formatter) + } + } +} + +impl FromStr for Address { + type Err=Error; + + fn from_str(text: &str) -> Result { + if let Ok(addr) = Ipv4Addr::from_str(text) { + let ip = addr.octets(); + let mut res = Vec::with_capacity(4); + unsafe { + res.set_len(4); + ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len()); + } + return Ok(Address(res)); + } + if let Ok(addr) = Ipv6Addr::from_str(text) { + let mut segments = addr.segments(); + for i in 0..8 { + segments[i] = segments[i].to_be(); + } + let bytes = unsafe { as_bytes(&segments) }; + let mut res = Vec::with_capacity(16); + unsafe { + res.set_len(16); + ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); + } + return Ok(Address(res)); + } + //FIXME: implement for mac addresses + return Err(Error::ParseError("Failed to parse address")) + } +} + + +#[derive(Debug, PartialEq, PartialOrd, Eq, Ord, Hash, Clone)] +pub struct Range { + pub base: Address, + pub prefix_len: u8 +} + +impl FromStr for Range { + type Err=Error; + + fn from_str(text: &str) -> Result { + let pos = match text.find("/") { + Some(pos) => pos, + None => return Err(Error::ParseError("Invalid range format")) + }; + let prefix_len = try!(u8::from_str(&text[pos+1..]) + .map_err(|_| Error::ParseError("Failed to parse prefix length"))); + let base = try!(Address::from_str(&text[..pos])); + Ok(Range{base: base, prefix_len: prefix_len}) + } +} + + #[derive(RustcDecodable, Debug)] pub enum Behavior { Normal, Hub, Switch, Router } -pub trait Address: Sized + fmt::Debug + Clone { - fn from_bytes(&[u8]) -> Result; - fn to_bytes(&self) -> Vec; -} - pub trait Table { - type Address: Address; - fn learn(&mut self, Self::Address, SocketAddr); - fn lookup(&self, &Self::Address) -> Option; + fn learn(&mut self, Address, Option, SocketAddr); + fn lookup(&self, &Address) -> Option; fn housekeep(&mut self); fn remove_all(&mut self, SocketAddr); } pub trait Protocol: Sized { - type Address: Address; - fn parse(&[u8]) -> Result<(Self::Address, Self::Address), Error>; + fn parse(&[u8]) -> Result<(Address, Address), Error>; } pub trait VirtualInterface: AsRawFd { @@ -124,9 +191,9 @@ impl PeerList { } -pub struct GenericCloud, M: Protocol, I: VirtualInterface> { +pub struct GenericCloud { peers: PeerList, - addresses: Vec, + addresses: Vec, learning: bool, broadcast: bool, reconnect_peers: Vec, @@ -138,12 +205,12 @@ pub struct GenericCloud, M: Protocol, update_freq: Duration, buffer_out: [u8; 64*1024], next_housekeep: SteadyTime, - _dummy_m: PhantomData, + _dummy_m: PhantomData

, } -impl, M: Protocol, I: VirtualInterface> GenericCloud { +impl GenericCloud { pub fn new(device: I, listen: String, network_id: Option, table: T, - peer_timeout: Duration, learning: bool, broadcast: bool, addresses: Vec) -> Self { + peer_timeout: Duration, learning: bool, broadcast: bool, addresses: Vec) -> Self { let socket = match UdpSocket::bind(&listen as &str) { Ok(socket) => socket, _ => panic!("Failed to open socket") @@ -166,7 +233,7 @@ impl, M: Protocol, I: VirtualInterfac } } - fn send_msg(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { + fn send_msg(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { debug!("Sending {:?} to {}", msg, addr); let mut options = Options::default(); options.network_id = self.network_id; @@ -224,7 +291,7 @@ impl, M: Protocol, I: VirtualInterfac } fn handle_interface_data(&mut self, payload: &[u8]) -> Result<(), Error> { - let (src, dst) = try!(M::parse(payload)); + let (src, dst) = try!(P::parse(payload)); debug!("Read data from interface: src: {:?}, dst: {:?}, {} bytes", src, dst, payload.len()); match self.table.lookup(&dst) { Some(addr) => { @@ -250,7 +317,7 @@ impl, M: Protocol, I: VirtualInterfac Ok(()) } - fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message) -> Result<(), Error> { + fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message) -> Result<(), Error> { if let Some(id) = self.network_id { if options.network_id != Some(id) { info!("Ignoring message from {} with wrong token {:?}", peer, options.network_id); @@ -260,7 +327,7 @@ impl, M: Protocol, I: VirtualInterfac debug!("Recieved {:?} from {}", msg, peer); match msg { Message::Data(payload) => { - let (src, _dst) = try!(M::parse(payload)); + let (src, _dst) = try!(P::parse(payload)); debug!("Writing data to device: {} bytes", payload.len()); match self.device.write(&payload) { Ok(()) => (), @@ -271,7 +338,8 @@ impl, M: Protocol, I: VirtualInterfac } self.peers.add(&peer); if self.learning { - self.table.learn(src, peer); + //learn single address + self.table.learn(src, None, peer); } }, Message::Peers(peers) => { @@ -282,7 +350,7 @@ impl, M: Protocol, I: VirtualInterfac } } }, - Message::Init(addrs) => { + Message::Init(ranges) => { if self.peers.contains(&peer) { return Ok(()); } @@ -291,8 +359,8 @@ impl, M: Protocol, I: VirtualInterfac let own_addrs = self.addresses.clone(); try!(self.send_msg(peer, &Message::Init(own_addrs))); try!(self.send_msg(peer, &Message::Peers(peers))); - for addr in addrs { - self.table.learn(addr, peer.clone()); + for range in ranges { + self.table.learn(range.base, Some(range.prefix_len), peer.clone()); } }, Message::Close => { @@ -349,7 +417,7 @@ impl, M: Protocol, I: VirtualInterfac } -pub type TapCloud = GenericCloud; +pub type TapCloud = GenericCloud; impl TapCloud { pub fn new_tap_cloud(device: &str, listen: String, behavior: Behavior, network_id: Option, mac_timeout: Duration, peer_timeout: Duration) -> Self { @@ -370,19 +438,19 @@ impl TapCloud { } -pub type TunCloud = GenericCloud; +pub type TunCloud = GenericCloud; impl TunCloud { - pub fn new_tun_cloud(device: &str, listen: String, behavior: Behavior, network_id: Option, subnets: Vec, peer_timeout: Duration) -> Self { + pub fn new_tun_cloud(device: &str, listen: String, behavior: Behavior, network_id: Option, range_strs: Vec, peer_timeout: Duration) -> Self { let device = match TunDevice::new(device) { Ok(device) => device, _ => panic!("Failed to open tun device") }; info!("Opened tun device {}", device.ifname()); let table = RoutingTable::new(); - let mut addrs = Vec::with_capacity(subnets.len()); - for s in subnets { - addrs.push(IpAddress::from_str(&s).expect("Invalid subnet")); + let mut ranges = Vec::with_capacity(range_strs.len()); + for s in range_strs { + ranges.push(Range::from_str(&s).expect("Invalid subnet")); } let (learning, broadcasting) = match behavior { Behavior::Normal => (false, false), @@ -390,6 +458,6 @@ impl TunCloud { Behavior::Hub => (false, true), Behavior::Router => (false, false) }; - Self::new(device, listen, network_id, table, peer_timeout, learning, broadcasting, addrs) + Self::new(device, listen, network_id, table, peer_timeout, learning, broadcasting, ranges) } } diff --git a/src/ethernet.rs b/src/ethernet.rs index 3e3c4ec..2fbce06 100644 --- a/src/ethernet.rs +++ b/src/ethernet.rs @@ -1,67 +1,52 @@ -use std::{mem, fmt}; +use std::ptr; use std::net::SocketAddr; use std::collections::HashMap; -use std::io::Write; use super::cloud::{Error, Table, Protocol, Address}; -use super::util::as_obj; use time::{Duration, SteadyTime}; -#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub struct Mac(pub [u8; 6]); - -impl fmt::Debug for Mac { - fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(formatter, "{:x}:{:x}:{:x}:{:x}:{:x}:{:x}", - self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]) - } -} - -pub type VlanId = u16; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct EthAddr { - pub mac: Mac, - pub vlan: Option -} - -impl Address for EthAddr { - fn from_bytes(_bytes: &[u8]) -> Result { - unimplemented!() - } - - fn to_bytes(&self) -> Vec { - unimplemented!() - } -} - - #[derive(PartialEq)] pub struct Frame; impl Protocol for Frame { - type Address = EthAddr; - - fn parse(data: &[u8]) -> Result<(EthAddr, EthAddr), Error> { + fn parse(data: &[u8]) -> Result<(Address, Address), Error> { if data.len() < 14 { return Err(Error::ParseError("Frame is too short")); } let mut pos = 0; - let dst = *unsafe { as_obj::(&data[pos..]) }; - pos += mem::size_of::(); - let src = *unsafe { as_obj::(&data[pos..]) }; - pos += mem::size_of::(); - let mut vlan = None; + let dst_data = &data[pos..pos+6]; + pos += 6; + let src_data = &data[pos..pos+6]; + pos += 6; if data[pos] == 0x81 && data[pos+1] == 0x00 { pos += 2; if data.len() < pos + 2 { return Err(Error::ParseError("Vlan frame is too short")); } - vlan = Some(u16::from_be(* unsafe { as_obj::(&data[pos..]) })); + let mut src = Vec::with_capacity(8); + let mut dst = Vec::with_capacity(8); + unsafe { + ptr::copy_nonoverlapping(data.as_ptr(), src.as_mut_ptr(), 2); + ptr::copy_nonoverlapping(src_data.as_ptr(), src[2..].as_mut_ptr(), 6); + src.set_len(8); + ptr::copy_nonoverlapping(data.as_ptr(), dst.as_mut_ptr(), 2); + ptr::copy_nonoverlapping(dst_data.as_ptr(), dst[2..].as_mut_ptr(), 6); + dst.set_len(8); + } + Ok((Address(src), Address(dst))) + } else { + let mut src = Vec::with_capacity(6); + let mut dst = Vec::with_capacity(6); + unsafe { + ptr::copy_nonoverlapping(src_data.as_ptr(), src.as_mut_ptr(), 6); + src.set_len(6); + ptr::copy_nonoverlapping(dst_data.as_ptr(), dst.as_mut_ptr(), 6); + dst.set_len(6); + } + Ok((Address(src), Address(dst))) } - Ok((EthAddr{mac: src, vlan: vlan}, EthAddr{mac: dst, vlan: vlan})) } } @@ -73,7 +58,7 @@ struct MacTableValue { pub struct MacTable { - table: HashMap, + table: HashMap, timeout: Duration } @@ -84,30 +69,28 @@ impl MacTable { } impl Table for MacTable { - type Address = EthAddr; - fn housekeep(&mut self) { let now = SteadyTime::now(); - let mut del: Vec = Vec::new(); - for (&key, val) in &self.table { + let mut del: Vec

= Vec::new(); + for (key, val) in &self.table { if val.timeout < now { - del.push(key); + del.push(key.clone()); } } for key in del { - info!("Forgot mac: {:?} (vlan {:?})", key.mac, key.vlan); + info!("Forgot address {:?}", key); self.table.remove(&key); } } - fn learn(&mut self, key: Self::Address, addr: SocketAddr) { + fn learn(&mut self, key: Address, _prefix_len: Option, addr: SocketAddr) { let value = MacTableValue{address: addr, timeout: SteadyTime::now()+self.timeout}; - if self.table.insert(key, value).is_none() { - info!("Learned mac: {:?} (vlan {:?}) => {}", key.mac, key.vlan, addr); + if self.table.insert(key.clone(), value).is_none() { + info!("Learned address {:?} => {}", key, addr); } } - fn lookup(&self, key: &Self::Address) -> Option { + fn lookup(&self, key: &Address) -> Option { match self.table.get(key) { Some(value) => Some(value.address), None => None @@ -124,14 +107,14 @@ impl Table for MacTable { fn without_vlan() { let data = [6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8]; let (src, dst) = Frame::parse(&data).unwrap(); - assert_eq!(src, EthAddr{mac: Mac([1,2,3,4,5,6]), vlan: None}); - assert_eq!(dst, EthAddr{mac: Mac([6,5,4,3,2,1]), vlan: None}); + assert_eq!(src, Address(vec![1,2,3,4,5,6])); + assert_eq!(dst, Address(vec![6,5,4,3,2,1])); } #[test] fn with_vlan() { let data = [6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8]; let (src, dst) = Frame::parse(&data).unwrap(); - assert_eq!(src, EthAddr{mac: Mac([1,2,3,4,5,6]), vlan: Some(1234)}); - assert_eq!(dst, EthAddr{mac: Mac([6,5,4,3,2,1]), vlan: Some(1234)}); + assert_eq!(src, Address(vec![4,210,1,2,3,4,5,6])); + assert_eq!(dst, Address(vec![4,210,6,5,4,3,2,1])); } diff --git a/src/ip.rs b/src/ip.rs index 3f466d2..767b683 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -1,110 +1,15 @@ -use std::net::{SocketAddr, Ipv4Addr, Ipv6Addr}; +use std::net::SocketAddr; use std::collections::{hash_map, HashMap}; -use std::ptr; use std::io::Read; -use std::str::FromStr; use super::cloud::{Protocol, Error, Table, Address}; -use super::util::{as_obj, as_bytes}; - - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum IpAddress { - V4(Ipv4Addr), - V6(Ipv6Addr), - V4Net(Ipv4Addr, u8), - V6Net(Ipv6Addr, u8), -} - -impl Address for IpAddress { - fn to_bytes(&self) -> Vec { - match self { - &IpAddress::V4(addr) => { - let ip = addr.octets(); - let mut res = Vec::with_capacity(4); - unsafe { - res.set_len(4); - ptr::copy_nonoverlapping(ip.as_ptr(), res.as_mut_ptr(), ip.len()); - } - res - }, - &IpAddress::V4Net(addr, prefix_len) => { - let mut bytes = IpAddress::V4(addr).to_bytes(); - bytes.push(prefix_len); - bytes - }, - &IpAddress::V6(addr) => { - let mut segments = addr.segments(); - for i in 0..8 { - segments[i] = segments[i].to_be(); - } - let bytes = unsafe { as_bytes(&segments) }; - let mut res = Vec::with_capacity(16); - unsafe { - res.set_len(16); - ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); - } - res - }, - &IpAddress::V6Net(addr, prefix_len) => { - let mut bytes = IpAddress::V6(addr).to_bytes(); - bytes.push(prefix_len); - bytes - } - } - } - - fn from_bytes(bytes: &[u8]) -> Result { - match bytes.len() { - 4 => Ok(IpAddress::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]))), - 5 => Ok(IpAddress::V4Net(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]), bytes[4])), - 16 => { - let data = unsafe { as_obj::<[u16; 8]>(&bytes) }; - Ok(IpAddress::V6(Ipv6Addr::new( - u16::from_be(data[0]), u16::from_be(data[1]), - u16::from_be(data[2]), u16::from_be(data[3]), - u16::from_be(data[4]), u16::from_be(data[5]), - u16::from_be(data[6]), u16::from_be(data[7]), - ))) - }, - 17 => { - let data = unsafe { as_obj::<[u16; 8]>(&bytes) }; - Ok(IpAddress::V6Net(Ipv6Addr::new( - u16::from_be(data[0]), u16::from_be(data[1]), - u16::from_be(data[2]), u16::from_be(data[3]), - u16::from_be(data[4]), u16::from_be(data[5]), - u16::from_be(data[6]), u16::from_be(data[7]), - ), bytes[16])) - } - _ => Err(Error::ParseError("Invalid address size")) - } - } -} - -impl IpAddress { - pub fn from_str(addr: &str) -> Result { - if let Some(pos) = addr.find("/") { - let prefix_len = try!(u8::from_str(&addr[pos+1..]) - .map_err(|_| Error::ParseError("Failed to parse prefix length"))); - let addr = &addr[..pos]; - let ipv4 = Ipv4Addr::from_str(addr).map(|addr| IpAddress::V4Net(addr, prefix_len)); - let ipv6 = Ipv6Addr::from_str(addr).map(|addr| IpAddress::V6Net(addr, prefix_len)); - ipv4.or(ipv6).map_err(|_| Error::ParseError("Failed to parse address")) - } else { - let ipv4 = Ipv4Addr::from_str(addr).map(|addr| IpAddress::V4(addr)); - let ipv6 = Ipv6Addr::from_str(addr).map(|addr| IpAddress::V6(addr)); - ipv4.or(ipv6).map_err(|_| Error::ParseError("Failed to parse address")) - } - } -} +use super::util::to_vec; #[allow(dead_code)] pub struct InternetProtocol; impl Protocol for InternetProtocol { - type Address = IpAddress; - - fn parse(data: &[u8]) -> Result<(IpAddress, IpAddress), Error> { + fn parse(data: &[u8]) -> Result<(Address, Address), Error> { if data.len() < 1 { return Err(Error::ParseError("Empty header")); } @@ -114,13 +19,13 @@ impl Protocol for InternetProtocol { if data.len() < 20 { return Err(Error::ParseError("Truncated header")); } - Ok((try!(IpAddress::from_bytes(&data[12..16])), try!(IpAddress::from_bytes(&data[16..20])))) + Ok((Address(to_vec(&data[12..16])), Address(to_vec(&data[16..20])))) }, 6 => { if data.len() < 40 { return Err(Error::ParseError("Truncated header")); } - Ok((try!(IpAddress::from_bytes(&data[8..24])), try!(IpAddress::from_bytes(&data[24..40])))) + Ok((Address(to_vec(&data[8..24])), Address(to_vec(&data[24..40])))) }, _ => Err(Error::ParseError("Invalid version")) } @@ -128,7 +33,6 @@ impl Protocol for InternetProtocol { } - struct RoutingEntry { address: SocketAddr, bytes: Vec, @@ -141,28 +45,34 @@ impl RoutingTable { pub fn new() -> Self { RoutingTable(HashMap::new()) } +} - pub fn add(&mut self, bytes: Vec, prefix_len: u8, address: SocketAddr) { +impl Table for RoutingTable { + fn learn(&mut self, addr: Address, prefix_len: Option, address: SocketAddr) { + let prefix_len = match prefix_len { + Some(val) => val, + None => addr.0.len() as u8 * 8 + }; let group_len = (prefix_len as usize / 16) * 2; - let group_bytes: Vec = bytes[..group_len].iter().map(|b| *b).collect(); - let routing_entry = RoutingEntry{address: address, bytes: bytes, prefix_len: prefix_len}; + let group_bytes: Vec = addr.0[..group_len].iter().map(|b| *b).collect(); + let routing_entry = RoutingEntry{address: address, bytes: addr.0, prefix_len: prefix_len}; match self.0.entry(group_bytes) { hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(routing_entry), hash_map::Entry::Vacant(entry) => { entry.insert(vec![routing_entry]); () } } } - pub fn lookup_bytes(&self, bytes: &[u8]) -> Option { - let len = bytes.len()/2 * 2; + fn lookup(&self, addr: &Address) -> Option { + let len = addr.0.len()/2 * 2; for i in 0..len/2 { - if let Some(group) = self.0.get(&bytes[0..len-2*i]) { + if let Some(group) = self.0.get(&addr.0[0..len-2*i]) { for entry in group { - if entry.bytes.len() != bytes.len() { + if entry.bytes.len() != addr.0.len() { continue; } let mut match_len = 0; - for i in 0..bytes.len() { - let b = bytes[i] ^ entry.bytes[i]; + for i in 0..addr.0.len() { + let b = addr.0[i] ^ entry.bytes[i]; if b == 0 { match_len += 8; } else { @@ -178,29 +88,6 @@ impl RoutingTable { } None } -} - -impl Table for RoutingTable { - type Address = IpAddress; - - fn learn(&mut self, src: Self::Address, addr: SocketAddr) { - match src { - IpAddress::V4(_) => (), - IpAddress::V4Net(base, prefix_len) => { - info!("Adding to routing table: {}/{} => {}", base, prefix_len, addr); - self.add(IpAddress::V4(base).to_bytes(), prefix_len, addr); - }, - IpAddress::V6(_) => (), - IpAddress::V6Net(base, prefix_len) => { - info!("Adding to routing table: {}/{} => {}", base, prefix_len, addr); - self.add(IpAddress::V6(base).to_bytes(), prefix_len, addr); - } - } - } - - fn lookup(&self, dst: &Self::Address) -> Option { - self.lookup_bytes(&dst.to_bytes()) - } fn housekeep(&mut self) { //nothin to do diff --git a/src/udpmessage.rs b/src/udpmessage.rs index 1723dfa..500abff 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -2,9 +2,8 @@ use std::{mem, ptr, fmt}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::u16; -use super::cloud::{Error, NetworkId, Address}; -use super::util::{as_obj, as_bytes}; -#[cfg(test)] use super::ethernet; +use super::cloud::{Error, NetworkId, Range, Address}; +use super::util::{as_obj, as_bytes, to_vec}; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; const VERSION: u8 = 0; @@ -31,14 +30,14 @@ pub struct Options { #[derive(PartialEq)] -pub enum Message<'a, A: Address> { +pub enum Message<'a> { Data(&'a[u8]), Peers(Vec), - Init(Vec), + Init(Vec), Close, } -impl<'a, A: Address> fmt::Debug for Message<'a, A> { +impl<'a> fmt::Debug for Message<'a> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self { &Message::Data(ref data) => write!(formatter, "Data(data: {} bytes)", data.len()), @@ -60,7 +59,7 @@ impl<'a, A: Address> fmt::Debug for Message<'a, A> { } } -pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { +pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { if data.len() < mem::size_of::() { return Err(Error::ParseError("Empty message")); } @@ -125,8 +124,14 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { if data.len() < pos + len { return Err(Error::ParseError("Init data too short")); } - addrs.push(try!(A::from_bytes(&data[pos..pos+len]))); + let base = Address(to_vec(&data[pos..pos+len])); pos += len; + if data.len() < pos + 1 { + return Err(Error::ParseError("Init data too short")); + } + let prefix_len = data[pos]; + pos += 1; + addrs.push(Range{base: base, prefix_len: prefix_len}); } Message::Init(addrs) }, @@ -136,7 +141,7 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { Ok((options, msg)) } -pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize { +pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize { assert!(buf.len() >= mem::size_of::()); let mut pos = 0; let mut header = TopHeader::default(); @@ -192,19 +197,22 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) - buf[pos] = 0; pos += 1; }, - &Message::Init(ref addrs) => { + &Message::Init(ref ranges) => { assert!(buf.len() >= pos + 1); - assert!(addrs.len() <= 255); - buf[pos] = addrs.len() as u8; + assert!(ranges.len() <= 255); + buf[pos] = ranges.len() as u8; pos += 1; - for addr in addrs { - let bytes = addr.to_bytes(); - assert!(bytes.len() <= 255); - assert!(buf.len() >= pos + 1 + bytes.len()); - buf[pos] = bytes.len() as u8; + for range in ranges { + let base = &range.base; + let len = base.0.len(); + assert!(len <= 255); + assert!(buf.len() >= pos + 1 + len + 1); + buf[pos] = len as u8; + pos += 1; + unsafe { ptr::copy_nonoverlapping(base.0.as_ptr(), buf[pos..].as_mut_ptr(), base.0.len()) }; + pos += base.0.len(); + buf[pos] = range.prefix_len; pos += 1; - unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), buf[pos..].as_mut_ptr(), bytes.len()) }; - pos += bytes.len(); } }, &Message::Close => { @@ -218,7 +226,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) - fn encode_message_packet() { let options = Options::default(); let payload = [1,2,3,4,5]; - let msg: Message = Message::Data(&payload); + let msg = Message::Data(&payload); let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 13); @@ -232,7 +240,7 @@ fn encode_message_packet() { fn encode_message_peers() { use std::str::FromStr; let options = Options::default(); - let msg: Message = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]); + let msg = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]); let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 22); @@ -246,7 +254,7 @@ fn encode_message_peers() { fn encode_option_network_id() { let mut options = Options::default(); options.network_id = Some(134); - let msg: Message = Message::Close; + let msg = Message::Close; let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 16); @@ -260,7 +268,7 @@ fn encode_option_network_id() { fn encode_message_init() { let options = Options::default(); let addrs = vec![]; - let msg: Message = Message::Init(addrs); + let msg = Message::Init(addrs); let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 13); @@ -273,7 +281,7 @@ fn encode_message_init() { #[test] fn encode_message_close() { let options = Options::default(); - let msg: Message = Message::Close; + let msg = Message::Close; let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 8); diff --git a/src/util.rs b/src/util.rs index 3f886c4..8a481a4 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,4 @@ -use std::{mem, slice}; +use std::{mem, slice, ptr}; #[inline(always)] pub unsafe fn as_bytes(obj: &T) -> &[u8] { @@ -10,3 +10,12 @@ pub unsafe fn as_obj(data: &[u8]) -> &T { assert!(data.len() >= mem::size_of::()); mem::transmute(data.as_ptr()) } + +pub fn to_vec(data: &[u8]) -> Vec { + let mut v = Vec::with_capacity(data.len()); + unsafe { + ptr::copy_nonoverlapping(data.as_ptr(), v.as_mut_ptr(), data.len()); + v.set_len(data.len()); + } + v +}