diff --git a/src/crypto/common.rs b/src/crypto/common.rs index cb2402d..785686f 100644 --- a/src/crypto/common.rs +++ b/src/crypto/common.rs @@ -241,7 +241,7 @@ impl PeerCrypto { } } - fn encrypt_message(&mut self, buffer: &mut MsgBuffer) { + pub fn encrypt_message(&mut self, buffer: &mut MsgBuffer) { // HOT PATH if let PeerCrypto::Encrypted { core, .. } = self { core.encrypt(buffer) @@ -283,13 +283,6 @@ impl PeerCrypto { } } - pub fn send_message(&mut self, type_: u8, buffer: &mut MsgBuffer) { - // HOT PATH - assert_ne!(type_, MESSAGE_TYPE_ROTATION); - buffer.prepend_byte(type_); - self.encrypt_message(buffer); - } - pub fn every_second(&mut self, out: &mut MsgBuffer) -> MessageResult { out.clear(); if let PeerCrypto::Encrypted { core, rotation, rotate_counter, algorithm, .. } = self { @@ -367,7 +360,8 @@ mod tests { buffer.set_length(1000); rng.fill(buffer.message_mut()).unwrap(); for _ in 0..1000 { - node1.send_message(1, &mut buffer); + buffer.prepend_byte(1); + node1.encrypt_message(&mut buffer); let res = node2.handle_message(&mut buffer).unwrap(); assert_eq!(res, MessageResult::Message(1)); diff --git a/src/crypto/init.rs b/src/crypto/init.rs index def486b..0741a4a 100644 --- a/src/crypto/init.rs +++ b/src/crypto/init.rs @@ -552,7 +552,7 @@ impl InitState

{ out.set_length(len); } - fn repeat_last_message(&self, out: &mut MsgBuffer) { + pub fn repeat_last_message(&self, out: &mut MsgBuffer) { if let Some(ref bytes) = self.last_message { debug!("Repeating last init message"); let buffer = out.buffer(); diff --git a/src/engine/device_thread.rs b/src/engine/device_thread.rs index 70ecb59..07e4ced 100644 --- a/src/engine/device_thread.rs +++ b/src/engine/device_thread.rs @@ -42,11 +42,9 @@ impl DeviceThread Result<(), Error> { debug!("Sending msg with {} bytes to {}", msg.len(), addr); - if self.peer_crypto.send_message(addr, type_, msg)? { - self.send_to(addr, msg) - } else { - Err(Error::Message("Sending to node that is not a peer")) - } + msg.prepend_byte(type_); + self.peer_crypto.encrypt_for(addr, msg)?; + self.send_to(addr, msg) } #[inline] @@ -57,7 +55,10 @@ impl DeviceThread Ok(()), diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 69eb995..ae36975 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -26,7 +26,7 @@ use smallvec::{smallvec, SmallVec}; use crate::{ beacon::BeaconSerializer, config::{Config, DEFAULT_PEER_TIMEOUT, DEFAULT_PORT}, - crypto::{is_init_message, Crypto, MessageResult, PeerCrypto}, + crypto::{is_init_message, Crypto, MessageResult, PeerCrypto, InitState, InitResult}, device::{Device, Type}, error::Error, messages::{ @@ -58,7 +58,7 @@ struct PeerData { timeout: Time, peer_timeout: u16, node_id: NodeId, - crypto: PeerCrypto + crypto: PeerCrypto } #[derive(Clone)] @@ -190,7 +190,8 @@ impl GenericCloud Ok(()), @@ -221,7 +222,8 @@ impl GenericCloud peer, None => return Err(Error::Message("Sending to node that is not a peer")) }; - peer.crypto.send_message(type_, msg)?; + msg.prepend_byte(type_); + peer.crypto.encrypt_message(msg); self.send_to(addr, msg) } @@ -330,7 +332,7 @@ impl GenericCloud, Hash>>> + peers: Arc>, Hash>>> } impl SharedPeerCrypto { @@ -20,20 +20,19 @@ impl SharedPeerCrypto { // TODO sync if needed } - pub fn send_message(&mut self, peer: SocketAddr, type_: u8, data: &mut MsgBuffer) -> Result { + pub fn encrypt_for(&mut self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { let mut peers = self.peers.lock(); - if let Some(peer) = peers.get_mut(&peer) { - peer.send_message(type_, data); - Ok(true) - } else { - Ok(false) - } + match peers.get_mut(&peer) { + None => Err(Error::InvalidCryptoState("No crypto found for peer")), + Some(None) => Ok(()), + Some(Some(crypto)) => Ok(crypto.encrypt(data)) + } } - pub fn for_each(&mut self, mut callback: impl FnMut(SocketAddr, &mut CryptoCore) -> Result<(), Error>) -> Result<(), Error> { + pub fn for_each(&mut self, mut callback: impl FnMut(SocketAddr, Option>) -> Result<(), Error>) -> Result<(), Error> { let mut peers = self.peers.lock(); for (k, v) in peers.iter_mut() { - callback(*k, v)? + callback(*k, v.clone())? } Ok(()) } diff --git a/src/engine/socket_thread.rs b/src/engine/socket_thread.rs index b439583..f8352a0 100644 --- a/src/engine/socket_thread.rs +++ b/src/engine/socket_thread.rs @@ -5,7 +5,7 @@ use super::{ use crate::{ config::DEFAULT_PEER_TIMEOUT, - crypto::{is_init_message, MessageResult, PeerCrypto}, + crypto::{is_init_message, MessageResult, PeerCrypto, InitState, InitResult}, engine::{addr_nice, resolve, Hash, PeerData}, error::Error, messages::{AddrList, NodeInfo, PeerInfo}, @@ -38,7 +38,7 @@ pub struct SocketThread { device: D, next_housekeep: Time, own_addresses: AddrList, - pending_inits: HashMap, Hash>, + pending_inits: HashMap, Hash>, crypto: Crypto, peers: HashMap, // Shared fields @@ -48,7 +48,7 @@ pub struct SocketThread { impl SocketThread { #[inline] - fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { + fn send_to(&mut self, addr: SocketAddr, msg: &MsgBuffer) -> Result<(), Error> { debug!("Sending msg with {} bytes to {}", msg.len(), addr); self.traffic.count_out_traffic(addr, msg.len()); match self.socket.send(msg.message(), addr) { @@ -68,10 +68,10 @@ impl SocketThread SocketThread Result<(), Error> { + fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo, msg: &mut MsgBuffer) -> Result<(), Error> { info!("Added peer {}", addr_nice(addr)); if let Some(init) = self.pending_inits.remove(&addr) { + msg.clear(); + let crypto = init.finish(&mut msg); self.peers.insert(addr, PeerData { addrs: info.addrs.clone(), - crypto: init, + crypto, node_id: info.node_id, peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT), last_seen: TS::now(), timeout: TS::now() + self.config.peer_timeout as Time }); self.update_peer_info(addr, Some(info))?; + if !msg.is_empty() { + self.send_to(addr, msg)?; + } } else { error!("No init for new peer {}", addr_nice(addr)); } @@ -193,7 +198,7 @@ impl SocketThread, data: &mut MsgBuffer + &mut self, src: SocketAddr, msg_result: MessageResult, data: &mut MsgBuffer ) -> Result<(), Error> { match msg_result { MessageResult::Message(type_) => { @@ -217,11 +222,6 @@ impl SocketThread self.add_new_peer(src, info)?, - MessageResult::InitializedWithReply(info) => { - self.add_new_peer(src, info)?; - self.send_to(src, data)? - } MessageResult::Reply => self.send_to(src, data)?, MessageResult::None => () } @@ -231,43 +231,48 @@ impl SocketThread Result<(), Error> { let src = mapped_addr(src); debug!("Received {} bytes from {}", data.len(), src); - let msg_result = if let Some(init) = self.pending_inits.get_mut(&src) { - init.handle_message(data) - } else if is_init_message(data.message()) { - let mut result = None; - if let Some(peer) = self.peers.get_mut(&src) { - if peer.crypto.has_init() { - result = Some(peer.crypto.handle_message(data)) - } - } - if let Some(result) = result { - result - } else { - let mut init = self.crypto.peer_instance(self.create_node_info()); - let msg_result = init.handle_message(data); - match msg_result { - Ok(res) => { - self.pending_inits.insert(src, init); - Ok(res) - } - Err(err) => { - self.traffic.count_invalid_protocol(data.len()); - return Err(err) - } - } - } - } else if let Some(peer) = self.peers.get_mut(&src) { + if let Some(result) = self.peers.get_mut(&src).map(|peer| { peer.crypto.handle_message(data) - } else { + }) { + return self.process_message(src, result?, data) + } + let is_init = is_init_message(data.message()); + if let Some(result) = self.pending_inits.get_mut(&src).map(|init| { + if is_init { + init.handle_init(data) + } else { + data.clear(); + init.repeat_last_message(data); + Ok(InitResult::Continue) + } + }) { + match result? { + InitResult::Continue => { + if !data.is_empty() { + self.send_to(src, data)? + } + }, + InitResult::Success { peer_payload, is_initiator } => { + self.add_new_peer(src, peer_payload, data)? + } + } + return Ok(()) + } + if !is_init_message(data.message()) { info!("Ignoring non-init message from unknown peer {}", addr_nice(src)); self.traffic.count_invalid_protocol(data.len()); return Ok(()) - }; + } + let mut init = self.crypto.peer_instance(self.create_node_info()); + let msg_result = init.handle_init(data); match msg_result { - Ok(val) => self.process_message(src, val, data), + Ok(res) => { + self.pending_inits.insert(src, init); + self.send_to(src, data) + } Err(err) => { self.traffic.count_invalid_protocol(data.len()); - Err(err) + return Err(err) } } }