libvirt integration
This commit is contained in:
parent
42c71a201d
commit
1ee0d2f740
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.idea
|
||||
__pycache__
|
||||
|
80
cloud_config.py
Normal file
80
cloud_config.py
Normal file
@ -0,0 +1,80 @@
|
||||
import base64
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Iterable, Tuple
|
||||
|
||||
|
||||
def gen_cloud_config(name: str,
|
||||
cloud_init_path: Path,
|
||||
wireguard_config: str,
|
||||
ssh_authorized_keys: Optional[Iterable[str]] = None
|
||||
) -> Tuple[Path, Path]:
|
||||
if ssh_authorized_keys is not None:
|
||||
ssh_keys = 'ssh_authorized_keys\n' + \
|
||||
'\n'.join((f'- {key}' for key in ssh_authorized_keys))
|
||||
else:
|
||||
ssh_keys = ''
|
||||
|
||||
user_data = f"""
|
||||
#cloud-config
|
||||
chpasswd: {{ expire: False }}
|
||||
password: root
|
||||
{ssh_keys}
|
||||
ssh_pwauth: False
|
||||
timezone: Europe/Berlin
|
||||
users:
|
||||
- default
|
||||
|
||||
package_update: true
|
||||
packages:
|
||||
- wireguard
|
||||
|
||||
write_files:
|
||||
- encoding: b64
|
||||
content: {base64.b64encode(wireguard_config.encode())}
|
||||
owner: root:root
|
||||
path: /etc/wireguard/wg0.conf
|
||||
permissions: 0600
|
||||
"""
|
||||
|
||||
meta_data = f"""
|
||||
local-hostname: {name}
|
||||
"""
|
||||
|
||||
user_data_path = cloud_init_path / 'user_data'
|
||||
meta_data_path = cloud_init_path / 'meta_data'
|
||||
|
||||
with open(user_data_path, 'w+') as handle:
|
||||
handle.write(user_data)
|
||||
with open(meta_data_path, 'w+') as handle:
|
||||
handle.write(meta_data)
|
||||
|
||||
return user_data_path, meta_data_path
|
||||
|
||||
|
||||
def _gen_cloudinit_iso_image(path: Path, user_data_path: Path, meta_data_path: Path, volume_label: str = 'cloud_init'):
|
||||
genisoimage_executable = shutil.which('genisoimage')
|
||||
if genisoimage_executable is None:
|
||||
raise FileNotFoundError('could not locate genisoimage executable!')
|
||||
|
||||
command = (
|
||||
genisoimage_executable,
|
||||
'-output', str(path),
|
||||
'-V', volume_label,
|
||||
'-r',
|
||||
'-J',
|
||||
user_data_path,
|
||||
meta_data_path
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True
|
||||
)
|
||||
except subprocess.CalledProcessError as ex:
|
||||
print(f'genisoimage error, stderr was: {ex.stderr}, stdout was: {ex.stdout}')
|
||||
raise ex
|
144
main.py
144
main.py
@ -1,108 +1,34 @@
|
||||
import base64
|
||||
import itertools
|
||||
import os
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Network, IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Dict, Tuple, List, Optional
|
||||
from time import sleep
|
||||
from typing import List, Optional
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
import libvirt
|
||||
from libvirt import virDomain, virStoragePool
|
||||
|
||||
from virt import spawn_vm
|
||||
from wireguard_mesh import create_mesh
|
||||
|
||||
|
||||
def _spawn_vm():
|
||||
pass
|
||||
def cleanup(vms: List[virDomain], storage_pool: Optional[virStoragePool] = None):
|
||||
for vm in vms:
|
||||
try:
|
||||
vm.destroy()
|
||||
except libvirt.libvirtError as ex:
|
||||
print(f'error shutting down {vm.name()}: {ex.get_error_message()}')
|
||||
# vm.undefine()
|
||||
|
||||
if storage_pool is not None:
|
||||
|
||||
@dataclass
|
||||
class MeshNodeConfig:
|
||||
name: str
|
||||
ip: str
|
||||
peers: List['MeshNodeConfig'] = ()
|
||||
private_key: str = base64.b64encode(
|
||||
X25519PrivateKey.generate().private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
).decode()
|
||||
public_key: str = base64.b64encode(
|
||||
X25519PrivateKey.from_private_bytes(
|
||||
base64.b64decode(
|
||||
private_key.encode()
|
||||
)
|
||||
).public_key().public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
).decode()
|
||||
psk: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
listen_port: int = 51515
|
||||
|
||||
@property
|
||||
def config(self) -> str:
|
||||
config = '[Interface]\n' + \
|
||||
f'Address = {self.ip}\n' + \
|
||||
f'PrivateKey = {self.private_key}\n' + \
|
||||
f'ListenPort = {self.listen_port}\n\n'
|
||||
|
||||
for peer in self.peers:
|
||||
config += '[Peer]\n' + \
|
||||
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
|
||||
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
|
||||
f'PublicKey = {peer.public_key}\n' + \
|
||||
f'AllowedIPs = {peer.ip}/32\n' + \
|
||||
f'PersistentKeepalive = 25\n\n'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _create_node(
|
||||
node_name: str,
|
||||
ip_pool: Iterator[IPv4Address],
|
||||
psk: Optional[str] = None
|
||||
) -> Tuple[str, MeshNodeConfig]:
|
||||
return node_name, MeshNodeConfig(
|
||||
name=node_name,
|
||||
peers=[],
|
||||
ip=str(next(ip_pool)),
|
||||
psk=psk
|
||||
)
|
||||
|
||||
|
||||
# TODO this should take a graph / adjacency list as argument and set peers based on that.
|
||||
# currently all nodes are connected to all other nodes, which scales with O(n!) and
|
||||
# takes significant amounts of time and compute for even small n < 10.
|
||||
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
|
||||
for src, dst in itertools.permutations(nodes.keys(), 2):
|
||||
nodes[dst].peers.append(nodes[src])
|
||||
|
||||
|
||||
def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]:
|
||||
ip_pool = IPv4Network(ip_subnet).hosts()
|
||||
node_names = (
|
||||
f'node-{idx}' for idx in range(1, node_count + 1)
|
||||
)
|
||||
psk = base64.b64encode(secrets.token_bytes(32)).decode()
|
||||
|
||||
nodes = dict(
|
||||
map(
|
||||
_create_node,
|
||||
node_names,
|
||||
itertools.repeat(ip_pool),
|
||||
itertools.repeat(psk)
|
||||
)
|
||||
)
|
||||
|
||||
_connect_nodes(nodes)
|
||||
|
||||
return nodes
|
||||
for volume in storage_pool.listAllVolumes():
|
||||
print(f'deleting volume: {volume.path()}')
|
||||
volume.delete()
|
||||
storage_pool.destroy()
|
||||
# storage_pool.undefine()
|
||||
|
||||
|
||||
def main():
|
||||
nodes = _create_mesh()
|
||||
nodes = create_mesh()
|
||||
print(nodes)
|
||||
|
||||
tmp_dir = '/tmp/node_configs'
|
||||
@ -119,6 +45,34 @@ def main():
|
||||
with open(p / f'{node.name}.conf', 'w+') as fh:
|
||||
fh.write(node.config)
|
||||
|
||||
libvirt_conn_uri = 'qemu:///system'
|
||||
conn = libvirt.open(libvirt_conn_uri)
|
||||
|
||||
# TODO map to all
|
||||
storage_pool_path = Path('/tmp/test')
|
||||
vms: List[virDomain] = [
|
||||
spawn_vm(conn, node.name, storage_pool_path, node.config)
|
||||
for node in nodes.values()
|
||||
]
|
||||
|
||||
print('vms created, resuming cpus...')
|
||||
|
||||
for vm in vms:
|
||||
vm.resume()
|
||||
|
||||
try:
|
||||
while True:
|
||||
sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print('caught SIGTERM, finishing...')
|
||||
|
||||
storage_pool: Optional[virStoragePool] = None
|
||||
storage_pool_name = 'wireguard_test'
|
||||
if storage_pool_name in (p.name() for p in conn.listAllStoragePools()):
|
||||
storage_pool = conn.storagePoolLookupByName(storage_pool_name)
|
||||
|
||||
cleanup(vms, storage_pool)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -1 +1,2 @@
|
||||
cryptography~=3.4.8
|
||||
cryptography~=3.4.8
|
||||
libvirt-python~=7.7.0
|
||||
|
291
virt.py
Normal file
291
virt.py
Normal file
@ -0,0 +1,291 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Iterable
|
||||
|
||||
import libvirt
|
||||
from libvirt import virConnect, virDomain, virStoragePool
|
||||
|
||||
from cloud_config import gen_cloud_config, _gen_cloudinit_iso_image
|
||||
|
||||
|
||||
def _gen_libvirt_volume_config(name: str, path: Path, backing_image: Path, capacity: int = 1000,
|
||||
image_format: str = 'qcow2') -> str:
|
||||
"""
|
||||
creates the XML config for a volume that will hold the data of a vm.
|
||||
:param name: volume name
|
||||
:param path: path to save volume to
|
||||
:param backing_image: path to a backing image to base this one on
|
||||
:param capacity: capacity in MiB
|
||||
:param image_format: qcow2, iso, raw or alike.
|
||||
A list of valid values can be found here: https://libvirt.org/storage.html#StorageBackendFS
|
||||
:return: XML config as string
|
||||
"""
|
||||
|
||||
xml_config = f"""
|
||||
<volume type='file'>
|
||||
<name>{name}.{image_format}</name>
|
||||
<capacity unit='MiB'>{capacity}</capacity>
|
||||
<allocation unit='MiB'>0</allocation>
|
||||
<target>
|
||||
<path>{path}</path>
|
||||
<format type='{image_format}'/>
|
||||
</target>
|
||||
<backingStore>
|
||||
<path>{backing_image}</path>
|
||||
<format type='qcow2'/>
|
||||
</backingStore>
|
||||
</volume>
|
||||
"""
|
||||
|
||||
return xml_config
|
||||
|
||||
|
||||
def _gen_libvirt_storage_pool_config(name: str, path: Path) -> str:
|
||||
"""
|
||||
create the libvirt XML config for a storage pool with directory backend.
|
||||
:param name: name of the storage pool
|
||||
:param path: the path to save images in the pool to
|
||||
:return: XML config as string
|
||||
"""
|
||||
|
||||
xml_config = f"""
|
||||
<pool type='dir'>
|
||||
<name>{name}</name>
|
||||
<target>
|
||||
<path>{path}</path>
|
||||
</target>
|
||||
</pool>
|
||||
"""
|
||||
|
||||
return xml_config
|
||||
|
||||
|
||||
def _gen_libvirt_config(name: str, hdd_path: Path, cdrom_path: Path, cpu: int = 1, mem: int = 512) -> str:
|
||||
"""
|
||||
return libvirt XML config for a virtual machine
|
||||
:param name: name of the machine
|
||||
:param hdd_path: path to the image to use as primary disk
|
||||
:param cdrom_path: path to the cloud-init image
|
||||
:param cpu: number of cpus
|
||||
:param mem: amount of ram in MiB
|
||||
:return: XML config as string
|
||||
"""
|
||||
|
||||
vm_uuid = uuid.uuid4()
|
||||
|
||||
xml_config = f"""
|
||||
<domain type="kvm">
|
||||
<name>{name}</name>
|
||||
<uuid>{vm_uuid}</uuid>
|
||||
<metadata>
|
||||
<libosinfo:libosinfo xmlns:libosinfo="http://libosinfo.org/xmlns/libvirt/domain/1.0">
|
||||
<libosinfo:os id="http://ubuntu.com/ubuntu/20.04"/>
|
||||
</libosinfo:libosinfo>
|
||||
</metadata>
|
||||
<memory unit="MiB">{mem}</memory>
|
||||
<currentMemory unit="MiB">{mem}</currentMemory>
|
||||
<vcpu placement="static">{cpu}</vcpu>
|
||||
<os>
|
||||
<type arch="x86_64" machine="pc-q35-5.2">hvm</type>
|
||||
</os>
|
||||
<features>
|
||||
<acpi/>
|
||||
<apic/>
|
||||
<vmport state="off"/>
|
||||
</features>
|
||||
<cpu mode="host-model" check="partial"/>
|
||||
<clock offset="utc">
|
||||
<timer name="rtc" tickpolicy="catchup"/>
|
||||
<timer name="pit" tickpolicy="delay"/>
|
||||
<timer name="hpet" present="no"/>
|
||||
</clock>
|
||||
<on_poweroff>destroy</on_poweroff>
|
||||
<on_reboot>restart</on_reboot>
|
||||
<on_crash>destroy</on_crash>
|
||||
<pm>
|
||||
<suspend-to-mem enabled="no"/>
|
||||
<suspend-to-disk enabled="no"/>
|
||||
</pm>
|
||||
<devices>
|
||||
<emulator>/usr/bin/qemu-system-x86_64</emulator>
|
||||
<console type='pty'>
|
||||
<target type='serial'/>
|
||||
</console>
|
||||
|
||||
<video>
|
||||
<model type="virtio" heads="1" primary="yes">
|
||||
<acceleration accel3d="yes"/>
|
||||
</model>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x01" function="0x0"/>
|
||||
</video>
|
||||
<channel type="spicevmc">
|
||||
<target type="virtio" name="com.redhat.spice.0"/>
|
||||
<address type="virtio-serial" controller="0" bus="0" port="2"/>
|
||||
</channel>
|
||||
<channel type="unix">
|
||||
<target type="virtio" name="org.qemu.guest_agent.0"/>
|
||||
<address type="virtio-serial" controller="0" bus="0" port="1"/>
|
||||
</channel>
|
||||
<graphics type="spice" autoport="yes">
|
||||
<listen type="address"/>
|
||||
<image compression="off"/>
|
||||
</graphics>
|
||||
|
||||
<disk type="file" device="disk">
|
||||
<driver name="qemu" type="qcow2"/>
|
||||
<source file="{hdd_path}"/>
|
||||
<target dev="vda" bus="virtio"/>
|
||||
<address type="pci" domain="0x0000" bus="0x03" slot="0x00" function="0x0"/>
|
||||
<boot order="1"/>
|
||||
</disk>
|
||||
|
||||
<disk type="file" device="cdrom">
|
||||
<driver name="qemu" type="raw"/>
|
||||
<source file="{cdrom_path}"/>
|
||||
<target dev="sda" bus="sata"/>
|
||||
<readonly/>
|
||||
<address type="drive" controller="0" bus="0" target="0" unit="0"/>
|
||||
<boot order="2"/>
|
||||
</disk>
|
||||
|
||||
<controller type="usb" index="0" model="ich9-ehci1">
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x7"/>
|
||||
</controller>
|
||||
<controller type="usb" index="0" model="ich9-uhci1">
|
||||
<master startport="0"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x0" multifunction="on"/>
|
||||
</controller>
|
||||
<controller type="usb" index="0" model="ich9-uhci2">
|
||||
<master startport="2"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x1"/>
|
||||
</controller>
|
||||
<controller type="usb" index="0" model="ich9-uhci3">
|
||||
<master startport="4"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x2"/>
|
||||
</controller>
|
||||
<controller type="sata" index="0">
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x1f" function="0x2"/>
|
||||
</controller>
|
||||
<controller type="pci" index="0" model="pcie-root"/>
|
||||
<controller type="pci" index="1" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="1" port="0x10"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x0" multifunction="on"/>
|
||||
</controller>
|
||||
<controller type="pci" index="2" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="2" port="0x11"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x1"/>
|
||||
</controller>
|
||||
<controller type="pci" index="3" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="3" port="0x12"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x2"/>
|
||||
</controller>
|
||||
<controller type="pci" index="4" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="4" port="0x13"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x3"/>
|
||||
</controller>
|
||||
<controller type="pci" index="5" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="5" port="0x14"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x4"/>
|
||||
</controller>
|
||||
<controller type="pci" index="6" model="pcie-root-port">
|
||||
<model name="pcie-root-port"/>
|
||||
<target chassis="6" port="0x15"/>
|
||||
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x5"/>
|
||||
</controller>
|
||||
<controller type="virtio-serial" index="0">
|
||||
<address type="pci" domain="0x0000" bus="0x02" slot="0x00" function="0x0"/>
|
||||
</controller>
|
||||
<interface type="network">
|
||||
<source network="default"/>
|
||||
<model type="virtio"/>
|
||||
<address type="pci" domain="0x0000" bus="0x01" slot="0x00" function="0x0"/>
|
||||
</interface>
|
||||
<input type="mouse" bus="ps2"/>
|
||||
<input type="keyboard" bus="ps2"/>
|
||||
<audio id="1" type="spice"/>
|
||||
<memballoon model="virtio">
|
||||
<address type="pci" domain="0x0000" bus="0x04" slot="0x00" function="0x0"/>
|
||||
</memballoon>
|
||||
<rng model="virtio">
|
||||
<backend model="random">/dev/urandom</backend>
|
||||
<address type="pci" domain="0x0000" bus="0x05" slot="0x00" function="0x0"/>
|
||||
</rng>
|
||||
</devices>
|
||||
</domain>
|
||||
"""
|
||||
|
||||
return xml_config
|
||||
|
||||
|
||||
def spawn_vm(connection: virConnect,
|
||||
name: str,
|
||||
storage_pool_path: Path,
|
||||
wireguard_config: str,
|
||||
ssh_authorized_keys: Optional[Iterable[str]] = None,
|
||||
storage_pool_name: str = 'wireguard_test'
|
||||
) -> virDomain:
|
||||
"""
|
||||
spawm a VM provisioned via cloud-init that is a mesh node
|
||||
:param connection: libvirt connection object
|
||||
:param name: name of the VM
|
||||
:param storage_pool_path: path storage pool objects are saved in
|
||||
:param wireguard_config: wireguard config file as string
|
||||
:param ssh_authorized_keys: ssh public keys as list to add to the image
|
||||
:param storage_pool_name: name of the storage pool (will be created if nonexistent)
|
||||
:return: a libvirt virDomain object representing a virtual machine
|
||||
"""
|
||||
|
||||
# if storage pool does not exist, create it. Otherwise, fetch it.
|
||||
storage_pools = (p.name() for p in connection.listAllStoragePools())
|
||||
if storage_pool_name not in storage_pools:
|
||||
storage_pool_path.mkdir(exist_ok=True)
|
||||
storage_pool_xml_config = _gen_libvirt_storage_pool_config(storage_pool_name, storage_pool_path)
|
||||
storage_pool: virStoragePool = connection.storagePoolCreateXML(storage_pool_xml_config)
|
||||
else:
|
||||
storage_pool = connection.storagePoolLookupByName(storage_pool_name)
|
||||
|
||||
image_format = 'qcow2'
|
||||
volume_path = storage_pool_path / f'{name}.{image_format}'
|
||||
# if image exists, fetch and wipe it. Otherwise, create it.
|
||||
if str(volume_path) in (v.path() for v in storage_pool.listAllVolumes()):
|
||||
volume = storage_pool.storageVolLookupByName(volume_path.name)
|
||||
volume.wipe()
|
||||
else:
|
||||
backing_file_path = Path('/tmp/focal-server-cloudimg-amd64.qcow2')
|
||||
volume_xml_config = _gen_libvirt_volume_config(name, volume_path, backing_file_path, image_format=image_format)
|
||||
# volume_create_flags = libvirt.VIR_STORAGE_VOL_CREATE_PREALLOC_METADATA
|
||||
storage_pool.createXML(volume_xml_config)
|
||||
|
||||
# if the cloud-init images exist, delete them.
|
||||
cloud_init_volume_name = f'cloudinit-{name}.iso'
|
||||
try:
|
||||
cloud_init_volume = storage_pool.storageVolLookupByName(cloud_init_volume_name)
|
||||
cloud_init_volume.delete()
|
||||
except libvirt.libvirtError:
|
||||
pass
|
||||
|
||||
# recreate cloud-init images
|
||||
cloud_init_iso_path = storage_pool_path / cloud_init_volume_name
|
||||
user_data_path, meta_data_path = gen_cloud_config(name, Path('/tmp/'), wireguard_config, ssh_authorized_keys)
|
||||
_gen_cloudinit_iso_image(cloud_init_iso_path, user_data_path, meta_data_path)
|
||||
|
||||
# generate VM XML config
|
||||
vm_xml_config = _gen_libvirt_config(
|
||||
name,
|
||||
volume_path,
|
||||
cloud_init_iso_path,
|
||||
cpu=1,
|
||||
mem=512
|
||||
)
|
||||
|
||||
vm_create_flags = libvirt.VIR_DOMAIN_START_PAUSED
|
||||
# create the virtual machine
|
||||
# might use defineXML for persistence. VMs disappear after shutdown
|
||||
vm = connection.createXML(vm_xml_config, flags=vm_create_flags)
|
||||
|
||||
return vm
|
128
wireguard_mesh.py
Normal file
128
wireguard_mesh.py
Normal file
@ -0,0 +1,128 @@
|
||||
import base64
|
||||
import itertools
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Network, IPv4Address
|
||||
from typing import List, Optional, Dict, Iterator, Tuple
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshNodeConfig:
|
||||
"""
|
||||
Holds the config for a wireguard mesh node and relevant information.
|
||||
Can also render the actual wireguard config into a string that can then be written out into a file.
|
||||
"""
|
||||
|
||||
name: str
|
||||
ip: str
|
||||
peers: List['MeshNodeConfig'] = ()
|
||||
private_key: str = base64.b64encode(
|
||||
X25519PrivateKey.generate().private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
).decode()
|
||||
public_key: str = base64.b64encode(
|
||||
X25519PrivateKey.from_private_bytes(
|
||||
base64.b64decode(
|
||||
private_key.encode()
|
||||
)
|
||||
).public_key().public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
).decode()
|
||||
psk: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
listen_port: int = 51515
|
||||
|
||||
@property
|
||||
def config(self) -> str:
|
||||
"""
|
||||
render a wireguard config file based on the variables set in this class.
|
||||
:return: wireguard config as string
|
||||
"""
|
||||
|
||||
config = '[Interface]\n' + \
|
||||
f'Address = {self.ip}\n' + \
|
||||
f'PrivateKey = {self.private_key}\n' + \
|
||||
f'ListenPort = {self.listen_port}\n\n'
|
||||
|
||||
for peer in self.peers:
|
||||
config += '[Peer]\n' + \
|
||||
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
|
||||
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
|
||||
f'PublicKey = {peer.public_key}\n' + \
|
||||
f'AllowedIPs = {peer.ip}/32\n' + \
|
||||
f'PersistentKeepalive = 25\n\n'
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_mesh(node_count=3,
|
||||
ip_pool: Iterator[IPv4Address] = IPv4Network('10.0.0.0/24').hosts(),
|
||||
) -> Dict[str, MeshNodeConfig]:
|
||||
"""
|
||||
Create a number of wireguard mesh configs,
|
||||
that can then be rolled out on machines to use as a wireguard mesh network.
|
||||
|
||||
:param ip_pool: Iterator over IPv4Address objects to draw IP Addresses from
|
||||
:param node_count: number of configs to generate
|
||||
:return: a dict with names of instances as keys and their MeshNodeConfig objects as corresponding values
|
||||
"""
|
||||
|
||||
node_names = (
|
||||
f'node-{idx}' for idx in range(1, node_count + 1)
|
||||
)
|
||||
psk = base64.b64encode(secrets.token_bytes(32)).decode()
|
||||
|
||||
nodes = dict(
|
||||
map(
|
||||
_create_mesh_node,
|
||||
node_names,
|
||||
itertools.repeat(ip_pool),
|
||||
itertools.repeat(psk)
|
||||
)
|
||||
)
|
||||
|
||||
_connect_nodes(nodes)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
# TODO this should take a graph / adjacency list as argument and set peers based on that.
|
||||
# currently all nodes are connected to all other nodes, which scales with O(n!) and
|
||||
# takes significant amounts of time and compute for even small n < 10.
|
||||
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
|
||||
"""
|
||||
This function implements the actual entanglement of nodes and thus decides for implementing
|
||||
"who is able to talk to whom" rules within the wireguard mesh network.
|
||||
:param nodes: dict with names of instances as keys and their MeshNodeConfig objects as corresponding values
|
||||
:return: nothing
|
||||
"""
|
||||
for src, dst in itertools.permutations(nodes.keys(), 2):
|
||||
nodes[dst].peers.append(nodes[src])
|
||||
|
||||
|
||||
def _create_mesh_node(
|
||||
node_name: str,
|
||||
ip_pool: Iterator[IPv4Address],
|
||||
psk: Optional[str] = None
|
||||
) -> Tuple[str, MeshNodeConfig]:
|
||||
"""
|
||||
Create a single mesh node config
|
||||
:param node_name: name of the node
|
||||
:param ip_pool: iterator over IPv4Address objects to draw an IPv4Address for this node from
|
||||
:param psk: preshared key as string. Optional.
|
||||
:return: a tuple with name of the new node as first and the MeshNodeConfig as second element
|
||||
"""
|
||||
return node_name, MeshNodeConfig(
|
||||
name=node_name,
|
||||
peers=[],
|
||||
ip=str(next(ip_pool)),
|
||||
psk=psk
|
||||
)
|
Loading…
Reference in New Issue
Block a user