diff --git a/src/benches.rs b/src/benches.rs index 724246b..988cc39 100644 --- a/src/benches.rs +++ b/src/benches.rs @@ -77,7 +77,7 @@ fn message_decode(b: &mut Bencher) { #[bench] fn switch_learn(b: &mut Bencher) { - let mut table = SwitchTable::new(10); + let mut table = SwitchTable::new(10, 0); let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); b.iter(|| { @@ -88,7 +88,7 @@ fn switch_learn(b: &mut Bencher) { #[bench] fn switch_lookup(b: &mut Bencher) { - let mut table = SwitchTable::new(10); + let mut table = SwitchTable::new(10, 0); let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); table.learn(addr.clone(), None, peer); @@ -152,7 +152,7 @@ fn epoll_wait(b: &mut Bencher) { fn handle_interface_data(b: &mut Bencher) { let mut node = GenericCloud::::new( MAGIC, Device::dummy("vpncloud0", "/dev/null", Type::Tap).unwrap(), 0, - Box::new(SwitchTable::new(300)), 1800, true, true, vec![], Crypto::None, None + Box::new(SwitchTable::new(300, 10)), 1800, true, true, vec![], Crypto::None, None ); let mut data = [0; 1500]; data[105] = 45; @@ -166,7 +166,7 @@ fn handle_interface_data(b: &mut Bencher) { fn handle_net_message(b: &mut Bencher) { let mut node = GenericCloud::::new( MAGIC, Device::dummy("vpncloud0", "/dev/null", Type::Tap).unwrap(), 0, - Box::new(SwitchTable::new(300)), 1800, true, true, vec![], Crypto::None, None + Box::new(SwitchTable::new(300, 10)), 1800, true, true, vec![], Crypto::None, None ); let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1)); let mut data = [0; 1500]; diff --git a/src/ethernet.rs b/src/ethernet.rs index de8b8e8..6155442 100644 --- a/src/ethernet.rs +++ b/src/ethernet.rs @@ -71,13 +71,15 @@ pub struct SwitchTable { /// The table storing the actual mapping table: HashMap, /// Timeout period for forgetting learnt addresses - timeout: Duration + timeout: Duration, + // Timeout period for not overwriting learnt addresses + protection_period: Duration, } impl SwitchTable { /// Creates a new switch table - pub fn new(timeout: Duration) -> Self { - SwitchTable{table: HashMap::default(), timeout: timeout} + pub fn new(timeout: Duration, protection_period: Duration) -> Self { + SwitchTable{table: HashMap::default(), timeout: timeout, protection_period: protection_period} } } @@ -108,7 +110,7 @@ impl Table for SwitchTable { }, Entry::Occupied(mut entry) => { let mut entry = entry.get_mut(); - if entry.timeout + 10 >= deadline { + if entry.timeout + self.protection_period as Time > deadline { // Do not override recently learnt entries return } diff --git a/src/main.rs b/src/main.rs index ee7248c..ec347d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -150,12 +150,12 @@ fn run (config: Config) { let peer_timeout = config.peer_timeout; let (learning, broadcasting, table): (bool, bool, Box) = match config.mode { Mode::Normal => match config.device_type { - Type::Tap => (true, true, Box::new(SwitchTable::new(dst_timeout))), + Type::Tap => (true, true, Box::new(SwitchTable::new(dst_timeout, 10))), Type::Tun => (false, false, Box::new(RoutingTable::new())) }, Mode::Router => (false, false, Box::new(RoutingTable::new())), - Mode::Switch => (true, true, Box::new(SwitchTable::new(dst_timeout))), - Mode::Hub => (false, true, Box::new(SwitchTable::new(dst_timeout))) + Mode::Switch => (true, true, Box::new(SwitchTable::new(dst_timeout, 10))), + Mode::Hub => (false, true, Box::new(SwitchTable::new(dst_timeout, 10))) }; let magic = config.get_magic(); Crypto::init(); diff --git a/src/tests.rs b/src/tests.rs index 2ca307a..0b191eb 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -4,6 +4,8 @@ use std::net::{ToSocketAddrs, SocketAddr}; use std::str::FromStr; +use std::thread; +use std::time::Duration; use super::MAGIC; use super::ethernet::{Frame, SwitchTable}; @@ -237,12 +239,19 @@ fn decode_invalid_packet() { #[test] fn switch() { - let mut table = SwitchTable::new(10); + let mut table = SwitchTable::new(10, 1); let addr = Address::from_str("12:34:56:78:90:ab").unwrap(); let peer = "1.2.3.4:5678".to_socket_addrs().unwrap().next().unwrap(); + let peer2 = "1.2.3.5:7890".to_socket_addrs().unwrap().next().unwrap(); assert!(table.lookup(&addr).is_none()); table.learn(addr.clone(), None, peer.clone()); assert_eq!(table.lookup(&addr), Some(peer)); + // Do not override within 1 seconds + table.learn(addr.clone(), None, peer2.clone()); + assert_eq!(table.lookup(&addr), Some(peer)); + thread::sleep(Duration::from_secs(1)); + table.learn(addr.clone(), None, peer2.clone()); + assert_eq!(table.lookup(&addr), Some(peer2)); } #[test] @@ -493,6 +502,7 @@ fn config_merge() { flag_listen: Some(3211), flag_peer_timeout: Some(1801), flag_dst_timeout: Some(301), + flag_mode: Some(Mode::Switch), flag_subnet: vec![], flag_connect: vec!["another:3210".to_string()], flag_no_port_forwarding: true, @@ -514,7 +524,7 @@ fn config_merge() { peers: vec!["remote.machine.foo:3210".to_string(), "remote.machine.bar:3210".to_string(), "another:3210".to_string()], peer_timeout: 1801, dst_timeout: 301, - mode: Mode::Normal, + mode: Mode::Switch, port_forwarding: false, subnets: vec!["10.0.1.0/24".to_string()], user: Some("root".to_string()),