vpncloud/contrib/aws/common.py

385 lines
14 KiB
Python

import boto3
import atexit
import paramiko
import io
import time
import threading
import re
import json
import base64
import sys
import os
from datetime import date
MAX_WAIT = 300
CREATE = "***CREATE***"
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def run_cmd(connection, cmd):
_stdin, stdout, stderr = connection.exec_command(cmd)
out = stdout.read().decode('utf-8')
err = stderr.read().decode('utf-8')
code = stdout.channel.recv_exit_status()
if code:
raise Exception("Command failed", code, out, err)
else:
return out, err
def upload(connection, local, remote):
ftp_client=connection.open_sftp()
ftp_client.put(local, remote)
ftp_client.close()
class SpotInstanceRequest:
def __init__(self, id):
self.id = id
def __str__(self):
return str(self.id)
class Node:
def __init__(self, instance, connection):
self.instance = instance
self.connection = connection
self.private_ip = instance.private_ip_address
self.public_ip = instance.public_ip_address
def run_cmd(self, cmd):
return run_cmd(self.connection, cmd)
def start_vpncloud(self, ip=None, crypto=None, password="test", device_type="tun", listen="3210", mode="normal", peers=[], claims=[]):
args = [
"--daemon",
"--no-port-forwarding",
"-t {}".format(device_type),
"-m {}".format(mode),
"-l {}".format(listen),
"--password '{}'".format(password)
]
if ip:
args.append("--ip {}".format(ip))
if crypto:
args.append("--algo {}".format(crypto))
for p in peers:
args.append("-c {}".format(p))
for c in claims:
args.append("--claim {}".format(c))
args = " ".join(args)
self.run_cmd("sudo vpncloud {}".format(args))
def stop_vpncloud(self, wait=True):
self.run_cmd("sudo killall vpncloud")
if wait:
time.sleep(3.0)
def ping(self, dst, size=100, count=10, interval=0.001):
(out, _) = self.run_cmd('sudo ping {dst} -c {count} -i {interval} -s {size} -U -q'.format(dst=dst, size=size, count=count, interval=interval))
match = re.search(r'([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)/([\d]*\.[\d]*)', out)
ping_min = float(match.group(1))
ping_avg = float(match.group(2))
ping_max = float(match.group(3))
match = re.search(r'(\d*)% packet loss', out)
pkt_loss = float(match.group(1))
return {
"rtt_min": ping_min,
"rtt_max": ping_max,
"rtt_avg": ping_avg,
"pkt_loss": pkt_loss
}
def start_iperf_server(self):
self.run_cmd('iperf3 -s -D')
time.sleep(0.1)
def stop_iperf_server(self):
self.run_cmd('killall iperf3')
def run_iperf(self, dst, duration):
(out, _) = self.run_cmd('iperf3 -c {dst} -t {duration} --json'.format(dst=dst, duration=duration))
data = json.loads(out)
return {
"throughput": data['end']['streams'][0]['receiver']['bits_per_second'],
"cpu_sender": data['end']['cpu_utilization_percent']['host_total'],
"cpu_receiver": data['end']['cpu_utilization_percent']['remote_total']
}
def find_ami(region, owner, name_pattern, arch='x86_64'):
ec2client = boto3.client('ec2', region_name=region)
response = ec2client.describe_images(Owners=[owner], Filters=[
{'Name': 'name', 'Values': [name_pattern]},
{'Name': 'architecture', 'Values': ['x86_64']}
])
try:
image = max(response['Images'], key=lambda i: i['CreationDate'])
return image['ImageId']
except ValueError:
return None
class EC2Environment:
def __init__(self, vpncloud_version, region, node_count, instance_type, vpncloud_file=None, use_spot=True, max_price=0.1, ami=('amazon', 'al2023-ami-*-hvm-*'), username="ec2-user", subnet=CREATE, keyname=CREATE, privatekey=CREATE, tag="vpncloud", cluster_nodes=False):
self.region = region
self.node_count = node_count
self.instance_type = instance_type
self.use_spot = use_spot
self.max_price = str(max_price)
if isinstance(ami, tuple):
owner, name = ami
self.ami = find_ami(region, owner, name)
assert self.ami
else:
self.ami = ami
self.username = username
self.vpncloud_version = vpncloud_version
self.vpncloud_file = vpncloud_file
self.cluster_nodes = cluster_nodes
self.resources = []
self.instances = []
self.connections = []
self.nodes = []
self.subnet = subnet
self.tag = tag
self.keyname = keyname
self.privatekey = privatekey
self.rsa_key = None
try:
eprint("Setting up resources...")
self.setup()
self.wait_until_ready()
for i in range(0, self.node_count):
self.nodes.append(Node(self.instances[i], self.connections[i]))
eprint("Setup done")
atexit.register(lambda : self.terminate())
eprint()
except:
eprint("Error, shutting down")
self.terminate()
raise
def track_resource(self, res):
self.resources.append(res)
eprint("\t{} {}".format(res.__class__.__name__, res.id if hasattr(res, "id") else ""))
if hasattr(res, "create_tags") and not hasattr(res, "name"):
res.create_tags(Tags=[{"Key": "Name", "Value": self.tag}])
def setup_vpc(self):
ec2 = boto3.resource('ec2', region_name=self.region)
ec2client = boto3.client('ec2', region_name=self.region)
vpc = ec2.create_vpc(CidrBlock='172.16.0.0/16')
self.track_resource(vpc)
vpc.wait_until_available()
ec2client.modify_vpc_attribute(VpcId=vpc.id, EnableDnsSupport={'Value': True})
ec2client.modify_vpc_attribute(VpcId=vpc.id, EnableDnsHostnames={'Value': True})
igw = ec2.create_internet_gateway()
self.track_resource(igw)
igw.attach_to_vpc(VpcId=vpc.id)
rtb = vpc.create_route_table()
self.track_resource(rtb)
rtb.create_route(DestinationCidrBlock='0.0.0.0/0', GatewayId=igw.id)
subnet = ec2.create_subnet(CidrBlock='172.16.1.0/24', VpcId=vpc.id)
self.track_resource(subnet)
rtb.associate_with_subnet(SubnetId=subnet.id)
self.subnet = subnet.id
def setup(self):
ec2 = boto3.resource('ec2', region_name=self.region)
ec2client = boto3.client('ec2', region_name=self.region)
if self.subnet == CREATE:
self.setup_vpc()
else:
eprint("\tUsing subnet {}".format(self.subnet))
vpc = ec2.Subnet(self.subnet).vpc
sg = ec2.create_security_group(GroupName='SSH-ONLY', Description='only allow SSH traffic', VpcId=vpc.id)
self.track_resource(sg)
sg.authorize_ingress(CidrIp='0.0.0.0/0', IpProtocol='tcp', FromPort=22, ToPort=22)
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='icmp', FromPort=-1, ToPort=-1)
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='tcp', FromPort=0, ToPort=65535)
sg.authorize_ingress(CidrIp='172.16.1.0/24', IpProtocol='udp', FromPort=0, ToPort=65535)
if self.keyname == CREATE:
key_pair = ec2.create_key_pair(KeyName="{}-keypair".format(self.tag))
self.track_resource(key_pair)
self.keyname = key_pair.name
self.privatekey = key_pair.key_material
self.rsa_key = paramiko.RSAKey.from_private_key(io.StringIO(self.privatekey))
placement = {}
if self.cluster_nodes:
placement_group = ec2.create_placement_group(GroupName="{}-placement".format(self.tag), Strategy="cluster")
self.track_resource(placement_group)
placement = { 'GroupName': placement_group.name }
userdata = """#cloud-config
packages:
- iperf3
- socat
"""
if not self.vpncloud_file:
userdata += """
runcmd:
- wget https://github.com/dswd/vpncloud/releases/download/v{version}/vpncloud_{version}.x86_64.rpm -O /tmp/vpncloud.rpm
- yum install -y /tmp/vpncloud.rpm
""".format(version=self.vpncloud_version)
if self.use_spot:
response = ec2client.request_spot_instances(
SpotPrice = self.max_price,
Type = "one-time",
InstanceCount = self.node_count,
LaunchSpecification = {
"ImageId": self.ami,
"InstanceType": self.instance_type,
"KeyName": key_pair.name,
"UserData": base64.b64encode(userdata.encode("ascii")).decode('ascii'),
"BlockDeviceMappings": [
{
"DeviceName": "/dev/xvda",
"Ebs": {
"DeleteOnTermination": True,
"VolumeType": "gp2",
}
}
],
"NetworkInterfaces": [
{
'SubnetId': self.subnet,
'DeviceIndex': 0,
'AssociatePublicIpAddress': True,
'Groups': [sg.group_id]
}
],
"Placement": placement
}
)
requests = []
for req in response['SpotInstanceRequests']:
request = SpotInstanceRequest(req['SpotInstanceRequestId'])
self.track_resource(request)
requests.append(request)
eprint("Waiting for spot instance requests")
waited = 0
self.instances = [None] * len(requests)
while waited < MAX_WAIT:
time.sleep(1.0)
for i, req in enumerate(requests):
response = ec2client.describe_spot_instance_requests(SpotInstanceRequestIds=[req.id])
data = response['SpotInstanceRequests'][0]
if 'InstanceId' in data:
self.instances[i] = ec2.Instance(data['InstanceId'])
self.track_resource(self.instances[i])
if min(map(bool, self.instances)):
break
if waited >= MAX_WAIT:
raise Exception("Waited too long")
else:
self.instances = ec2.create_instances(
ImageId=self.ami,
InstanceType=self.instance_type,
MaxCount=self.node_count,
MinCount=self.node_count,
NetworkInterfaces=[
{
'SubnetId': self.subnet,
'DeviceIndex': 0,
'AssociatePublicIpAddress': True,
'Groups': [sg.group_id]
}
],
Placement=placement,
UserData=userdata,
KeyName=key_pair.name
)
for instance in self.instances:
self.track_resource(instance)
def wait_until_ready(self):
waited = 0
eprint("Waiting for instances to start...")
for instance in self.instances:
instance.wait_until_running()
instance.reload()
eprint("Waiting for SSH to be ready...")
self.connections = [None] * len(self.instances)
while waited < MAX_WAIT:
for i, instance in enumerate(self.instances):
if self.connections[i]:
continue
try:
self.connections[i] = self._connect(instance)
except:
pass
if min(map(bool, self.connections)):
break
time.sleep(1.0)
waited += 1
eprint("Waiting for instances to finish setup...")
ready = [False] * len(self.connections)
while waited < MAX_WAIT:
for i, con in enumerate(self.connections):
if ready[i]:
continue
try:
run_cmd(con, 'test -f /var/lib/cloud/instance/boot-finished')
ready[i] = True
except:
pass
if min(map(bool, ready)):
break
time.sleep(1.0)
waited += 1
if waited >= MAX_WAIT:
raise Exception("Waited too long")
if self.vpncloud_file:
eprint("Uploading vpncloud binary")
for con in self.connections:
upload(con, self.vpncloud_file, 'vpncloud')
run_cmd(con, 'chmod +x vpncloud')
run_cmd(con, 'sudo mv vpncloud /usr/bin/vpncloud')
def terminate(self):
if not self.resources:
return
eprint("Closing connections...")
for con in self.connections:
if con:
con.close()
self.connections = []
eprint("Terminating instances...")
for instance in self.instances:
instance.terminate()
for instance in self.instances:
eprint("\t{}".format(instance.id))
instance.wait_until_terminated()
self.instances = []
eprint("Deleting resources...")
ec2client = boto3.client('ec2', region_name=self.region)
for res in reversed(self.resources):
eprint("\t{} {}".format(res.__class__.__name__, res.id if hasattr(res, "id") else ""))
if isinstance(res, SpotInstanceRequest):
ec2client.cancel_spot_instance_requests(SpotInstanceRequestIds=[res.id])
if hasattr(res, "attachments"):
for a in res.attachments:
res.detach_from_vpc(VpcId=a['VpcId'])
if hasattr(res, "delete"):
res.delete()
self.resources = []
def _connect(self, instance):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(hostname=instance.public_dns_name, username=self.username, pkey=self.rsa_key, timeout=1.0, banner_timeout=1.0)
return client