Lots of changes

This commit is contained in:
Dennis Schwerdel 2015-11-22 22:00:34 +01:00
parent f559e9bf09
commit d9dedcdd0d
6 changed files with 189 additions and 130 deletions

View File

@ -10,7 +10,6 @@ docopt = "0.6"
rustc-serialize = "0.3" rustc-serialize = "0.3"
log = "0.3" log = "0.3"
epoll = "0.2" epoll = "0.2"
regex = "0.1"
[build-dependencies] [build-dependencies]
gcc = "0.3" gcc = "0.3"

View File

@ -18,15 +18,21 @@ use super::ip::{InternetProtocol, IpAddress, RoutingTable};
pub type NetworkId = u64; pub type NetworkId = u64;
pub trait Address: Sized + fmt::Debug + Clone {
fn from_bytes(&[u8]) -> Result<Self, Error>;
fn to_bytes(&self) -> Vec<u8>;
}
pub trait Table { pub trait Table {
type Address; type Address: Address;
fn learn(&mut self, Self::Address, SocketAddr); fn learn(&mut self, Self::Address, SocketAddr);
fn lookup(&self, &Self::Address) -> Option<SocketAddr>; fn lookup(&self, &Self::Address) -> Option<SocketAddr>;
fn housekeep(&mut self); fn housekeep(&mut self);
fn remove_all(&mut self, SocketAddr);
} }
pub trait Protocol: Sized { pub trait Protocol: Sized {
type Address; type Address: Address;
fn parse(&[u8]) -> Result<(Self::Address, Self::Address), Error>; fn parse(&[u8]) -> Result<(Self::Address, Self::Address), Error>;
} }
@ -113,8 +119,10 @@ impl PeerList {
} }
pub struct GenericCloud<A, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInterface> { pub struct GenericCloud<A: Address, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInterface> {
peers: PeerList, peers: PeerList,
addresses: Vec<A>,
learning: bool,
reconnect_peers: Vec<SocketAddr>, reconnect_peers: Vec<SocketAddr>,
table: T, table: T,
socket: UdpSocket, socket: UdpSocket,
@ -127,14 +135,17 @@ pub struct GenericCloud<A, T: Table<Address=A>, M: Protocol<Address=A>, I: Virtu
_dummy_m: PhantomData<M>, _dummy_m: PhantomData<M>,
} }
impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInterface> GenericCloud<A, T, M, I> { impl<A: Address, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInterface> GenericCloud<A, T, M, I> {
pub fn new(device: I, listen: String, network_id: Option<NetworkId>, table: T, peer_timeout: Duration) -> Self { pub fn new(device: I, listen: String, network_id: Option<NetworkId>, table: T,
peer_timeout: Duration, learning: bool, addresses: Vec<A>) -> Self {
let socket = match UdpSocket::bind(&listen as &str) { let socket = match UdpSocket::bind(&listen as &str) {
Ok(socket) => socket, Ok(socket) => socket,
_ => panic!("Failed to open socket") _ => panic!("Failed to open socket")
}; };
GenericCloud{ GenericCloud{
peers: PeerList::new(peer_timeout), peers: PeerList::new(peer_timeout),
addresses: addresses,
learning: learning,
reconnect_peers: Vec::new(), reconnect_peers: Vec::new(),
table: table, table: table,
socket: socket, socket: socket,
@ -148,7 +159,7 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
} }
} }
fn send_msg<Addr: ToSocketAddrs+fmt::Display>(&mut self, addr: Addr, msg: &Message) -> Result<(), Error> { fn send_msg<Addr: ToSocketAddrs+fmt::Display>(&mut self, addr: Addr, msg: &Message<A>) -> Result<(), Error> {
debug!("Sending {:?} to {}", msg, addr); debug!("Sending {:?} to {}", msg, addr);
let mut options = Options::default(); let mut options = Options::default();
options.network_id = self.network_id; options.network_id = self.network_id;
@ -176,7 +187,8 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
let addr = addr.to_socket_addrs().unwrap().next().unwrap(); let addr = addr.to_socket_addrs().unwrap().next().unwrap();
self.reconnect_peers.push(addr); self.reconnect_peers.push(addr);
} }
self.send_msg(addr, &Message::GetPeers) let addrs = self.addresses.clone();
self.send_msg(addr, &Message::Init(addrs))
} }
fn housekeep(&mut self) -> Result<(), Error> { fn housekeep(&mut self) -> Result<(), Error> {
@ -211,11 +223,15 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
match self.table.lookup(&dst) { match self.table.lookup(&dst) {
Some(addr) => { Some(addr) => {
debug!("Found destination for {:?} => {}", dst, addr); debug!("Found destination for {:?} => {}", dst, addr);
try!(self.send_msg(addr, &Message::Frame(payload))) if self.peers.contains(&addr) {
try!(self.send_msg(addr, &Message::Data(payload)))
} else {
warn!("Destination for {:?} not found in peers: {}", dst, addr);
}
}, },
None => { None => {
debug!("No destination for {:?} found, broadcasting", dst); debug!("No destination for {:?} found, broadcasting", dst);
let msg = Message::Frame(payload); let msg = Message::Data(payload);
for addr in &self.peers.as_vec() { for addr in &self.peers.as_vec() {
try!(self.send_msg(addr, &msg)); try!(self.send_msg(addr, &msg));
} }
@ -224,7 +240,7 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
Ok(()) Ok(())
} }
fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message) -> Result<(), Error> { fn handle_net_message(&mut self, peer: SocketAddr, options: Options, msg: Message<A>) -> Result<(), Error> {
if let Some(id) = self.network_id { if let Some(id) = self.network_id {
if options.network_id != Some(id) { if options.network_id != Some(id) {
info!("Ignoring message from {} with wrong token {:?}", peer, options.network_id); info!("Ignoring message from {} with wrong token {:?}", peer, options.network_id);
@ -233,18 +249,20 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
} }
debug!("Recieved {:?} from {}", msg, peer); debug!("Recieved {:?} from {}", msg, peer);
match msg { match msg {
Message::Frame(payload) => { Message::Data(payload) => {
let (src, _dst) = try!(M::parse(payload)); let (src, _dst) = try!(M::parse(payload));
debug!("Writing data to device: {} bytes", payload.len()); debug!("Writing data to device: {} bytes", payload.len());
match self.device.write(&payload) { match self.device.write(&payload) {
Ok(()) => (), Ok(()) => (),
Err(e) => { Err(e) => {
error!("Failed to send via tap device {:?}", e); error!("Failed to send via device {:?}", e);
return Err(Error::TunTapDevError("Failed to write to tap device")); return Err(Error::TunTapDevError("Failed to write to device"));
} }
} }
self.peers.add(&peer); self.peers.add(&peer);
self.table.learn(src, peer); if self.learning {
self.table.learn(src, peer);
}
}, },
Message::Peers(peers) => { Message::Peers(peers) => {
self.peers.add(&peer); self.peers.add(&peer);
@ -254,10 +272,13 @@ impl<A: fmt::Debug, T: Table<Address=A>, M: Protocol<Address=A>, I: VirtualInter
} }
} }
}, },
Message::GetPeers => { Message::Init(addrs) => {
self.peers.add(&peer); self.peers.add(&peer);
let peers = self.peers.as_vec(); let peers = self.peers.as_vec();
try!(self.send_msg(peer, &Message::Peers(peers))); try!(self.send_msg(peer, &Message::Peers(peers)));
for addr in addrs {
self.table.learn(addr, peer.clone());
}
}, },
Message::Close => { Message::Close => {
self.peers.remove(&peer); self.peers.remove(&peer);
@ -323,7 +344,7 @@ impl TapCloud {
}; };
info!("Opened tap device {}", device.ifname()); info!("Opened tap device {}", device.ifname());
let table = MacTable::new(mac_timeout); let table = MacTable::new(mac_timeout);
Self::new(device, listen, network_id, table, peer_timeout) Self::new(device, listen, network_id, table, peer_timeout, true, vec![])
} }
} }
@ -331,12 +352,14 @@ impl TapCloud {
pub type TunCloud = GenericCloud<IpAddress, RoutingTable, InternetProtocol, TunDevice>; pub type TunCloud = GenericCloud<IpAddress, RoutingTable, InternetProtocol, TunDevice>;
impl TunCloud { impl TunCloud {
pub fn new_tun_cloud(device: &str, listen: String, network_id: Option<NetworkId>, table: RoutingTable, peer_timeout: Duration) -> Self { pub fn new_tun_cloud(device: &str, listen: String, network_id: Option<NetworkId>, subnet: String, peer_timeout: Duration) -> Self {
let device = match TunDevice::new(device) { let device = match TunDevice::new(device) {
Ok(device) => device, Ok(device) => device,
_ => panic!("Failed to open tun device") _ => panic!("Failed to open tun device")
}; };
info!("Opened tun device {}", device.ifname()); info!("Opened tun device {}", device.ifname());
Self::new(device, listen, network_id, table, peer_timeout) let table = RoutingTable::new();
let subnet = IpAddress::from_str(&subnet).expect("Invalid subnet");
Self::new(device, listen, network_id, table, peer_timeout, false, vec![subnet])
} }
} }

View File

@ -3,7 +3,7 @@ use std::net::SocketAddr;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Write; use std::io::Write;
use super::cloud::{Error, Table, Protocol}; use super::cloud::{Error, Table, Protocol, Address};
use super::util::as_obj; use super::util::as_obj;
use time::{Duration, SteadyTime}; use time::{Duration, SteadyTime};
@ -27,6 +27,16 @@ pub struct EthAddr {
pub vlan: Option<VlanId> pub vlan: Option<VlanId>
} }
impl Address for EthAddr {
fn from_bytes(_bytes: &[u8]) -> Result<Self, Error> {
unimplemented!()
}
fn to_bytes(&self) -> Vec<u8> {
unimplemented!()
}
}
#[derive(PartialEq)] #[derive(PartialEq)]
pub struct Frame; pub struct Frame;
@ -103,6 +113,10 @@ impl Table for MacTable {
None => None None => None
} }
} }
fn remove_all(&mut self, _addr: SocketAddr) {
unimplemented!()
}
} }

151
src/ip.rs
View File

@ -1,25 +1,23 @@
use std::net::{SocketAddr, Ipv4Addr, Ipv6Addr, AddrParseError, ToSocketAddrs}; use std::net::{SocketAddr, Ipv4Addr, Ipv6Addr};
use std::collections::{hash_map, HashMap}; use std::collections::{hash_map, HashMap};
use std::ptr; use std::ptr;
use std::path::Path; use std::io::Read;
use std::fs::File;
use std::io::{Result as IoResult, Read, BufRead, BufReader};
use std::str::FromStr; use std::str::FromStr;
use regex::Regex; use super::cloud::{Protocol, Error, Table, Address};
use super::cloud::{Protocol, Error, Table};
use super::util::{as_obj, as_bytes}; use super::util::{as_obj, as_bytes};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpAddress { pub enum IpAddress {
V4(Ipv4Addr), V4(Ipv4Addr),
V6(Ipv6Addr) V6(Ipv6Addr),
V4Net(Ipv4Addr, u8),
V6Net(Ipv6Addr, u8),
} }
impl IpAddress { impl Address for IpAddress {
pub fn to_bytes(&self) -> Vec<u8> { fn to_bytes(&self) -> Vec<u8> {
match self { match self {
&IpAddress::V4(addr) => { &IpAddress::V4(addr) => {
let ip = addr.octets(); let ip = addr.octets();
@ -30,6 +28,11 @@ impl IpAddress {
} }
res res
}, },
&IpAddress::V4Net(addr, prefix_len) => {
let mut bytes = IpAddress::V4(addr).to_bytes();
bytes.push(prefix_len);
bytes
},
&IpAddress::V6(addr) => { &IpAddress::V6(addr) => {
let mut segments = addr.segments(); let mut segments = addr.segments();
for i in 0..8 { for i in 0..8 {
@ -42,14 +45,56 @@ impl IpAddress {
ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len()); ptr::copy_nonoverlapping(bytes.as_ptr(), res.as_mut_ptr(), bytes.len());
} }
res res
},
&IpAddress::V6Net(addr, prefix_len) => {
let mut bytes = IpAddress::V6(addr).to_bytes();
bytes.push(prefix_len);
bytes
} }
} }
} }
pub fn from_str(addr: &str) -> Result<Self, AddrParseError> { fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
let ipv4 = Ipv4Addr::from_str(addr).map(|addr| IpAddress::V4(addr)); match bytes.len() {
let ipv6 = Ipv6Addr::from_str(addr).map(|addr| IpAddress::V6(addr)); 4 => Ok(IpAddress::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]))),
ipv4.or(ipv6) 5 => Ok(IpAddress::V4Net(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]), bytes[4])),
16 => {
let data = unsafe { as_obj::<[u16; 8]>(&bytes) };
Ok(IpAddress::V6(Ipv6Addr::new(
u16::from_be(data[0]), u16::from_be(data[1]),
u16::from_be(data[2]), u16::from_be(data[3]),
u16::from_be(data[4]), u16::from_be(data[5]),
u16::from_be(data[6]), u16::from_be(data[7]),
)))
},
17 => {
let data = unsafe { as_obj::<[u16; 8]>(&bytes) };
Ok(IpAddress::V6Net(Ipv6Addr::new(
u16::from_be(data[0]), u16::from_be(data[1]),
u16::from_be(data[2]), u16::from_be(data[3]),
u16::from_be(data[4]), u16::from_be(data[5]),
u16::from_be(data[6]), u16::from_be(data[7]),
), bytes[16]))
}
_ => Err(Error::ParseError("Invalid address size"))
}
}
}
impl IpAddress {
pub fn from_str(addr: &str) -> Result<Self, Error> {
if let Some(pos) = addr.find("/") {
let prefix_len = try!(u8::from_str(&addr[pos+1..])
.map_err(|_| Error::ParseError("Failed to parse prefix length")));
let addr = &addr[..pos];
let ipv4 = Ipv4Addr::from_str(addr).map(|addr| IpAddress::V4Net(addr, prefix_len));
let ipv6 = Ipv6Addr::from_str(addr).map(|addr| IpAddress::V6Net(addr, prefix_len));
ipv4.or(ipv6).map_err(|_| Error::ParseError("Failed to parse address"))
} else {
let ipv4 = Ipv4Addr::from_str(addr).map(|addr| IpAddress::V4(addr));
let ipv6 = Ipv6Addr::from_str(addr).map(|addr| IpAddress::V6(addr));
ipv4.or(ipv6).map_err(|_| Error::ParseError("Failed to parse address"))
}
} }
} }
@ -69,31 +114,13 @@ impl Protocol for InternetProtocol {
if data.len() < 20 { if data.len() < 20 {
return Err(Error::ParseError("Truncated header")); return Err(Error::ParseError("Truncated header"));
} }
let src_data = unsafe { as_obj::<[u8; 4]>(&data[12..]) }; Ok((try!(IpAddress::from_bytes(&data[12..16])), try!(IpAddress::from_bytes(&data[16..20]))))
let src = Ipv4Addr::new(src_data[0], src_data[1], src_data[2], src_data[3]);
let dst_data = unsafe { as_obj::<[u8; 4]>(&data[16..]) };
let dst = Ipv4Addr::new(dst_data[0], dst_data[1], dst_data[2], dst_data[3]);
Ok((IpAddress::V4(src), IpAddress::V4(dst)))
}, },
6 => { 6 => {
if data.len() < 40 { if data.len() < 40 {
return Err(Error::ParseError("Truncated header")); return Err(Error::ParseError("Truncated header"));
} }
let src_data = unsafe { as_obj::<[u16; 8]>(&data[8..]) }; Ok((try!(IpAddress::from_bytes(&data[8..24])), try!(IpAddress::from_bytes(&data[24..40]))))
let src = Ipv6Addr::new(
u16::from_be(src_data[0]), u16::from_be(src_data[1]),
u16::from_be(src_data[2]), u16::from_be(src_data[3]),
u16::from_be(src_data[4]), u16::from_be(src_data[5]),
u16::from_be(src_data[6]), u16::from_be(src_data[7]),
);
let dst_data = unsafe { as_obj::<[u16; 8]>(&data[24..]) };
let dst = Ipv6Addr::new(
u16::from_be(dst_data[0]), u16::from_be(dst_data[1]),
u16::from_be(dst_data[2]), u16::from_be(dst_data[3]),
u16::from_be(dst_data[4]), u16::from_be(dst_data[5]),
u16::from_be(dst_data[6]), u16::from_be(dst_data[7]),
);
Ok((IpAddress::V6(src), IpAddress::V6(dst)))
}, },
_ => Err(Error::ParseError("Invalid version")) _ => Err(Error::ParseError("Invalid version"))
} }
@ -125,49 +152,6 @@ impl RoutingTable {
} }
} }
pub fn load_from(&mut self, path: &Path) -> IoResult<()> {
let pattern = Regex::new(r"(?P<base>[^/]+)/(?P<prefix>\d+)\s=>\s(?P<peer>.+)").unwrap();
let file = try!(File::open(path));
let mut reader = BufReader::new(file);
loop {
let mut s = String::new();
let res = try!(reader.read_line(&mut s));
if res == 0 {
break;
}
let captures = match pattern.captures(&s) {
Some(captures) => captures,
None => {
error!("Failed to parse routing table entry: {}", s);
continue
}
};
let base = match IpAddress::from_str(captures.name("base").unwrap()) {
Ok(addr) => addr.to_bytes(),
Err(e) => {
error!("Failed to parse base address: {}", e);
continue
}
};
let prefix_len = match u8::from_str(captures.name("prefix").unwrap()) {
Ok(num) => num,
Err(e) => {
error!("Failed to parse prefix length: {}", e);
continue
}
};
let peer = match captures.name("peer").unwrap().to_socket_addrs().map(|mut r| r.next()) {
Ok(Some(addr)) => addr,
_ => {
error!("Failed to parse peer address");
continue
}
};
self.add(base, prefix_len, peer);
}
Ok(())
}
pub fn lookup_bytes(&self, bytes: &[u8]) -> Option<SocketAddr> { pub fn lookup_bytes(&self, bytes: &[u8]) -> Option<SocketAddr> {
let len = bytes.len()/2 * 2; let len = bytes.len()/2 * 2;
for i in 0..len/2 { for i in 0..len/2 {
@ -199,8 +183,13 @@ impl RoutingTable {
impl Table for RoutingTable { impl Table for RoutingTable {
type Address = IpAddress; type Address = IpAddress;
fn learn(&mut self, _src: Self::Address, _addr: SocketAddr) { fn learn(&mut self, src: Self::Address, addr: SocketAddr) {
//nothing to do match src {
IpAddress::V4(_) => (),
IpAddress::V4Net(base, prefix_len) => self.add(IpAddress::V4(base).to_bytes(), prefix_len, addr),
IpAddress::V6(_) => (),
IpAddress::V6Net(base, prefix_len) => self.add(IpAddress::V6(base).to_bytes(), prefix_len, addr)
}
} }
fn lookup(&self, dst: &Self::Address) -> Option<SocketAddr> { fn lookup(&self, dst: &Self::Address) -> Option<SocketAddr> {
@ -210,4 +199,8 @@ impl Table for RoutingTable {
fn housekeep(&mut self) { fn housekeep(&mut self) {
//nothin to do //nothin to do
} }
fn remove_all(&mut self, _addr: SocketAddr) {
unimplemented!()
}
} }

View File

@ -3,7 +3,6 @@ extern crate time;
extern crate docopt; extern crate docopt;
extern crate rustc_serialize; extern crate rustc_serialize;
extern crate epoll; extern crate epoll;
extern crate regex;
mod util; mod util;
mod udpmessage; mod udpmessage;
@ -16,9 +15,7 @@ use time::Duration;
use docopt::Docopt; use docopt::Docopt;
use std::hash::{Hash, SipHasher, Hasher}; use std::hash::{Hash, SipHasher, Hasher};
use std::path::PathBuf;
use ip::RoutingTable;
use cloud::{Error, TapCloud, TunCloud}; use cloud::{Error, TapCloud, TunCloud};
@ -52,7 +49,7 @@ Options:
-c <connect>, --connect <connect> List of peers (addr:port) to connect to -c <connect>, --connect <connect> List of peers (addr:port) to connect to
--network-id <network_id> Optional token that identifies the network --network-id <network_id> Optional token that identifies the network
--peer-timeout <peer_timeout> Peer timeout in seconds [default: 1800] --peer-timeout <peer_timeout> Peer timeout in seconds [default: 1800]
--table <file> The file containing the routing table (only for tun) --subnet <subnet> The local subnet to use (only for tun)
--mac-timeout <mac_timeout> Mac table entry timeout in seconds (only for tap) [default: 300] --mac-timeout <mac_timeout> Mac table entry timeout in seconds (only for tap) [default: 300]
-v, --verbose Log verbosely -v, --verbose Log verbosely
-q, --quiet Only print error messages -q, --quiet Only print error messages
@ -68,7 +65,7 @@ enum Type {
#[derive(RustcDecodable, Debug)] #[derive(RustcDecodable, Debug)]
struct Args { struct Args {
flag_type: Type, flag_type: Type,
flag_table: PathBuf, flag_subnet: String,
flag_device: String, flag_device: String,
flag_listen: String, flag_listen: String,
flag_network_id: Option<String>, flag_network_id: Option<String>,
@ -98,8 +95,6 @@ fn tap_cloud(args: Args) {
} }
fn tun_cloud(args: Args) { fn tun_cloud(args: Args) {
let mut table = RoutingTable::new();
table.load_from(&args.flag_table).unwrap();
let mut tuncloud = TunCloud::new_tun_cloud( let mut tuncloud = TunCloud::new_tun_cloud(
&args.flag_device, &args.flag_device,
args.flag_listen, args.flag_listen,
@ -108,7 +103,7 @@ fn tun_cloud(args: Args) {
name.hash(&mut s); name.hash(&mut s);
s.finish() s.finish()
}), }),
table, args.flag_subnet,
Duration::seconds(args.flag_peer_timeout as i64) Duration::seconds(args.flag_peer_timeout as i64)
); );
for addr in args.flag_connect { for addr in args.flag_connect {

View File

@ -2,8 +2,9 @@ use std::{mem, ptr, fmt};
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr}; use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
use std::u16; use std::u16;
use super::cloud::{Error, NetworkId}; use super::cloud::{Error, NetworkId, Address};
use super::util::{as_obj, as_bytes}; use super::util::{as_obj, as_bytes};
use super::ethernet;
const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e]; const MAGIC: [u8; 3] = [0x76, 0x70, 0x6e];
const VERSION: u8 = 0; const VERSION: u8 = 0;
@ -30,17 +31,17 @@ pub struct Options {
#[derive(PartialEq)] #[derive(PartialEq)]
pub enum Message<'a> { pub enum Message<'a, A: Address> {
Frame(&'a[u8]), Data(&'a[u8]),
Peers(Vec<SocketAddr>), Peers(Vec<SocketAddr>),
GetPeers, Init(Vec<A>),
Close, Close,
} }
impl<'a> fmt::Debug for Message<'a> { impl<'a, A: Address> fmt::Debug for Message<'a, A> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> { fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match self { match self {
&Message::Frame(ref data) => write!(formatter, "Frame(data: {} bytes)", data.len()), &Message::Data(ref data) => write!(formatter, "Data(data: {} bytes)", data.len()),
&Message::Peers(ref peers) => { &Message::Peers(ref peers) => {
try!(write!(formatter, "Peers [")); try!(write!(formatter, "Peers ["));
let mut first = true; let mut first = true;
@ -53,13 +54,13 @@ impl<'a> fmt::Debug for Message<'a> {
} }
write!(formatter, "]") write!(formatter, "]")
}, },
&Message::GetPeers => write!(formatter, "GetPeers"), &Message::Init(ref data) => write!(formatter, "Init(data: {} bytes)", data.len()),
&Message::Close => write!(formatter, "Close"), &Message::Close => write!(formatter, "Close"),
} }
} }
} }
pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> { pub fn decode<A: Address>(data: &[u8]) -> Result<(Options, Message<A>), Error> {
if data.len() < mem::size_of::<TopHeader>() { if data.len() < mem::size_of::<TopHeader>() {
return Err(Error::ParseError("Empty message")); return Err(Error::ParseError("Empty message"));
} }
@ -82,7 +83,7 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> {
pos += 8; pos += 8;
} }
let msg = match header.msgtype { let msg = match header.msgtype {
0 => Message::Frame(&data[pos..]), 0 => Message::Data(&data[pos..]),
1 => { 1 => {
if data.len() < pos + 1 { if data.len() < pos + 1 {
return Err(Error::ParseError("Empty peers")); return Err(Error::ParseError("Empty peers"));
@ -108,21 +109,41 @@ pub fn decode(data: &[u8]) -> Result<(Options, Message), Error> {
} }
Message::Peers(peers) Message::Peers(peers)
}, },
2 => Message::GetPeers, 2 => {
if data.len() < pos + 1 {
return Err(Error::ParseError("Init data too short"));
}
let count = data[pos] as usize;
pos += 1;
let mut addrs = Vec::with_capacity(count);
for _ in 0..count {
if data.len() < pos + 1 {
return Err(Error::ParseError("Init data too short"));
}
let len = data[pos] as usize;
pos += 1;
if data.len() < pos + len {
return Err(Error::ParseError("Init data too short"));
}
addrs.push(try!(A::from_bytes(&data[pos..pos+len])));
pos += len;
}
Message::Init(addrs)
},
3 => Message::Close, 3 => Message::Close,
_ => return Err(Error::ParseError("Unknown message type")) _ => return Err(Error::ParseError("Unknown message type"))
}; };
Ok((options, msg)) Ok((options, msg))
} }
pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize { pub fn encode<A: Address>(options: &Options, msg: &Message<A>, buf: &mut [u8]) -> usize {
assert!(buf.len() >= mem::size_of::<TopHeader>()); assert!(buf.len() >= mem::size_of::<TopHeader>());
let mut pos = 0; let mut pos = 0;
let mut header = TopHeader::default(); let mut header = TopHeader::default();
header.msgtype = match msg { header.msgtype = match msg {
&Message::Frame(_) => 0, &Message::Data(_) => 0,
&Message::Peers(_) => 1, &Message::Peers(_) => 1,
&Message::GetPeers => 2, &Message::Init(_) => 2,
&Message::Close => 3 &Message::Close => 3
}; };
if options.network_id.is_some() { if options.network_id.is_some() {
@ -140,7 +161,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize {
pos += 8; pos += 8;
} }
match msg { match msg {
&Message::Frame(ref data) => { &Message::Data(ref data) => {
assert!(buf.len() >= pos + data.len()); assert!(buf.len() >= pos + data.len());
unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf[pos..].as_mut_ptr(), data.len()) }; unsafe { ptr::copy_nonoverlapping(data.as_ptr(), buf[pos..].as_mut_ptr(), data.len()) };
pos += data.len(); pos += data.len();
@ -171,7 +192,20 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize {
buf[pos] = 0; buf[pos] = 0;
pos += 1; pos += 1;
}, },
&Message::GetPeers => { &Message::Init(ref addrs) => {
assert!(buf.len() >= pos + 1);
assert!(addrs.len() <= 255);
buf[pos] = addrs.len() as u8;
pos += 1;
for addr in addrs {
let bytes = addr.to_bytes();
assert!(bytes.len() <= 255);
assert!(buf.len() >= pos + 1 + bytes.len());
buf[pos] = bytes.len() as u8;
pos += 1;
unsafe { ptr::copy_nonoverlapping(bytes.as_ptr(), buf[pos..].as_mut_ptr(), bytes.len()) };
pos += bytes.len();
}
}, },
&Message::Close => { &Message::Close => {
} }
@ -184,7 +218,7 @@ pub fn encode(options: &Options, msg: &Message, buf: &mut [u8]) -> usize {
fn encode_message_packet() { fn encode_message_packet() {
let options = Options::default(); let options = Options::default();
let payload = [1,2,3,4,5]; let payload = [1,2,3,4,5];
let msg = Message::Frame(&payload); let msg: Message<ethernet::EthAddr> = Message::Data(&payload);
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&options, &msg, &mut buf[..]); let size = encode(&options, &msg, &mut buf[..]);
assert_eq!(size, 13); assert_eq!(size, 13);
@ -198,7 +232,7 @@ fn encode_message_packet() {
fn encode_message_peers() { fn encode_message_peers() {
use std::str::FromStr; use std::str::FromStr;
let options = Options::default(); let options = Options::default();
let msg: Message = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]); let msg: Message<ethernet::EthAddr> = Message::Peers(vec![SocketAddr::from_str("1.2.3.4:123").unwrap(), SocketAddr::from_str("5.6.7.8:12345").unwrap()]);
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&options, &msg, &mut buf[..]); let size = encode(&options, &msg, &mut buf[..]);
assert_eq!(size, 22); assert_eq!(size, 22);
@ -212,7 +246,7 @@ fn encode_message_peers() {
fn encode_option_network_id() { fn encode_option_network_id() {
let mut options = Options::default(); let mut options = Options::default();
options.network_id = Some(134); options.network_id = Some(134);
let msg: Message = Message::GetPeers; let msg: Message<ethernet::EthAddr> = Message::Close;
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&options, &msg, &mut buf[..]); let size = encode(&options, &msg, &mut buf[..]);
assert_eq!(size, 16); assert_eq!(size, 16);
@ -223,13 +257,14 @@ fn encode_option_network_id() {
} }
#[test] #[test]
fn encode_message_getpeers() { fn encode_message_init() {
let options = Options::default(); let options = Options::default();
let msg: Message = Message::GetPeers; let addrs = vec![];
let msg: Message<ethernet::EthAddr> = Message::Init(addrs);
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&options, &msg, &mut buf[..]); let size = encode(&options, &msg, &mut buf[..]);
assert_eq!(size, 8); assert_eq!(size, 13);
assert_eq!(&buf[..size], &[118,112,110,0,0,0,0,2]); assert_eq!(&buf[..size], &[118,112,110,0,0,0,0,2,1,2,3,4,5]);
let (options2, msg2) = decode(&buf[..size]).unwrap(); let (options2, msg2) = decode(&buf[..size]).unwrap();
assert_eq!(options, options2); assert_eq!(options, options2);
assert_eq!(msg, msg2); assert_eq!(msg, msg2);
@ -238,7 +273,7 @@ fn encode_message_getpeers() {
#[test] #[test]
fn encode_message_close() { fn encode_message_close() {
let options = Options::default(); let options = Options::default();
let msg: Message = Message::Close; let msg: Message<ethernet::EthAddr> = Message::Close;
let mut buf = [0; 1024]; let mut buf = [0; 1024];
let size = encode(&options, &msg, &mut buf[..]); let size = encode(&options, &msg, &mut buf[..]);
assert_eq!(size, 8); assert_eq!(size, 8);