diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index 77ded0c..b8ad1bb 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -1,50 +1,119 @@ -use std::{marker::PhantomData, sync::Arc}; +use crate::{ + engine::{addr_nice, Hash, PeerData}, + messages::MESSAGE_TYPE_DATA, + net::Socket, + table::ClaimTable, + traffic::TrafficStats, + Protocol +}; +use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc}; -use super::SPACE_BEFORE; -use super::shared::SharedData; +use super::{shared::SharedData, SPACE_BEFORE}; use crate::{ device::Device, error::Error, util::{MsgBuffer, Time, TimeSource} }; - -const SYNC_INTERVAL: Time = 1; - -pub struct DeviceThread { - shared: Arc, +pub struct DeviceThread { + // Read-only fields + _dummy_ts: PhantomData, + _dummy_p: PhantomData

, + broadcast: bool, + // Device-only fields + socket: S, device: D, - next_sync: Time, - _dummy: PhantomData + next_housekeep: Time, + // Shared fields + shared: Arc, + peers: HashMap, + traffic: TrafficStats, + table: ClaimTable } -impl DeviceThread { - fn sync(&mut self) { - // TODO sync +impl DeviceThread { + #[inline] + fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { + debug!("Sending msg with {} bytes to {}", msg.len(), addr); + self.traffic.count_out_traffic(addr, msg.len()); + match self.socket.send(msg.message(), addr) { + Ok(written) if written == msg.len() => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)) + } } - fn read_device_packet(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { - // TODO: read data - // use 5sec timeout - unimplemented!(); + #[inline] + fn send_msg(&mut self, addr: SocketAddr, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { + debug!("Sending msg with {} bytes to {}", msg.len(), addr); + let peer = match self.peers.get_mut(&addr) { + Some(peer) => peer, + None => return Err(Error::Message("Sending to node that is not a peer")) + }; + peer.crypto.send_message(type_, msg)?; + self.send_to(addr, msg) } - fn forward_packet(&mut self, buffer: &mut MsgBuffer) -> Result<(), Error> { - // TODO: handle data + #[inline] + fn broadcast_msg(&mut self, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { + debug!("Broadcasting message type {}, {:?} bytes to {} peers", type_, msg.len(), self.peers.len()); + let mut msg_data = MsgBuffer::new(100); + for (addr, peer) in &mut self.peers { + msg_data.set_start(msg.get_start()); + msg_data.set_length(msg.len()); + msg_data.message_mut().clone_from_slice(msg.message()); + peer.crypto.send_message(type_, &mut msg_data)?; + self.traffic.count_out_traffic(*addr, msg_data.len()); + match self.socket.send(msg_data.message(), *addr) { + Ok(written) if written == msg_data.len() => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)) + }? + } + Ok(()) + } + + fn forward_packet(&mut self, data: &mut MsgBuffer) -> Result<(), Error> { + let (src, dst) = P::parse(data.message())?; + debug!("Read data from interface: src: {}, dst: {}, {} bytes", src, dst, data.len()); + self.traffic.count_out_payload(dst, src, data.len()); + match self.table.lookup(dst) { + Some(addr) => { + // Peer found for destination + debug!("Found destination for {} => {}", dst, addr); + self.send_msg(addr, MESSAGE_TYPE_DATA, data)?; + } + None => { + if self.broadcast { + debug!("No destination for {} found, broadcasting", dst); + self.broadcast_msg(MESSAGE_TYPE_DATA, data)?; + } else { + debug!("No destination for {} found, dropping", dst); + self.traffic.count_dropped_payload(data.len()); + } + } + } + Ok(()) + } + + fn housekeep(&mut self) -> Result<(), Error> { + // TODO: sync unimplemented!(); } pub fn run(mut self) { let mut buffer = MsgBuffer::new(SPACE_BEFORE); loop { - try_fail!(self.read_device_packet(&mut buffer), "Failed to read from device: {}"); + try_fail!(self.device.read(&mut buffer), "Failed to read from device: {}"); if let Err(e) = self.forward_packet(&mut buffer) { error!("{}", e); } - let now = T::now(); - if self.next_sync < now { - self.sync(); - self.next_sync = now + SYNC_INTERVAL + let now = TS::now(); + if self.next_housekeep < TS::now() { + if let Err(e) = self.housekeep() { + error!("{}", e) + } + self.next_housekeep = TS::now() + 1 } } } diff --git a/src/engine/socket_thread.rs b/src/engine/socket_thread.rs index 42faa62..9f99dc9 100644 --- a/src/engine/socket_thread.rs +++ b/src/engine/socket_thread.rs @@ -1,50 +1,307 @@ -use crate::error::Error; -use std::{marker::PhantomData, net::SocketAddr, sync::Arc}; - use super::{shared::SharedData, SPACE_BEFORE}; use crate::{ - net::Socket, - util::{MsgBuffer, Time, TimeSource} + config::DEFAULT_PEER_TIMEOUT, + crypto::{is_init_message, MessageResult, PeerCrypto}, + engine::{addr_nice, resolve, Hash, PeerData}, + error::Error, + messages::{AddrList, NodeInfo, PeerInfo}, + net::{mapped_addr, Socket}, + table::ClaimTable, + traffic::TrafficStats, + types::{NodeId, RangeList}, + util::{MsgBuffer, Time, TimeSource}, + Config, Crypto, Device, Protocol +}; +use rand::{random, seq::SliceRandom, thread_rng}; +use smallvec::{smallvec, SmallVec}; +use std::{ + collections::HashMap, + fmt, + io::Cursor, + marker::PhantomData, + net::{SocketAddr, ToSocketAddrs}, + sync::Arc }; - -const SYNC_INTERVAL: Time = 1; - -pub struct SocketThread { - shared: Arc, +pub struct SocketThread { + // Read-only fields + node_id: NodeId, + claims: RangeList, + config: Config, + peer_timeout_publish: u16, + learning: bool, + _dummy_ts: PhantomData, + _dummy_p: PhantomData

, + // Socket-only fields socket: S, - next_sync: Time, - _dummy: PhantomData + device: D, + next_housekeep: Time, + own_addresses: AddrList, + pending_inits: HashMap, Hash>, + // Shared fields + shared: Arc, + traffic: TrafficStats, + peers: HashMap, + crypto: Crypto, + table: ClaimTable } -impl SocketThread { - fn sync(&mut self) { +impl SocketThread { + #[inline] + fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { + debug!("Sending msg with {} bytes to {}", msg.len(), addr); + self.traffic.count_out_traffic(addr, msg.len()); + match self.socket.send(msg.message(), addr) { + Ok(written) if written == msg.len() => Ok(()), + Ok(_) => Err(Error::Socket("Sent out truncated packet")), + Err(e) => Err(Error::SocketIo("IOError when sending", e)) + } + } + + fn connect_sock(&mut self, addr: SocketAddr) -> Result<(), Error> { + let addr = mapped_addr(addr); + if self.peers.contains_key(&addr) + || self.own_addresses.contains(&addr) + || self.pending_inits.contains_key(&addr) + { + return Ok(()) + } + debug!("Connecting to {:?}", addr); + let payload = self.create_node_info(); + let mut peer_crypto = self.crypto.peer_instance(payload); + let mut msg = MsgBuffer::new(SPACE_BEFORE); + peer_crypto.initialize(&mut msg)?; + self.pending_inits.insert(addr, peer_crypto); + self.send_to(addr, &mut msg) + } + + pub fn connect(&mut self, addr: Addr) -> Result<(), Error> { + let addrs = resolve(&addr)?.into_iter().map(mapped_addr).collect::>(); + for addr in &addrs { + if self.own_addresses.contains(addr) + || self.peers.contains_key(addr) + || self.pending_inits.contains_key(addr) + { + return Ok(()) + } + } + // Send a message to each resolved address + for a in addrs { + // Ignore error this time + self.connect_sock(a).ok(); + } + Ok(()) + } + + fn create_node_info(&self) -> NodeInfo { + let mut peers = smallvec![]; + for peer in self.peers.values() { + peers.push(PeerInfo { node_id: Some(peer.node_id), addrs: peer.addrs.clone() }) + } + if peers.len() > 20 { + let mut rng = rand::thread_rng(); + peers.partial_shuffle(&mut rng, 20); + peers.truncate(20); + } + NodeInfo { + node_id: self.node_id, + peers, + claims: self.claims.clone(), + peer_timeout: Some(self.peer_timeout_publish), + addrs: self.own_addresses.clone() + } + } + + fn update_peer_info(&mut self, addr: SocketAddr, info: Option) -> Result<(), Error> { + if let Some(peer) = self.peers.get_mut(&addr) { + peer.last_seen = TS::now(); + peer.timeout = TS::now() + self.config.peer_timeout as Time + } else { + error!("Received peer update from non peer {}", addr_nice(addr)); + return Ok(()) + } + if let Some(info) = info { + debug!("Adding claims of peer {}: {:?}", addr_nice(addr), info.claims); + self.table.set_claims(addr, info.claims); + debug!("Received {} peers from {}: {:?}", info.peers.len(), addr_nice(addr), info.peers); + self.connect_to_peers(&info.peers)?; + } + Ok(()) + } + + fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo) -> Result<(), Error> { + info!("Added peer {}", addr_nice(addr)); + if let Some(init) = self.pending_inits.remove(&addr) { + self.peers.insert(addr, PeerData { + addrs: info.addrs.clone(), + crypto: init, + node_id: info.node_id, + peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT), + last_seen: TS::now(), + timeout: TS::now() + self.config.peer_timeout as Time + }); + self.update_peer_info(addr, Some(info))?; + } else { + error!("No init for new peer {}", addr_nice(addr)); + } + Ok(()) + } + + fn connect_to_peers(&mut self, peers: &[PeerInfo]) -> Result<(), Error> { + 'outer: for peer in peers { + for addr in &peer.addrs { + if self.peers.contains_key(addr) { + continue 'outer + } + } + if let Some(node_id) = peer.node_id { + if self.node_id == node_id { + continue 'outer + } + for p in self.peers.values() { + if p.node_id == node_id { + continue 'outer + } + } + } + self.connect(&peer.addrs as &[SocketAddr])?; + } + Ok(()) + } + + fn remove_peer(&mut self, addr: SocketAddr) { + if let Some(_peer) = self.peers.remove(&addr) { + info!("Closing connection to {}", addr_nice(addr)); + self.table.remove_claims(addr); + } + } + + fn handle_payload_from(&mut self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { + let (src, dst) = P::parse(data.message())?; + let len = data.len(); + debug!("Writing data to device: {} bytes", len); + self.traffic.count_in_payload(src, dst, len); + if let Err(e) = self.device.write(data) { + error!("Failed to send via device: {}", e); + return Err(e) + } + if self.learning { + // Learn single address + self.table.cache(src, peer); + } + Ok(()) + } + + fn process_message( + &mut self, src: SocketAddr, msg_result: MessageResult, data: &mut MsgBuffer + ) -> Result<(), Error> { + match msg_result { + MessageResult::Message(type_) => { + match type_ { + MESSAGE_TYPE_DATA => self.handle_payload_from(src, data)?, + MESSAGE_TYPE_NODE_INFO => { + let info = match NodeInfo::decode(Cursor::new(data.message())) { + Ok(val) => val, + Err(err) => { + self.traffic.count_invalid_protocol(data.len()); + return Err(err) + } + }; + self.update_peer_info(src, Some(info))? + } + MESSAGE_TYPE_KEEPALIVE => self.update_peer_info(src, None)?, + MESSAGE_TYPE_CLOSE => self.remove_peer(src), + _ => { + self.traffic.count_invalid_protocol(data.len()); + return Err(Error::Message("Unknown message type")) + } + } + } + MessageResult::Initialized(info) => self.add_new_peer(src, info)?, + MessageResult::InitializedWithReply(info) => { + self.add_new_peer(src, info)?; + self.send_to(src, data)? + } + MessageResult::Reply => self.send_to(src, data)?, + MessageResult::None => () + } + Ok(()) + } + + fn handle_message(&mut self, src: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { + let src = mapped_addr(src); + debug!("Received {} bytes from {}", data.len(), src); + let msg_result = if let Some(init) = self.pending_inits.get_mut(&src) { + init.handle_message(data) + } else if is_init_message(data.message()) { + let mut result = None; + if let Some(peer) = self.peers.get_mut(&src) { + if peer.crypto.has_init() { + result = Some(peer.crypto.handle_message(data)) + } + } + if let Some(result) = result { + result + } else { + let mut init = self.crypto.peer_instance(self.create_node_info()); + let msg_result = init.handle_message(data); + match msg_result { + Ok(res) => { + self.pending_inits.insert(src, init); + Ok(res) + } + Err(err) => { + self.traffic.count_invalid_protocol(data.len()); + return Err(err) + } + } + } + } else if let Some(peer) = self.peers.get_mut(&src) { + peer.crypto.handle_message(data) + } else { + info!("Ignoring non-init message from unknown peer {}", addr_nice(src)); + self.traffic.count_invalid_protocol(data.len()); + return Ok(()) + }; + match msg_result { + Ok(val) => self.process_message(src, val, data), + Err(err) => { + self.traffic.count_invalid_protocol(data.len()); + Err(err) + } + } + } + + fn housekeep(&mut self) -> Result<(), Error> { // TODO: sync unimplemented!(); } - fn read_socket_data(&mut self, buffer: &mut MsgBuffer) -> Result { - // TODO: read data - // use 5sec timeout - unimplemented!(); - } - - fn handle_message(&mut self, src: SocketAddr, buffer: &mut MsgBuffer) -> Result<(), Error> { - // TODO: handle data - unimplemented!(); - } - pub fn run(mut self) { let mut buffer = MsgBuffer::new(SPACE_BEFORE); loop { - let addr = try_fail!(self.read_socket_data(&mut buffer), "Failed to read from socket: {}"); - if let Err(e) = self.handle_message(addr, &mut buffer) { - error!("{}", e); + let src = try_fail!(self.socket.receive(&mut buffer), "Failed to read from network socket: {}"); + match self.handle_message(src, &mut buffer) { + Err(e @ Error::CryptoInitFatal(_)) => { + debug!("Fatal crypto init error from {}: {}", src, e); + info!("Closing pending connection to {} due to error in crypto init", addr_nice(src)); + self.pending_inits.remove(&src); + } + Err(e @ Error::CryptoInit(_)) => { + debug!("Recoverable init error from {}: {}", src, e); + info!("Ignoring invalid init message from peer {}", addr_nice(src)); + } + Err(e) => { + error!("{}", e); + } + Ok(_) => {} } - let now = T::now(); - if self.next_sync < now { - self.sync(); - self.next_sync = now + SYNC_INTERVAL + let now = TS::now(); + if self.next_housekeep < TS::now() { + if let Err(e) = self.housekeep() { + error!("{}", e) + } + self.next_housekeep = TS::now() + 1 } } }