From 87571feef62e799aa844a5e3ed44b3b8e1703d6a Mon Sep 17 00:00:00 2001 From: Dennis Schwerdel Date: Sun, 22 Nov 2015 18:05:15 +0100 Subject: [PATCH] Simplification --- src/cloud.rs | 47 +++++++++++---------- src/ethernet.rs | 101 +++++++++------------------------------------- src/ip.rs | 40 ++---------------- src/udpmessage.rs | 44 +++++++++----------- 4 files changed, 65 insertions(+), 167 deletions(-) diff --git a/src/cloud.rs b/src/cloud.rs index 331a09d..2d98925 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -19,21 +19,18 @@ pub type NetworkId = u64; pub trait Table { type Address; fn learn(&mut self, Self::Address, SocketAddr); - fn lookup(&self, Self::Address) -> Option; + fn lookup(&self, &Self::Address) -> Option; fn housekeep(&mut self); } -pub trait InterfaceMessage: fmt::Debug + Sized { +pub trait Protocol: Sized { type Address; - fn src(&self) -> Self::Address; - fn dst(&self) -> Self::Address; - fn encode_to(&self, &[u8], &mut [u8]) -> usize; - fn parse_from(&[u8]) -> Result<(Self, &[u8]), Error>; + fn parse(&[u8]) -> Result<(Self::Address, Self::Address), Error>; } pub trait VirtualInterface: AsRawFd { - fn read<'a, T: InterfaceMessage>(&mut self, &'a mut [u8]) -> Result<(T, &'a[u8]), Error>; - fn write(&mut self, &T, &[u8]) -> Result<(), Error>; + fn read(&mut self, &mut [u8]) -> Result; + fn write(&mut self, &[u8]) -> Result<(), Error>; } @@ -114,7 +111,7 @@ impl PeerList { } -pub struct GenericCloud, M: InterfaceMessage, I: VirtualInterface> { +pub struct GenericCloud, M: Protocol, I: VirtualInterface> { peers: PeerList, reconnect_peers: Vec, table: T, @@ -128,7 +125,7 @@ pub struct GenericCloud, M: InterfaceMessage, _dummy_m: PhantomData, } -impl, M: InterfaceMessage, I: VirtualInterface> GenericCloud { +impl, M: Protocol, I: VirtualInterface> GenericCloud { pub fn new(device: I, listen: String, network_id: Option, table: T, peer_timeout: Duration) -> Self { let socket = match UdpSocket::bind(&listen as &str) { Ok(socket) => socket, @@ -149,7 +146,7 @@ impl, M: InterfaceMessage, I: Virt } } - fn send_msg(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { + fn send_msg(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { debug!("Sending {:?} to {}", msg, addr); let mut options = Options::default(); options.network_id = self.network_id; @@ -206,16 +203,17 @@ impl, M: InterfaceMessage, I: Virt Ok(()) } - fn handle_interface_data(&mut self, header: M, payload: &[u8]) -> Result<(), Error> { - debug!("Read data from interface {:?}, {} bytes", header, payload.len()); - match self.table.lookup(header.dst()) { + fn handle_interface_data(&mut self, payload: &[u8]) -> Result<(), Error> { + let (src, dst) = try!(M::parse(payload)); + debug!("Read data from interface: src: {:?}, dst: {:?}, {} bytes", src, dst, payload.len()); + match self.table.lookup(&dst) { Some(addr) => { - debug!("Found destination for {:?} => {}", header.dst(), addr); - try!(self.send_msg(addr, &Message::Frame(header, payload))) + debug!("Found destination for {:?} => {}", dst, addr); + try!(self.send_msg(addr, &Message::Frame(payload))) }, None => { - debug!("No destination for {:?} found, broadcasting", header.dst()); - let msg = Message::Frame(header, payload); + debug!("No destination for {:?} found, broadcasting", dst); + let msg = Message::Frame(payload); for addr in &self.peers.as_vec() { try!(self.send_msg(addr, &msg)); } @@ -224,7 +222,7 @@ impl, M: InterfaceMessage, I: Virt Ok(()) } - fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message) -> Result<(), Error> { + fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message) -> Result<(), Error> { if let Some(id) = self.network_id { if options.network_id != Some(id) { info!("Ignoring message from {} with wrong token {:?}", peer, options.network_id); @@ -233,9 +231,10 @@ impl, M: InterfaceMessage, I: Virt } debug!("Recieved {:?} from {}", msg, peer); match msg { - Message::Frame(header, payload) => { - debug!("Writing data to device: {:?}, {} bytes", header, payload.len()); - match self.device.write(&header, &payload) { + Message::Frame(payload) => { + let (src, _dst) = try!(M::parse(payload)); + debug!("Writing data to device: {} bytes", payload.len()); + match self.device.write(&payload) { Ok(()) => (), Err(e) => { error!("Failed to send via tap device {:?}", e); @@ -243,7 +242,7 @@ impl, M: InterfaceMessage, I: Virt } } self.peers.add(&peer); - self.table.learn(header.src(), peer); + self.table.learn(src, peer); }, Message::Peers(peers) => { self.peers.add(&peer); @@ -290,7 +289,7 @@ impl, M: InterfaceMessage, I: Virt Err(_error) => panic!("Failed to read from network socket") }, &1 => match self.device.read(&mut buffer) { - Ok((header, payload)) => match self.handle_interface_data(header, payload) { + Ok(size) => match self.handle_interface_data(&buffer[..size]) { Ok(_) => (), Err(e) => error!("Error: {:?}", e) }, diff --git a/src/ethernet.rs b/src/ethernet.rs index d0df77d..caaa150 100644 --- a/src/ethernet.rs +++ b/src/ethernet.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::os::unix::io::{AsRawFd, RawFd}; use std::io::{Result as IoResult, Error as IoError, Read, Write}; -use super::cloud::{Error, Table, InterfaceMessage, VirtualInterface}; +use super::cloud::{Error, Table, Protocol, VirtualInterface}; use super::util::{as_bytes, as_obj}; use time::{Duration, SteadyTime}; @@ -34,52 +34,12 @@ pub struct EthAddr { #[derive(PartialEq)] -pub struct Frame { - pub src: EthAddr, - pub dst: EthAddr -} +pub struct Frame; -impl fmt::Debug for Frame { - fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(formatter, "src: {:?}, dst: {:?}, vlan: {:?}", self.src.mac, self.dst.mac, self.src.vlan) - } -} - -impl InterfaceMessage for Frame { +impl Protocol for Frame { type Address = EthAddr; - fn src(&self) -> Self::Address { - self.src - } - - fn dst(&self) -> Self::Address { - self.dst - } - - fn encode_to(&self, payload: &[u8], data: &mut [u8]) -> usize { - assert!(data.len() >= 16 + payload.len()); - let mut pos = 0; - unsafe { - let dst_dat = as_bytes::(&self.dst.mac); - ptr::copy_nonoverlapping(dst_dat.as_ptr(), data[pos..].as_mut_ptr(), dst_dat.len()); - pos += dst_dat.len(); - let src_dat = as_bytes::(&self.src.mac); - ptr::copy_nonoverlapping(src_dat.as_ptr(), data[pos..].as_mut_ptr(), src_dat.len()); - pos += src_dat.len(); - if let Some(vlan) = self.src.vlan { - data[pos] = 0x81; data[pos+1] = 0x00; - pos += 2; - let vlan_dat = mem::transmute::(vlan.to_be()); - ptr::copy_nonoverlapping(vlan_dat.as_ptr(), data[pos..].as_mut_ptr(), vlan_dat.len()); - pos += vlan_dat.len(); - } - ptr::copy_nonoverlapping(payload.as_ptr(), data[pos..].as_mut_ptr(), payload.len()); - } - pos += payload.len(); - pos - } - - fn parse_from(data: &[u8]) -> Result<(Frame, &[u8]), Error> { + fn parse(data: &[u8]) -> Result<(EthAddr, EthAddr), Error> { if data.len() < 14 { return Err(Error::ParseError("Frame is too short")); } @@ -89,17 +49,14 @@ impl InterfaceMessage for Frame { let src = *unsafe { as_obj::(&data[pos..]) }; pos += mem::size_of::(); let mut vlan = None; - let mut payload = &data[pos..]; if data[pos] == 0x81 && data[pos+1] == 0x00 { pos += 2; if data.len() < pos + 2 { return Err(Error::ParseError("Vlan frame is too short")); } vlan = Some(u16::from_be(* unsafe { as_obj::(&data[pos..]) })); - pos += 2; - payload = &data[pos..]; } - Ok((Frame{src: EthAddr{mac: src, vlan: vlan}, dst: EthAddr{mac: dst, vlan: vlan}}, payload)) + Ok((EthAddr{mac: src, vlan: vlan}, EthAddr{mac: dst, vlan: vlan})) } } @@ -136,18 +93,12 @@ impl AsRawFd for TapDevice { } impl VirtualInterface for TapDevice { - fn read<'a, T: InterfaceMessage>(&mut self, mut buffer: &'a mut [u8]) -> Result<(T, &'a[u8]), Error> { - let size = match self.fd.read(&mut buffer) { - Ok(size) => size, - Err(_) => return Err(Error::TunTapDevError("Read error")) - }; - T::parse_from(&buffer[..size]) + fn read(&mut self, mut buffer: &mut [u8]) -> Result { + self.fd.read(&mut buffer).map_err(|_| Error::TunTapDevError("Read error")) } - fn write(&mut self, msg: &T, payload: &[u8]) -> Result<(), Error> { - let mut buffer = [0u8; 64*1024]; - let size = msg.encode_to(payload, &mut buffer); - match self.fd.write_all(&buffer[..size]) { + fn write(&mut self, data: &[u8]) -> Result<(), Error> { + match self.fd.write_all(&data) { Ok(_) => Ok(()), Err(_) => Err(Error::TunTapDevError("Write error")) } @@ -196,8 +147,8 @@ impl Table for MacTable { } } - fn lookup(&self, key: Self::Address) -> Option { - match self.table.get(&key) { + fn lookup(&self, key: &Self::Address) -> Option { + match self.table.get(key) { Some(value) => Some(value.address), None => None } @@ -207,30 +158,16 @@ impl Table for MacTable { #[test] fn without_vlan() { - let src = Mac([1,2,3,4,5,6]); - let dst = Mac([6,5,4,3,2,1]); - let payload = [1,2,3,4,5,6,7,8]; - let mut buf = [0u8; 1024]; - let frame = Frame{src: EthAddr{mac: src, vlan: None}, dst: EthAddr{mac: dst, vlan: None}}; - let size = frame.encode_to(&payload, &mut buf); - assert_eq!(size, 20); - assert_eq!(&buf[..size], &[6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8]); - let (frame2, payload2) = Frame::parse_from(&buf[..size]).unwrap(); - assert_eq!(frame, frame2); - assert_eq!(payload, payload2); + let data = [6,5,4,3,2,1,1,2,3,4,5,6,1,2,3,4,5,6,7,8]; + let (src, dst) = Frame::parse(&data).unwrap(); + assert_eq!(src, EthAddr{mac: Mac([1,2,3,4,5,6]), vlan: None}); + assert_eq!(dst, EthAddr{mac: Mac([6,5,4,3,2,1]), vlan: None}); } #[test] fn with_vlan() { - let src = Mac([1,2,3,4,5,6]); - let dst = Mac([6,5,4,3,2,1]); - let payload = [1,2,3,4,5,6,7,8]; - let mut buf = [0u8; 1024]; - let frame = Frame{src: EthAddr{mac: src, vlan: Some(1234)}, dst: EthAddr{mac: dst, vlan: Some(1234)}}; - let size = frame.encode_to(&payload, &mut buf); - assert_eq!(size, 24); - assert_eq!(&buf[..size], &[6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8]); - let (frame2, payload2) = Frame::parse_from(&buf[..size]).unwrap(); - assert_eq!(frame, frame2); - assert_eq!(payload, payload2); + let data = [6,5,4,3,2,1,1,2,3,4,5,6,0x81,0,4,210,1,2,3,4,5,6,7,8]; + let (src, dst) = Frame::parse(&data).unwrap(); + assert_eq!(src, EthAddr{mac: Mac([1,2,3,4,5,6]), vlan: Some(1234)}); + assert_eq!(dst, EthAddr{mac: Mac([6,5,4,3,2,1]), vlan: Some(1234)}); } diff --git a/src/ip.rs b/src/ip.rs index 24b58e9..2ff7de2 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -1,8 +1,7 @@ use std::net::{SocketAddr, Ipv4Addr, Ipv6Addr}; use std::collections::{hash_map, HashMap}; -use std::fmt; -use super::cloud::{InterfaceMessage, Error}; +use super::cloud::{Protocol, Error}; use super::util::{as_obj, as_bytes}; @@ -12,43 +11,12 @@ pub enum IpAddress { V6(Ipv6Addr) } -#[derive(PartialEq)] -pub enum IpHeader { - V4{src: Ipv4Addr, dst: Ipv4Addr}, - V6{src: Ipv6Addr, dst: Ipv6Addr} -} +pub struct InternetProtocol; -impl fmt::Debug for IpHeader { - fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match self { - &IpHeader::V4{src, dst} => write!(formatter, "src: {}, dst: {}", src, dst), - &IpHeader::V6{src, dst} => write!(formatter, "src: {}, dst: {}", src, dst) - } - } -} - -impl InterfaceMessage for IpHeader { +impl Protocol for InternetProtocol { type Address = IpAddress; - fn src(&self) -> Self::Address { - match self { - &IpHeader::V4{src, dst: _} => IpAddress::V4(src), - &IpHeader::V6{src, dst: _} => IpAddress::V6(src) - } - } - - fn dst(&self) -> Self::Address { - match self { - &IpHeader::V4{src: _, dst} => IpAddress::V4(dst), - &IpHeader::V6{src: _, dst} => IpAddress::V6(dst) - } - } - - fn encode_to(&self, payload: &[u8], data: &mut [u8]) -> usize { - unimplemented!() - } - - fn parse_from(data: &[u8]) -> Result<(Self, &[u8]), Error> { + fn parse(data: &[u8]) -> Result<(IpAddress, IpAddress), Error> { unimplemented!() } } diff --git a/src/udpmessage.rs b/src/udpmessage.rs index 6c66174..44b8462 100644 --- a/src/udpmessage.rs +++ b/src/udpmessage.rs @@ -2,8 +2,7 @@ use std::{mem, ptr, fmt}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::u16; -use super::cloud::{Error, NetworkId, InterfaceMessage}; -use super::ethernet; +use super::cloud::{Error, NetworkId}; use super::util::{as_obj, as_bytes}; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; @@ -31,17 +30,17 @@ pub struct Options { #[derive(PartialEq)] -pub enum Message<'a, M: InterfaceMessage> { - Frame(M, &'a[u8]), +pub enum Message<'a> { + Frame(&'a[u8]), Peers(Vec), GetPeers, Close, } -impl<'a, M: InterfaceMessage> fmt::Debug for Message<'a, M> { +impl<'a> fmt::Debug for Message<'a> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self { - &Message::Frame(ref frame, ref payload) => write!(formatter, "Frame({:?}, payload: {})", frame, payload.len()), + &Message::Frame(ref data) => write!(formatter, "Frame(data: {} bytes)", data.len()), &Message::Peers(ref peers) => { try!(write!(formatter, "Peers [")); let mut first = true; @@ -60,7 +59,7 @@ impl<'a, M: InterfaceMessage> fmt::Debug for Message<'a, M> { } } -pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { +pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { if data.len() < mem::size_of::() { return Err(Error::ParseError("Empty message")); } @@ -83,10 +82,7 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), pos += 8; } let msg = match header.msgtype { - 0 => { - let (header, payload) = try!(M::parse_from(&data[pos..])); - Message::Frame(header, payload) - }, + 0 => Message::Frame(&data[pos..]), 1 => { if data.len() < pos + 1 { return Err(Error::ParseError("Empty peers")); @@ -119,12 +115,12 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), Ok((options, msg)) } -pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize { +pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize { assert!(buf.len() >= mem::size_of::()); let mut pos = 0; let mut header = TopHeader::default(); header.msgtype = match msg { - &Message::Frame(_, _) => 0, + &Message::Frame(_) => 0, &Message::Peers(_) => 1, &Message::GetPeers => 2, &Message::Close => 3 @@ -144,8 +140,10 @@ pub fn encode(options: &Options, msg: &Message, buf: &mu pos += 8; } match msg { - &Message::Frame(ref header, ref payload) => { - pos += header.encode_to(&payload, &mut buf[pos..]) + &Message::Frame(ref data) => { + assert!(buf.len() >= pos + data.len()); + unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf[pos..].as_mut_ptr(), data.len()) }; + pos += data.len(); }, &Message::Peers(ref peers) => { let count_pos = pos; @@ -184,16 +182,12 @@ pub fn encode(options: &Options, msg: &Message, buf: &mu #[test] fn encode_message_packet() { - use super::ethernet::Mac; let options = Options::default(); - let src = Mac([1,2,3,4,5,6]); - let dst = Mac([7,8,9,10,11,12]); let payload = [1,2,3,4,5]; - let frame = ethernet::Frame{src: ethernet::EthAddr{mac: src, vlan: None}, dst: ethernet::EthAddr{mac: dst, vlan: None}}; - let msg = Message::Frame(frame, &payload); + let msg = Message::Frame(&payload); let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); - assert_eq!(size, 25); + assert_eq!(size, 13); assert_eq!(&buf[..8], &[118,112,110,0,0,0,0,0]); let (options2, msg2) = decode(&buf[..size]).unwrap(); assert_eq!(options, options2); @@ -204,7 +198,7 @@ fn encode_message_packet() { fn encode_message_peers() { use std::str::FromStr; let options = Options::default(); - let msg: Message = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]); + let msg: Message = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]); let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 22); @@ -218,7 +212,7 @@ fn encode_message_peers() { fn encode_option_network_id() { let mut options = Options::default(); options.network_id = Some(134); - let msg: Message = Message::GetPeers; + let msg: Message = Message::GetPeers; let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 16); @@ -231,7 +225,7 @@ fn encode_option_network_id() { #[test] fn encode_message_getpeers() { let options = Options::default(); - let msg: Message = Message::GetPeers; + let msg: Message = Message::GetPeers; let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 8); @@ -244,7 +238,7 @@ fn encode_message_getpeers() { #[test] fn encode_message_close() { let options = Options::default(); - let msg: Message = Message::Close; + let msg: Message = Message::Close; let mut buf = [0; 1024]; let size = encode(&options, &msg, &mut buf[..]); assert_eq!(size, 8);