More code

This commit is contained in:
Dennis Schwerdel 2021-02-10 20:47:25 +01:00
parent 21d58f25a3
commit 82c43ff496
6 changed files with 76 additions and 75 deletions

View File

@ -241,7 +241,7 @@ impl PeerCrypto {
} }
} }
fn encrypt_message(&mut self, buffer: &mut MsgBuffer) { pub fn encrypt_message(&mut self, buffer: &mut MsgBuffer) {
// HOT PATH // HOT PATH
if let PeerCrypto::Encrypted { core, .. } = self { if let PeerCrypto::Encrypted { core, .. } = self {
core.encrypt(buffer) core.encrypt(buffer)
@ -283,13 +283,6 @@ impl PeerCrypto {
} }
} }
pub fn send_message(&mut self, type_: u8, buffer: &mut MsgBuffer) {
// HOT PATH
assert_ne!(type_, MESSAGE_TYPE_ROTATION);
buffer.prepend_byte(type_);
self.encrypt_message(buffer);
}
pub fn every_second(&mut self, out: &mut MsgBuffer) -> MessageResult { pub fn every_second(&mut self, out: &mut MsgBuffer) -> MessageResult {
out.clear(); out.clear();
if let PeerCrypto::Encrypted { core, rotation, rotate_counter, algorithm, .. } = self { if let PeerCrypto::Encrypted { core, rotation, rotate_counter, algorithm, .. } = self {
@ -367,7 +360,8 @@ mod tests {
buffer.set_length(1000); buffer.set_length(1000);
rng.fill(buffer.message_mut()).unwrap(); rng.fill(buffer.message_mut()).unwrap();
for _ in 0..1000 { for _ in 0..1000 {
node1.send_message(1, &mut buffer); buffer.prepend_byte(1);
node1.encrypt_message(&mut buffer);
let res = node2.handle_message(&mut buffer).unwrap(); let res = node2.handle_message(&mut buffer).unwrap();
assert_eq!(res, MessageResult::Message(1)); assert_eq!(res, MessageResult::Message(1));

View File

@ -552,7 +552,7 @@ impl<P: Payload> InitState<P> {
out.set_length(len); out.set_length(len);
} }
fn repeat_last_message(&self, out: &mut MsgBuffer) { pub fn repeat_last_message(&self, out: &mut MsgBuffer) {
if let Some(ref bytes) = self.last_message { if let Some(ref bytes) = self.last_message {
debug!("Repeating last init message"); debug!("Repeating last init message");
let buffer = out.buffer(); let buffer = out.buffer();

View File

@ -42,11 +42,9 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> DeviceThread<S, D, P, TS
#[inline] #[inline]
fn send_msg(&mut self, addr: SocketAddr, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> { fn send_msg(&mut self, addr: SocketAddr, type_: u8, msg: &mut MsgBuffer) -> Result<(), Error> {
debug!("Sending msg with {} bytes to {}", msg.len(), addr); debug!("Sending msg with {} bytes to {}", msg.len(), addr);
if self.peer_crypto.send_message(addr, type_, msg)? { msg.prepend_byte(type_);
self.send_to(addr, msg) self.peer_crypto.encrypt_for(addr, msg)?;
} else { self.send_to(addr, msg)
Err(Error::Message("Sending to node that is not a peer"))
}
} }
#[inline] #[inline]
@ -57,7 +55,10 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> DeviceThread<S, D, P, TS
msg_data.set_start(msg.get_start()); msg_data.set_start(msg.get_start());
msg_data.set_length(msg.len()); msg_data.set_length(msg.len());
msg_data.message_mut().clone_from_slice(msg.message()); msg_data.message_mut().clone_from_slice(msg.message());
crypto.send_message(type_, &mut msg_data)?; msg_data.prepend_byte(type_);
if let Some(crypto) = crypto {
crypto.encrypt(&mut msg_data);
}
self.traffic.count_out_traffic(addr, msg_data.len()); self.traffic.count_out_traffic(addr, msg_data.len());
match self.socket.send(msg_data.message(), addr) { match self.socket.send(msg_data.message(), addr) {
Ok(written) if written == msg_data.len() => Ok(()), Ok(written) if written == msg_data.len() => Ok(()),

View File

@ -26,7 +26,7 @@ use smallvec::{smallvec, SmallVec};
use crate::{ use crate::{
beacon::BeaconSerializer, beacon::BeaconSerializer,
config::{Config, DEFAULT_PEER_TIMEOUT, DEFAULT_PORT}, config::{Config, DEFAULT_PEER_TIMEOUT, DEFAULT_PORT},
crypto::{is_init_message, Crypto, MessageResult, PeerCrypto}, crypto::{is_init_message, Crypto, MessageResult, PeerCrypto, InitState, InitResult},
device::{Device, Type}, device::{Device, Type},
error::Error, error::Error,
messages::{ messages::{
@ -58,7 +58,7 @@ struct PeerData {
timeout: Time, timeout: Time,
peer_timeout: u16, peer_timeout: u16,
node_id: NodeId, node_id: NodeId,
crypto: PeerCrypto<NodeInfo> crypto: PeerCrypto
} }
#[derive(Clone)] #[derive(Clone)]
@ -190,7 +190,8 @@ impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS
msg_data.set_start(msg.get_start()); msg_data.set_start(msg.get_start());
msg_data.set_length(msg.len()); msg_data.set_length(msg.len());
msg_data.message_mut().clone_from_slice(msg.message()); msg_data.message_mut().clone_from_slice(msg.message());
peer.crypto.send_message(type_, &mut msg_data)?; msg_data.prepend_byte(type_);
peer.crypto.encrypt_message(&mut msg_data);
self.traffic.count_out_traffic(*addr, msg_data.len()); self.traffic.count_out_traffic(*addr, msg_data.len());
match self.socket.send(msg_data.message(), *addr) { match self.socket.send(msg_data.message(), *addr) {
Ok(written) if written == msg_data.len() => Ok(()), Ok(written) if written == msg_data.len() => Ok(()),
@ -221,7 +222,8 @@ impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS
Some(peer) => peer, Some(peer) => peer,
None => return Err(Error::Message("Sending to node that is not a peer")) None => return Err(Error::Message("Sending to node that is not a peer"))
}; };
peer.crypto.send_message(type_, msg)?; msg.prepend_byte(type_);
peer.crypto.encrypt_message(msg);
self.send_to(addr, msg) self.send_to(addr, msg)
} }
@ -330,7 +332,7 @@ impl<D: Device, P: Protocol, S: Socket, TS: TimeSource> GenericCloud<D, P, S, TS
let payload = self.create_node_info(); let payload = self.create_node_info();
let mut peer_crypto = self.crypto.peer_instance(payload); let mut peer_crypto = self.crypto.peer_instance(payload);
let mut msg = MsgBuffer::new(SPACE_BEFORE); let mut msg = MsgBuffer::new(SPACE_BEFORE);
peer_crypto.initialize(&mut msg)?; peer_crypto.send_ping(&mut msg);
self.pending_inits.insert(addr, peer_crypto); self.pending_inits.insert(addr, peer_crypto);
self.send_to(addr, &mut msg) self.send_to(addr, &mut msg)
} }

View File

@ -12,7 +12,7 @@ use parking_lot::Mutex;
use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{collections::HashMap, net::SocketAddr, sync::Arc};
pub struct SharedPeerCrypto { pub struct SharedPeerCrypto {
peers: Arc<Mutex<HashMap<SocketAddr, Arc<CryptoCore>, Hash>>> peers: Arc<Mutex<HashMap<SocketAddr, Option<Arc<CryptoCore>>, Hash>>>
} }
impl SharedPeerCrypto { impl SharedPeerCrypto {
@ -20,20 +20,19 @@ impl SharedPeerCrypto {
// TODO sync if needed // TODO sync if needed
} }
pub fn send_message(&mut self, peer: SocketAddr, type_: u8, data: &mut MsgBuffer) -> Result<bool, Error> { pub fn encrypt_for(&mut self, peer: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> {
let mut peers = self.peers.lock(); let mut peers = self.peers.lock();
if let Some(peer) = peers.get_mut(&peer) { match peers.get_mut(&peer) {
peer.send_message(type_, data); None => Err(Error::InvalidCryptoState("No crypto found for peer")),
Ok(true) Some(None) => Ok(()),
} else { Some(Some(crypto)) => Ok(crypto.encrypt(data))
Ok(false) }
}
} }
pub fn for_each(&mut self, mut callback: impl FnMut(SocketAddr, &mut CryptoCore) -> Result<(), Error>) -> Result<(), Error> { pub fn for_each(&mut self, mut callback: impl FnMut(SocketAddr, Option<Arc<CryptoCore>>) -> Result<(), Error>) -> Result<(), Error> {
let mut peers = self.peers.lock(); let mut peers = self.peers.lock();
for (k, v) in peers.iter_mut() { for (k, v) in peers.iter_mut() {
callback(*k, v)? callback(*k, v.clone())?
} }
Ok(()) Ok(())
} }

View File

@ -5,7 +5,7 @@ use super::{
use crate::{ use crate::{
config::DEFAULT_PEER_TIMEOUT, config::DEFAULT_PEER_TIMEOUT,
crypto::{is_init_message, MessageResult, PeerCrypto}, crypto::{is_init_message, MessageResult, PeerCrypto, InitState, InitResult},
engine::{addr_nice, resolve, Hash, PeerData}, engine::{addr_nice, resolve, Hash, PeerData},
error::Error, error::Error,
messages::{AddrList, NodeInfo, PeerInfo}, messages::{AddrList, NodeInfo, PeerInfo},
@ -38,7 +38,7 @@ pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
device: D, device: D,
next_housekeep: Time, next_housekeep: Time,
own_addresses: AddrList, own_addresses: AddrList,
pending_inits: HashMap<SocketAddr, PeerCrypto<NodeInfo>, Hash>, pending_inits: HashMap<SocketAddr, InitState<NodeInfo>, Hash>,
crypto: Crypto, crypto: Crypto,
peers: HashMap<SocketAddr, PeerData, Hash>, peers: HashMap<SocketAddr, PeerData, Hash>,
// Shared fields // Shared fields
@ -48,7 +48,7 @@ pub struct SocketThread<S: Socket, D: Device, P: Protocol, TS: TimeSource> {
impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS> { impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS> {
#[inline] #[inline]
fn send_to(&mut self, addr: SocketAddr, msg: &mut MsgBuffer) -> Result<(), Error> { fn send_to(&mut self, addr: SocketAddr, msg: &MsgBuffer) -> Result<(), Error> {
debug!("Sending msg with {} bytes to {}", msg.len(), addr); debug!("Sending msg with {} bytes to {}", msg.len(), addr);
self.traffic.count_out_traffic(addr, msg.len()); self.traffic.count_out_traffic(addr, msg.len());
match self.socket.send(msg.message(), addr) { match self.socket.send(msg.message(), addr) {
@ -68,10 +68,10 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
debug!("Connecting to {:?}", addr); debug!("Connecting to {:?}", addr);
let payload = self.create_node_info(); let payload = self.create_node_info();
let mut peer_crypto = self.crypto.peer_instance(payload); let mut init = self.crypto.peer_instance(payload);
let mut msg = MsgBuffer::new(SPACE_BEFORE); let mut msg = MsgBuffer::new(SPACE_BEFORE);
peer_crypto.initialize(&mut msg)?; init.send_ping(&mut msg);
self.pending_inits.insert(addr, peer_crypto); self.pending_inits.insert(addr, init);
self.send_to(addr, &mut msg) self.send_to(addr, &mut msg)
} }
@ -129,18 +129,23 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
Ok(()) Ok(())
} }
fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo) -> Result<(), Error> { fn add_new_peer(&mut self, addr: SocketAddr, info: NodeInfo, msg: &mut MsgBuffer) -> Result<(), Error> {
info!("Added peer {}", addr_nice(addr)); info!("Added peer {}", addr_nice(addr));
if let Some(init) = self.pending_inits.remove(&addr) { if let Some(init) = self.pending_inits.remove(&addr) {
msg.clear();
let crypto = init.finish(&mut msg);
self.peers.insert(addr, PeerData { self.peers.insert(addr, PeerData {
addrs: info.addrs.clone(), addrs: info.addrs.clone(),
crypto: init, crypto,
node_id: info.node_id, node_id: info.node_id,
peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT), peer_timeout: info.peer_timeout.unwrap_or(DEFAULT_PEER_TIMEOUT),
last_seen: TS::now(), last_seen: TS::now(),
timeout: TS::now() + self.config.peer_timeout as Time timeout: TS::now() + self.config.peer_timeout as Time
}); });
self.update_peer_info(addr, Some(info))?; self.update_peer_info(addr, Some(info))?;
if !msg.is_empty() {
self.send_to(addr, msg)?;
}
} else { } else {
error!("No init for new peer {}", addr_nice(addr)); error!("No init for new peer {}", addr_nice(addr));
} }
@ -193,7 +198,7 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
fn process_message( fn process_message(
&mut self, src: SocketAddr, msg_result: MessageResult<NodeInfo>, data: &mut MsgBuffer &mut self, src: SocketAddr, msg_result: MessageResult, data: &mut MsgBuffer
) -> Result<(), Error> { ) -> Result<(), Error> {
match msg_result { match msg_result {
MessageResult::Message(type_) => { MessageResult::Message(type_) => {
@ -217,11 +222,6 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
} }
} }
} }
MessageResult::Initialized(info) => self.add_new_peer(src, info)?,
MessageResult::InitializedWithReply(info) => {
self.add_new_peer(src, info)?;
self.send_to(src, data)?
}
MessageResult::Reply => self.send_to(src, data)?, MessageResult::Reply => self.send_to(src, data)?,
MessageResult::None => () MessageResult::None => ()
} }
@ -231,43 +231,48 @@ impl<S: Socket, D: Device, P: Protocol, TS: TimeSource> SocketThread<S, D, P, TS
fn handle_message(&mut self, src: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> { fn handle_message(&mut self, src: SocketAddr, data: &mut MsgBuffer) -> Result<(), Error> {
let src = mapped_addr(src); let src = mapped_addr(src);
debug!("Received {} bytes from {}", data.len(), src); debug!("Received {} bytes from {}", data.len(), src);
let msg_result = if let Some(init) = self.pending_inits.get_mut(&src) { if let Some(result) = self.peers.get_mut(&src).map(|peer| {
init.handle_message(data)
} else if is_init_message(data.message()) {
let mut result = None;
if let Some(peer) = self.peers.get_mut(&src) {
if peer.crypto.has_init() {
result = Some(peer.crypto.handle_message(data))
}
}
if let Some(result) = result {
result
} else {
let mut init = self.crypto.peer_instance(self.create_node_info());
let msg_result = init.handle_message(data);
match msg_result {
Ok(res) => {
self.pending_inits.insert(src, init);
Ok(res)
}
Err(err) => {
self.traffic.count_invalid_protocol(data.len());
return Err(err)
}
}
}
} else if let Some(peer) = self.peers.get_mut(&src) {
peer.crypto.handle_message(data) peer.crypto.handle_message(data)
} else { }) {
return self.process_message(src, result?, data)
}
let is_init = is_init_message(data.message());
if let Some(result) = self.pending_inits.get_mut(&src).map(|init| {
if is_init {
init.handle_init(data)
} else {
data.clear();
init.repeat_last_message(data);
Ok(InitResult::Continue)
}
}) {
match result? {
InitResult::Continue => {
if !data.is_empty() {
self.send_to(src, data)?
}
},
InitResult::Success { peer_payload, is_initiator } => {
self.add_new_peer(src, peer_payload, data)?
}
}
return Ok(())
}
if !is_init_message(data.message()) {
info!("Ignoring non-init message from unknown peer {}", addr_nice(src)); info!("Ignoring non-init message from unknown peer {}", addr_nice(src));
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(data.len());
return Ok(()) return Ok(())
}; }
let mut init = self.crypto.peer_instance(self.create_node_info());
let msg_result = init.handle_init(data);
match msg_result { match msg_result {
Ok(val) => self.process_message(src, val, data), Ok(res) => {
self.pending_inits.insert(src, init);
self.send_to(src, data)
}
Err(err) => { Err(err) => {
self.traffic.count_invalid_protocol(data.len()); self.traffic.count_invalid_protocol(data.len());
Err(err) return Err(err)
} }
} }
} }