mirror of https://github.com/dswd/vpncloud.git
385 lines
14 KiB
Python
385 lines
14 KiB
Python
import boto3
|
|
import atexit
|
|
import paramiko
|
|
import io
|
|
import time
|
|
import threading
|
|
import re
|
|
import json
|
|
import base64
|
|
import sys
|
|
import os
|
|
from datetime import date
|
|
|
|
MAX_WAIT = 300
|
|
CREATE = "***CREATE***"
|
|
|
|
def eprint(*args, **kwargs):
|
|
print(*args, file=sys.stderr, **kwargs)
|
|
|
|
def run_cmd(connection, cmd):
|
|
_stdin, stdout, stderr = connection.exec_command(cmd)
|
|
out = stdout.read().decode('utf-8')
|
|
err = stderr.read().decode('utf-8')
|
|
code = stdout.channel.recv_exit_status()
|
|
if code:
|
|
raise Exception("Command failed", code, out, err)
|
|
else:
|
|
return out, err
|
|
|
|
def upload(connection, local, remote):
|
|
ftp_client=connection.open_sftp()
|
|
ftp_client.put(local, remote)
|
|
ftp_client.close()
|
|
|
|
class SpotInstanceRequest:
|
|
def __init__(self, id):
|
|
self.id = id
|
|
def __str__(self):
|
|
return str(self.id)
|
|
|
|
|
|
class Node:
|
|
def __init__(self, instance, connection):
|
|
self.instance = instance
|
|
self.connection = connection
|
|
self.private_ip = instance.private_ip_address
|
|
self.public_ip = instance.public_ip_address
|
|
|
|
def run_cmd(self, cmd):
|
|
return run_cmd(self.connection, cmd)
|
|
|
|
def start_vpncloud(self, ip=None, crypto=None, password="test", device_type="tun", listen="3210", mode="normal", peers=[], claims=[]):
|
|
args = [
|
|
"--daemon",
|
|
"--no-port-forwarding",
|
|
"-t {}".format(device_type),
|
|
"-m {}".format(mode),
|
|
"-l {}".format(listen),
|
|
"--password '{}'".format(password)
|
|
]
|
|
if ip:
|
|
args.append("--ip {}".format(ip))
|
|
if crypto:
|
|
args.append("--algo {}".format(crypto))
|
|
for p in peers:
|
|
args.append("-c {}".format(p))
|
|
for c in claims:
|
|
args.append("--claim {}".format(c))
|
|
args = " ".join(args)
|
|
self.run_cmd("sudo vpncloud {}".format(args))
|
|
|
|
def stop_vpncloud(self, wait=True):
|
|
self.run_cmd("sudo killall vpncloud")
|
|
if wait:
|
|
time.sleep(3.0)
|
|
|
|
def ping(self, dst, size=100, count=10, interval=0.001):
|
|
(out, _) = self.run_cmd('sudo ping {dst} -c {count} -i {interval} -s {size} -U -q'.format(dst=dst, size=size, count=count, interval=interval))
|
|
match = re.search(r'([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)', out)
|
|
ping_min = float(match.group(1))
|
|
ping_avg = float(match.group(2))
|
|
ping_max = float(match.group(3))
|
|
match = re.search(r'(\d*)% packet loss', out)
|
|
pkt_loss = float(match.group(1))
|
|
return {
|
|
"rtt_min": ping_min,
|
|
"rtt_max": ping_max,
|
|
"rtt_avg": ping_avg,
|
|
"pkt_loss": pkt_loss
|
|
}
|
|
|
|
def start_iperf_server(self):
|
|
self.run_cmd('iperf3 -s -D')
|
|
time.sleep(0.1)
|
|
|
|
def stop_iperf_server(self):
|
|
self.run_cmd('killall iperf3')
|
|
|
|
def run_iperf(self, dst, duration):
|
|
(out, _) = self.run_cmd('iperf3 -c {dst} -t {duration} --json'.format(dst=dst, duration=duration))
|
|
data = json.loads(out)
|
|
return {
|
|
"throughput": data['end']['streams'][0]['receiver']['bits_per_second'],
|
|
"cpu_sender": data['end']['cpu_utilization_percent']['host_total'],
|
|
"cpu_receiver": data['end']['cpu_utilization_percent']['remote_total']
|
|
}
|
|
|
|
|
|
def find_ami(region, owner, name_pattern, arch='x86_64'):
|
|
ec2client = boto3.client('ec2', region_name=region)
|
|
response = ec2client.describe_images(Owners=[owner], Filters=[
|
|
{'Name': 'name', 'Values': [name_pattern]},
|
|
{'Name': 'architecture', 'Values': ['x86_64']}
|
|
])
|
|
try:
|
|
image = max(response['Images'], key=lambda i: i['CreationDate'])
|
|
return image['ImageId']
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
class EC2Environment:
|
|
def __init__(self, vpncloud_version, region, node_count, instance_type, vpncloud_file=None, use_spot=True, max_price=0.1, ami=('amazon', 'al2023-ami-*-hvm-*'), username="ec2-user", subnet=CREATE, keyname=CREATE, privatekey=CREATE, tag="vpncloud", cluster_nodes=False):
|
|
self.region = region
|
|
self.node_count = node_count
|
|
self.instance_type = instance_type
|
|
self.use_spot = use_spot
|
|
self.max_price = str(max_price)
|
|
if isinstance(ami, tuple):
|
|
owner, name = ami
|
|
self.ami = find_ami(region, owner, name)
|
|
assert self.ami
|
|
else:
|
|
self.ami = ami
|
|
self.username = username
|
|
self.vpncloud_version = vpncloud_version
|
|
self.vpncloud_file = vpncloud_file
|
|
self.cluster_nodes = cluster_nodes
|
|
self.resources = []
|
|
self.instances = []
|
|
self.connections = []
|
|
self.nodes = []
|
|
self.subnet = subnet
|
|
self.tag = tag
|
|
self.keyname = keyname
|
|
self.privatekey = privatekey
|
|
self.rsa_key = None
|
|
try:
|
|
eprint("Setting up resources...")
|
|
self.setup()
|
|
self.wait_until_ready()
|
|
for i in range(0, self.node_count):
|
|
self.nodes.append(Node(self.instances[i], self.connections[i]))
|
|
eprint("Setup done")
|
|
atexit.register(lambda : self.terminate())
|
|
eprint()
|
|
except:
|
|
eprint("Error, shutting down")
|
|
self.terminate()
|
|
raise
|
|
|
|
def track_resource(self, res):
|
|
self.resources.append(res)
|
|
eprint("\t{} {}".format(res.__class__.__name__, res.id if hasattr(res, "id") else ""))
|
|
if hasattr(res, "create_tags") and not hasattr(res, "name"):
|
|
res.create_tags(Tags=[{"Key": "Name", "Value": self.tag}])
|
|
|
|
def setup_vpc(self):
|
|
ec2 = boto3.resource('ec2', region_name=self.region)
|
|
ec2client = boto3.client('ec2', region_name=self.region)
|
|
|
|
vpc = ec2.create_vpc(CidrBlock='172.16.0.0/16')
|
|
self.track_resource(vpc)
|
|
vpc.wait_until_available()
|
|
ec2client.modify_vpc_attribute(VpcId=vpc.id, EnableDnsSupport={'Value': True})
|
|
ec2client.modify_vpc_attribute(VpcId=vpc.id, EnableDnsHostnames={'Value': True})
|
|
|
|
igw = ec2.create_internet_gateway()
|
|
self.track_resource(igw)
|
|
igw.attach_to_vpc(VpcId=vpc.id)
|
|
|
|
rtb = vpc.create_route_table()
|
|
self.track_resource(rtb)
|
|
rtb.create_route(DestinationCidrBlock='0.0.0.0/0', GatewayId=igw.id)
|
|
|
|
subnet = ec2.create_subnet(CidrBlock='172.16.1.0/24', VpcId=vpc.id)
|
|
self.track_resource(subnet)
|
|
rtb.associate_with_subnet(SubnetId=subnet.id)
|
|
|
|
self.subnet = subnet.id
|
|
|
|
|
|
def setup(self):
|
|
ec2 = boto3.resource('ec2', region_name=self.region)
|
|
ec2client = boto3.client('ec2', region_name=self.region)
|
|
|
|
if self.subnet == CREATE:
|
|
self.setup_vpc()
|
|
else:
|
|
eprint("\tUsing subnet {}".format(self.subnet))
|
|
|
|
vpc = ec2.Subnet(self.subnet).vpc
|
|
|
|
sg = ec2.create_security_group(GroupName='SSH-ONLY', Description='only allow SSH traffic', VpcId=vpc.id)
|
|
self.track_resource(sg)
|
|
sg.authorize_ingress(CidrIp='0.0.0.0/0', IpProtocol='tcp', FromPort=22, ToPort=22)
|
|
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='icmp', FromPort=-1, ToPort=-1)
|
|
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='tcp', FromPort=0, ToPort=65535)
|
|
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='udp', FromPort=0, ToPort=65535)
|
|
|
|
if self.keyname == CREATE:
|
|
key_pair = ec2.create_key_pair(KeyName="{}-keypair".format(self.tag))
|
|
self.track_resource(key_pair)
|
|
self.keyname = key_pair.name
|
|
self.privatekey = key_pair.key_material
|
|
self.rsa_key = paramiko.RSAKey.from_private_key(io.StringIO(self.privatekey))
|
|
|
|
placement = {}
|
|
if self.cluster_nodes:
|
|
placement_group = ec2.create_placement_group(GroupName="{}-placement".format(self.tag), Strategy="cluster")
|
|
self.track_resource(placement_group)
|
|
placement = { 'GroupName': placement_group.name }
|
|
|
|
userdata = """#cloud-config
|
|
packages:
|
|
- iperf3
|
|
- socat
|
|
"""
|
|
if not self.vpncloud_file:
|
|
userdata += """
|
|
runcmd:
|
|
- wget https://github.com/dswd/vpncloud/releases/download/v{version}/vpncloud_{version}.x86_64.rpm -O /tmp/vpncloud.rpm
|
|
- yum install -y /tmp/vpncloud.rpm
|
|
""".format(version=self.vpncloud_version)
|
|
|
|
if self.use_spot:
|
|
response = ec2client.request_spot_instances(
|
|
SpotPrice = self.max_price,
|
|
Type = "one-time",
|
|
InstanceCount = self.node_count,
|
|
LaunchSpecification = {
|
|
"ImageId": self.ami,
|
|
"InstanceType": self.instance_type,
|
|
"KeyName": key_pair.name,
|
|
"UserData": base64.b64encode(userdata.encode("ascii")).decode('ascii'),
|
|
"BlockDeviceMappings": [
|
|
{
|
|
"DeviceName": "/dev/xvda",
|
|
"Ebs": {
|
|
"DeleteOnTermination": True,
|
|
"VolumeType": "gp2",
|
|
}
|
|
}
|
|
],
|
|
"NetworkInterfaces": [
|
|
{
|
|
'SubnetId': self.subnet,
|
|
'DeviceIndex': 0,
|
|
'AssociatePublicIpAddress': True,
|
|
'Groups': [sg.group_id]
|
|
}
|
|
],
|
|
"Placement": placement
|
|
}
|
|
)
|
|
requests = []
|
|
for req in response['SpotInstanceRequests']:
|
|
request = SpotInstanceRequest(req['SpotInstanceRequestId'])
|
|
self.track_resource(request)
|
|
requests.append(request)
|
|
eprint("Waiting for spot instance requests")
|
|
waited = 0
|
|
self.instances = [None] * len(requests)
|
|
while waited < MAX_WAIT:
|
|
time.sleep(1.0)
|
|
for i, req in enumerate(requests):
|
|
response = ec2client.describe_spot_instance_requests(SpotInstanceRequestIds=[req.id])
|
|
data = response['SpotInstanceRequests'][0]
|
|
if 'InstanceId' in data:
|
|
self.instances[i] = ec2.Instance(data['InstanceId'])
|
|
self.track_resource(self.instances[i])
|
|
if min(map(bool, self.instances)):
|
|
break
|
|
if waited >= MAX_WAIT:
|
|
raise Exception("Waited too long")
|
|
else:
|
|
self.instances = ec2.create_instances(
|
|
ImageId=self.ami,
|
|
InstanceType=self.instance_type,
|
|
MaxCount=self.node_count,
|
|
MinCount=self.node_count,
|
|
NetworkInterfaces=[
|
|
{
|
|
'SubnetId': self.subnet,
|
|
'DeviceIndex': 0,
|
|
'AssociatePublicIpAddress': True,
|
|
'Groups': [sg.group_id]
|
|
}
|
|
],
|
|
Placement=placement,
|
|
UserData=userdata,
|
|
KeyName=key_pair.name
|
|
)
|
|
for instance in self.instances:
|
|
self.track_resource(instance)
|
|
|
|
def wait_until_ready(self):
|
|
waited = 0
|
|
eprint("Waiting for instances to start...")
|
|
for instance in self.instances:
|
|
instance.wait_until_running()
|
|
instance.reload()
|
|
eprint("Waiting for SSH to be ready...")
|
|
self.connections = [None] * len(self.instances)
|
|
while waited < MAX_WAIT:
|
|
for i, instance in enumerate(self.instances):
|
|
if self.connections[i]:
|
|
continue
|
|
try:
|
|
self.connections[i] = self._connect(instance)
|
|
except:
|
|
pass
|
|
if min(map(bool, self.connections)):
|
|
break
|
|
time.sleep(1.0)
|
|
waited += 1
|
|
eprint("Waiting for instances to finish setup...")
|
|
ready = [False] * len(self.connections)
|
|
while waited < MAX_WAIT:
|
|
for i, con in enumerate(self.connections):
|
|
if ready[i]:
|
|
continue
|
|
try:
|
|
run_cmd(con, 'test -f /var/lib/cloud/instance/boot-finished')
|
|
ready[i] = True
|
|
except:
|
|
pass
|
|
if min(map(bool, ready)):
|
|
break
|
|
time.sleep(1.0)
|
|
waited += 1
|
|
if waited >= MAX_WAIT:
|
|
raise Exception("Waited too long")
|
|
if self.vpncloud_file:
|
|
eprint("Uploading vpncloud binary")
|
|
for con in self.connections:
|
|
upload(con, self.vpncloud_file, 'vpncloud')
|
|
run_cmd(con, 'chmod +x vpncloud')
|
|
run_cmd(con, 'sudo mv vpncloud /usr/bin/vpncloud')
|
|
|
|
|
|
def terminate(self):
|
|
if not self.resources:
|
|
return
|
|
eprint("Closing connections...")
|
|
for con in self.connections:
|
|
if con:
|
|
con.close()
|
|
self.connections = []
|
|
eprint("Terminating instances...")
|
|
for instance in self.instances:
|
|
instance.terminate()
|
|
for instance in self.instances:
|
|
eprint("\t{}".format(instance.id))
|
|
instance.wait_until_terminated()
|
|
self.instances = []
|
|
eprint("Deleting resources...")
|
|
ec2client = boto3.client('ec2', region_name=self.region)
|
|
for res in reversed(self.resources):
|
|
eprint("\t{} {}".format(res.__class__.__name__, res.id if hasattr(res, "id") else ""))
|
|
if isinstance(res, SpotInstanceRequest):
|
|
ec2client.cancel_spot_instance_requests(SpotInstanceRequestIds=[res.id])
|
|
if hasattr(res, "attachments"):
|
|
for a in res.attachments:
|
|
res.detach_from_vpc(VpcId=a['VpcId'])
|
|
if hasattr(res, "delete"):
|
|
res.delete()
|
|
self.resources = []
|
|
|
|
def _connect(self, instance):
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect(hostname=instance.public_dns_name, username=self.username, pkey=self.rsa_key, timeout=1.0, banner_timeout=1.0)
|
|
return client
|