diff --git a/src/cloud.rs b/src/cloud.rs index 23100ea..ddb99f3 100644 --- a/src/cloud.rs +++ b/src/cloud.rs @@ -824,7 +824,7 @@ impl GenericCloud run::(config, socket), diff --git a/src/poll/epoll.rs b/src/poll/epoll.rs index 5d69d43..73679de 100644 --- a/src/poll/epoll.rs +++ b/src/poll/epoll.rs @@ -2,11 +2,9 @@ // Copyright (C) 2015-2020 Dennis Schwerdel // This software is licensed under GPL-3 or newer (see LICENSE.md) -use crate::device::Device; use std::{io, os::unix::io::RawFd}; use super::WaitResult; -use crate::net::Socket; pub struct EpollWait { poll_fd: RawFd, @@ -17,21 +15,21 @@ pub struct EpollWait { } impl EpollWait { - pub fn new(socket: &S, device: &dyn Device, timeout: u32) -> io::Result { + pub fn new(socket: RawFd, device: RawFd, timeout: u32) -> io::Result { Self::create(socket, device, timeout, libc::EPOLLIN as u32) } - pub fn testing(socket: &S, device: &dyn Device, timeout: u32) -> io::Result { + pub fn testing(socket: RawFd, device: RawFd, timeout: u32) -> io::Result { Self::create(socket, device, timeout, (libc::EPOLLIN | libc::EPOLLOUT) as u32) } - fn create(socket: &S, device: &dyn Device, timeout: u32, flags: u32) -> io::Result { + fn create(socket: RawFd, device: RawFd, timeout: u32, flags: u32) -> io::Result { let mut event = libc::epoll_event { u64: 0, events: 0 }; let poll_fd = unsafe { libc::epoll_create(3) }; if poll_fd == -1 { return Err(io::Error::last_os_error()) } - for fd in &[socket.as_raw_fd(), device.as_raw_fd()] { + for fd in &[socket, device] { event.u64 = *fd as u64; event.events = flags; let res = unsafe { libc::epoll_ctl(poll_fd, libc::EPOLL_CTL_ADD, *fd, &mut event) }; @@ -39,7 +37,7 @@ impl EpollWait { return Err(io::Error::last_os_error()) } } - Ok(Self { poll_fd, event, socket: socket.as_raw_fd(), device: device.as_raw_fd(), timeout }) + Ok(Self { poll_fd, event, socket, device, timeout }) } } diff --git a/src/wsproxy.rs b/src/wsproxy.rs index c387640..25f7c82 100644 --- a/src/wsproxy.rs +++ b/src/wsproxy.rs @@ -1,42 +1,125 @@ use super::{ net::{mapped_addr, Socket}, + poll::{WaitImpl, WaitResult}, port_forwarding::PortForwarding, util::MsgBuffer }; -use std::{ - io::{self, Write, Read, Cursor}, - net::{SocketAddr, TcpListener, SocketAddrV6, Ipv6Addr}, - os::unix::io::{AsRawFd, RawFd}, - thread::spawn, -}; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use std::{ + io::{self, Cursor, Read, Write}, + net::{Ipv6Addr, SocketAddr, SocketAddrV6, TcpListener, TcpStream, UdpSocket}, + os::unix::io::{AsRawFd, RawFd}, + thread::spawn +}; use tungstenite::{client::AutoStream, connect, protocol::WebSocket, server::accept, stream::Stream, Message}; use url::Url; -pub fn run_proxy() { - let server = TcpListener::bind("127.0.0.1:9001").unwrap(); - for stream in server.incoming() { - info!("connect"); - spawn(move || { - let mut websocket = accept(stream.unwrap()).unwrap(); - loop { - let msg = websocket.read_message().unwrap(); - // We do not want to send back ping/pong messages. - if msg.is_binary() || msg.is_text() { - info!("msg"); - websocket.write_message(msg).unwrap(); + +macro_rules! io_error { + ($val:expr, $format:expr) => ( { + $val.map_err(|err| io::Error::new(io::ErrorKind::Other, format!($format, err))) + } ); + ($val:expr, $format:expr, $( $arg:expr ),+) => ( { + $val.map_err(|err| io::Error::new(io::ErrorKind::Other, format!($format, $( $arg ),+, err))) + } ); +} + +fn write_addr(addr: SocketAddr, mut out: W) -> Result<(), io::Error> { + let addr = mapped_addr(addr); + match mapped_addr(addr) { + SocketAddr::V6(addr) => { + out.write_all(&addr.ip().octets())?; + out.write_u16::(addr.port())?; + } + _ => unreachable!() + } + Ok(()) +} + +fn read_addr(mut r: R) -> Result { + let mut ip = [0u8; 16]; + r.read_exact(&mut ip)?; + let port = r.read_u16::()?; + let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0)); + Ok(addr) +} + +fn serve_proxy_connection(stream: TcpStream) -> Result<(), io::Error> { + let peer = stream.peer_addr()?; + info!("WS client {} connected", peer); + let mut websocket = io_error!(accept(stream), "Failed to initialize websocket with {}: {}", peer)?; + let udpsocket = UdpSocket::bind("[::]:0")?; + let mut msg = Vec::with_capacity(18); + let addr = udpsocket.local_addr()?; + info!("Listening on {} for peer {}", addr, peer); + write_addr(addr, &mut msg)?; + io_error!(websocket.write_message(Message::Binary(msg)), "Failed to write to ws connection: {}")?; + let websocketfd = websocket.get_ref().as_raw_fd(); + let poll = WaitImpl::new(websocketfd, udpsocket.as_raw_fd(), 60)?; + let mut buffer = [0; 65535]; + for evt in poll { + match evt { + WaitResult::Socket => { + info!("WS -> UDP"); + let msg = io_error!(websocket.read_message(), "Failed to read message on websocket {}: {}", peer)?; + info!("MSG: {}", msg); + match msg { + Message::Binary(data) => { + let dst = read_addr(Cursor::new(&data))?; + udpsocket.send_to(&data[18..], dst)?; + } + Message::Close(_) => return Ok(()), + _ => {} } } + WaitResult::Device => { + info!("UDP -> WS"); + let (size, addr) = udpsocket.recv_from(&mut buffer)?; + let mut data = Vec::with_capacity(18 + size); + write_addr(addr, &mut data)?; + data.write_all(&buffer[..size])?; + io_error!(websocket.write_message(Message::Binary(data)), "Failed to write to {}: {}", peer)?; + } + WaitResult::Timeout => { + info!("Sending ping"); + io_error!(websocket.write_message(Message::Ping(vec![])), "Failed to send ping: {}")?; + } + WaitResult::Error(err) => return Err(err) + } + } + Ok(()) +} + +pub fn run_proxy() -> Result<(), io::Error> { + // TODO: configurable listen + let server = TcpListener::bind("127.0.0.1:9001")?; + for stream in server.incoming() { + let stream = stream?; + let peer = stream.peer_addr()?; + spawn(move || { + if let Err(err) = serve_proxy_connection(stream) { + error!("Error on connection {}: {}", peer, err); + } }); } + Ok(()) } pub struct ProxyConnection { - url: String, addr: SocketAddr, socket: WebSocket } +impl ProxyConnection { + fn read_message(&mut self) -> Result, io::Error> { + loop { + if let Message::Binary(data) = io_error!(self.socket.read_message(), "Failed to read from ws proxy: {}")? { + return Ok(data) + } + } + } +} + impl AsRawFd for ProxyConnection { fn as_raw_fd(&self) -> RawFd { match self.socket.get_ref() { @@ -48,39 +131,28 @@ impl AsRawFd for ProxyConnection { impl Socket for ProxyConnection { fn listen(url: &str) -> Result { - let (mut socket, _) = connect(Url::parse(url).unwrap()).unwrap(); + let parsed_url = io_error!(Url::parse(url), "Invalid URL {}: {}", url)?; + let (socket, _) = io_error!(connect(parsed_url), "Failed to connect to URL {}: {}", url)?; let addr = "0.0.0.0:0".parse::().unwrap(); - Ok(ProxyConnection { url: addr.to_string(), addr, socket }) + let mut con = ProxyConnection { addr, socket }; + let addr_data = con.read_message()?; + con.addr = read_addr(Cursor::new(&addr_data))?; + Ok(con) } fn receive(&mut self, buffer: &mut MsgBuffer) -> Result { buffer.clear(); - match self.socket.read_message().unwrap() { - Message::Binary(data) => { - let mut cursor = Cursor::new(&data); - let mut ip = [0u8; 16]; - cursor.read_exact(&mut ip)?; - let port = cursor.read_u16::()?; - let addr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::from(ip), port, 0, 0)); - buffer.clone_from(&data[18..]); - Ok(addr) - }, - _ => unimplemented!() - } + let data = self.read_message()?; + let addr = read_addr(Cursor::new(&data))?; + buffer.clone_from(&data[18..]); + Ok(addr) } fn send(&mut self, data: &[u8], addr: SocketAddr) -> Result { let mut msg = Vec::with_capacity(data.len() + 18); - let addr = mapped_addr(addr); - match mapped_addr(addr) { - SocketAddr::V6(addr) => { - msg.write_all(&addr.ip().octets())?; - msg.write_u16::(addr.port())?; - }, - _ => unreachable!() - } + write_addr(addr, &mut msg)?; msg.write_all(data)?; - self.socket.write_message(Message::Binary(msg)).unwrap(); + io_error!(self.socket.write_message(Message::Binary(msg)), "Failed to write to ws proxy: {}")?; Ok(data.len()) }