diff --git a/src/main.rs b/src/main.rs index 05ca9ef..fb62e7e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,6 +39,7 @@ pub mod wizard; #[cfg(feature = "websocket")] pub mod wsproxy; +use net::SocketBuilder; use structopt::StructOpt; use tokio::runtime::Runtime; @@ -176,7 +177,7 @@ fn setup_device(config: &Config) -> TunTapDevice { } #[allow(clippy::cognitive_complexity)] -fn run(config: Config, socket: NetSocket) { +fn run(config: Config, socket: S) { let device = setup_device(&config); let port_forwarding = if config.port_forwarding { socket.create_port_forwarding() } else { None }; let stats_file = match config.stats_file { @@ -226,9 +227,9 @@ fn run(config: Config, socket: NetSocket) { rt.block_on(async move { // Warning: no async code outside this block, or it will break on daemonize let device = AsyncTunTapDevice::from_sync(device); - let socket = try_fail!(AsyncNetSocket::from_sync(socket), "Failed to create async socket: {}"); + let socket = try_fail!(socket.build(), "Failed to create async socket: {}"); let mut cloud = try_fail!( - GenericCloud::::new( + GenericCloud::::new( &config, socket, device, @@ -345,27 +346,18 @@ fn main() { error!("Either password or private key must be set in config or given as parameter"); return; } - /* #[cfg(feature = "websocket")] if config.listen.starts_with("ws://") { - let socket = { - let rt = Runtime::new().unwrap(); - try_fail!( - rt.block_on(ProxyConnection::listen(&config.listen)), - "Failed to open socket {}: {}", - config.listen - ) - }; + let socket = try_fail!(ProxyConnection::listen(&config.listen), "Failed to open socket {}: {}", config.listen); match config.device_type { Type::Tap => run::(config, socket), Type::Tun => run::(config, socket), } return; } - */ let socket = try_fail!(NetSocket::listen(&config.listen), "Failed to open socket {}: {}", config.listen); match config.device_type { - Type::Tap => run::(config, socket), - Type::Tun => run::(config, socket), + Type::Tap => run::(config, socket), + Type::Tun => run::(config, socket), } } diff --git a/src/net.rs b/src/net.rs index 313570d..14aa3b4 100644 --- a/src/net.rs +++ b/src/net.rs @@ -32,6 +32,12 @@ pub fn get_ip() -> IpAddr { s.local_addr().unwrap().ip() } +pub trait SocketBuilder { + type SocketType: Socket; + fn build(self) -> Result; + fn create_port_forwarding(&self) -> Option; +} + #[async_trait] pub trait Socket: Sized + Clone + Send + Sync + 'static { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result; @@ -60,11 +66,18 @@ impl NetSocket { let addr = parse_listen(addr, DEFAULT_PORT); Ok(Self(UdpSocket::bind(addr)?)) } +} - pub fn create_port_forwarding(&self) -> Option { +impl SocketBuilder for NetSocket { + type SocketType = AsyncNetSocket; + + fn create_port_forwarding(&self) -> Option { PortForwarding::new(self.0.local_addr().unwrap().port()) } + fn build(self) -> Result { + Ok(AsyncNetSocket(Arc::new(AsyncUdpSocket::from_std(self.0)?))) + } } pub struct AsyncNetSocket(Arc); @@ -75,13 +88,6 @@ impl Clone for AsyncNetSocket { } } -impl AsyncNetSocket { - pub fn from_sync(sock: NetSocket) -> Result { - Ok(Self(Arc::new(AsyncUdpSocket::from_std(sock.0)?))) - } -} - - #[async_trait] impl Socket for AsyncNetSocket { async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { diff --git a/src/wsproxy.rs b/src/wsproxy.rs index f31df64..26c43fb 100644 --- a/src/wsproxy.rs +++ b/src/wsproxy.rs @@ -3,7 +3,7 @@ // This software is licensed under GPL-3 or newer (see LICENSE.md) use super::{ - net::{get_ip, mapped_addr, parse_listen, Socket}, + net::{get_ip, mapped_addr, parse_listen, Socket, SocketBuilder}, poll::{WaitImpl, WaitResult}, port_forwarding::PortForwarding, util::MsgBuffer, @@ -115,6 +115,17 @@ pub struct ProxyConnection { } impl ProxyConnection { + pub fn listen(url: &str) -> Result { + let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; + let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; + socket.get_mut().set_nodelay(true)?; + let addr = "0.0.0.0:0".parse::().unwrap(); + let mut con = ProxyConnection { addr, socket: Arc::new(socket) }; + let addr_data = con.read_message()?; + con.addr = read_addr(Cursor::new(&addr_data))?; + Ok(con) + } + fn read_message(&mut self) -> Result, io::Error> { loop { unimplemented!(); @@ -127,21 +138,20 @@ impl ProxyConnection { } } +impl SocketBuilder for ProxyConnection { + type SocketType = ProxyConnection; + + fn build(self) -> Result { + Ok(self) + } + + fn create_port_forwarding(&self) -> Option { + None + } +} + #[async_trait] impl Socket for ProxyConnection { - /* - async fn listen(url: &str) -> Result { - let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; - let (mut socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; - socket.get_mut().set_nodelay(true)?; - let addr = "0.0.0.0:0".parse::().unwrap(); - let mut con = ProxyConnection { addr, socket: Arc::new(socket) }; - let addr_data = con.read_message()?; - con.addr = read_addr(Cursor::new(&addr_data))?; - Ok(con) - } - */ - async fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { buffer.clear(); let data = self.read_message()?;