libvirt integration

This commit is contained in:
fix 2021-09-05 23:12:56 +02:00
parent 42c71a201d
commit 1ee0d2f740
6 changed files with 551 additions and 96 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
.idea
__pycache__

80
cloud_config.py Normal file
View File

@ -0,0 +1,80 @@
import base64
import shutil
import subprocess
from pathlib import Path
from typing import Optional, Iterable, Tuple
def gen_cloud_config(name: str,
cloud_init_path: Path,
wireguard_config: str,
ssh_authorized_keys: Optional[Iterable[str]] = None
) -> Tuple[Path, Path]:
if ssh_authorized_keys is not None:
ssh_keys = 'ssh_authorized_keys\n' + \
'\n'.join((f'- {key}' for key in ssh_authorized_keys))
else:
ssh_keys = ''
user_data = f"""
#cloud-config
chpasswd: {{ expire: False }}
password: root
{ssh_keys}
ssh_pwauth: False
timezone: Europe/Berlin
users:
- default
package_update: true
packages:
- wireguard
write_files:
- encoding: b64
content: {base64.b64encode(wireguard_config.encode())}
owner: root:root
path: /etc/wireguard/wg0.conf
permissions: 0600
"""
meta_data = f"""
local-hostname: {name}
"""
user_data_path = cloud_init_path / 'user_data'
meta_data_path = cloud_init_path / 'meta_data'
with open(user_data_path, 'w+') as handle:
handle.write(user_data)
with open(meta_data_path, 'w+') as handle:
handle.write(meta_data)
return user_data_path, meta_data_path
def _gen_cloudinit_iso_image(path: Path, user_data_path: Path, meta_data_path: Path, volume_label: str = 'cloud_init'):
genisoimage_executable = shutil.which('genisoimage')
if genisoimage_executable is None:
raise FileNotFoundError('could not locate genisoimage executable!')
command = (
genisoimage_executable,
'-output', str(path),
'-V', volume_label,
'-r',
'-J',
user_data_path,
meta_data_path
)
try:
subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True
)
except subprocess.CalledProcessError as ex:
print(f'genisoimage error, stderr was: {ex.stderr}, stdout was: {ex.stdout}')
raise ex

144
main.py
View File

@ -1,108 +1,34 @@
import base64
import itertools
import os
import secrets
from dataclasses import dataclass
from ipaddress import IPv4Network, IPv4Address
from pathlib import Path
from typing import Iterator, Dict, Tuple, List, Optional
from time import sleep
from typing import List, Optional
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
import libvirt
from libvirt import virDomain, virStoragePool
from virt import spawn_vm
from wireguard_mesh import create_mesh
def _spawn_vm():
pass
def cleanup(vms: List[virDomain], storage_pool: Optional[virStoragePool] = None):
for vm in vms:
try:
vm.destroy()
except libvirt.libvirtError as ex:
print(f'error shutting down {vm.name()}: {ex.get_error_message()}')
# vm.undefine()
if storage_pool is not None:
@dataclass
class MeshNodeConfig:
name: str
ip: str
peers: List['MeshNodeConfig'] = ()
private_key: str = base64.b64encode(
X25519PrivateKey.generate().private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode()
public_key: str = base64.b64encode(
X25519PrivateKey.from_private_bytes(
base64.b64decode(
private_key.encode()
)
).public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode()
psk: Optional[str] = None
endpoint: Optional[str] = None
listen_port: int = 51515
@property
def config(self) -> str:
config = '[Interface]\n' + \
f'Address = {self.ip}\n' + \
f'PrivateKey = {self.private_key}\n' + \
f'ListenPort = {self.listen_port}\n\n'
for peer in self.peers:
config += '[Peer]\n' + \
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
f'PublicKey = {peer.public_key}\n' + \
f'AllowedIPs = {peer.ip}/32\n' + \
f'PersistentKeepalive = 25\n\n'
return config
def _create_node(
node_name: str,
ip_pool: Iterator[IPv4Address],
psk: Optional[str] = None
) -> Tuple[str, MeshNodeConfig]:
return node_name, MeshNodeConfig(
name=node_name,
peers=[],
ip=str(next(ip_pool)),
psk=psk
)
# TODO this should take a graph / adjacency list as argument and set peers based on that.
# currently all nodes are connected to all other nodes, which scales with O(n!) and
# takes significant amounts of time and compute for even small n < 10.
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
for src, dst in itertools.permutations(nodes.keys(), 2):
nodes[dst].peers.append(nodes[src])
def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]:
ip_pool = IPv4Network(ip_subnet).hosts()
node_names = (
f'node-{idx}' for idx in range(1, node_count + 1)
)
psk = base64.b64encode(secrets.token_bytes(32)).decode()
nodes = dict(
map(
_create_node,
node_names,
itertools.repeat(ip_pool),
itertools.repeat(psk)
)
)
_connect_nodes(nodes)
return nodes
for volume in storage_pool.listAllVolumes():
print(f'deleting volume: {volume.path()}')
volume.delete()
storage_pool.destroy()
# storage_pool.undefine()
def main():
nodes = _create_mesh()
nodes = create_mesh()
print(nodes)
tmp_dir = '/tmp/node_configs'
@ -119,6 +45,34 @@ def main():
with open(p / f'{node.name}.conf', 'w+') as fh:
fh.write(node.config)
libvirt_conn_uri = 'qemu:///system'
conn = libvirt.open(libvirt_conn_uri)
# TODO map to all
storage_pool_path = Path('/tmp/test')
vms: List[virDomain] = [
spawn_vm(conn, node.name, storage_pool_path, node.config)
for node in nodes.values()
]
print('vms created, resuming cpus...')
for vm in vms:
vm.resume()
try:
while True:
sleep(1)
except KeyboardInterrupt:
print('caught SIGTERM, finishing...')
storage_pool: Optional[virStoragePool] = None
storage_pool_name = 'wireguard_test'
if storage_pool_name in (p.name() for p in conn.listAllStoragePools()):
storage_pool = conn.storagePoolLookupByName(storage_pool_name)
cleanup(vms, storage_pool)
if __name__ == '__main__':
main()

View File

@ -1 +1,2 @@
cryptography~=3.4.8
cryptography~=3.4.8
libvirt-python~=7.7.0

291
virt.py Normal file
View File

@ -0,0 +1,291 @@
import uuid
from pathlib import Path
from typing import Optional, Iterable
import libvirt
from libvirt import virConnect, virDomain, virStoragePool
from cloud_config import gen_cloud_config, _gen_cloudinit_iso_image
def _gen_libvirt_volume_config(name: str, path: Path, backing_image: Path, capacity: int = 1000,
image_format: str = 'qcow2') -> str:
"""
creates the XML config for a volume that will hold the data of a vm.
:param name: volume name
:param path: path to save volume to
:param backing_image: path to a backing image to base this one on
:param capacity: capacity in MiB
:param image_format: qcow2, iso, raw or alike.
A list of valid values can be found here: https://libvirt.org/storage.html#StorageBackendFS
:return: XML config as string
"""
xml_config = f"""
<volume type='file'>
<name>{name}.{image_format}</name>
<capacity unit='MiB'>{capacity}</capacity>
<allocation unit='MiB'>0</allocation>
<target>
<path>{path}</path>
<format type='{image_format}'/>
</target>
<backingStore>
<path>{backing_image}</path>
<format type='qcow2'/>
</backingStore>
</volume>
"""
return xml_config
def _gen_libvirt_storage_pool_config(name: str, path: Path) -> str:
"""
create the libvirt XML config for a storage pool with directory backend.
:param name: name of the storage pool
:param path: the path to save images in the pool to
:return: XML config as string
"""
xml_config = f"""
<pool type='dir'>
<name>{name}</name>
<target>
<path>{path}</path>
</target>
</pool>
"""
return xml_config
def _gen_libvirt_config(name: str, hdd_path: Path, cdrom_path: Path, cpu: int = 1, mem: int = 512) -> str:
"""
return libvirt XML config for a virtual machine
:param name: name of the machine
:param hdd_path: path to the image to use as primary disk
:param cdrom_path: path to the cloud-init image
:param cpu: number of cpus
:param mem: amount of ram in MiB
:return: XML config as string
"""
vm_uuid = uuid.uuid4()
xml_config = f"""
<domain type="kvm">
<name>{name}</name>
<uuid>{vm_uuid}</uuid>
<metadata>
<libosinfo:libosinfo xmlns:libosinfo="http://libosinfo.org/xmlns/libvirt/domain/1.0">
<libosinfo:os id="http://ubuntu.com/ubuntu/20.04"/>
</libosinfo:libosinfo>
</metadata>
<memory unit="MiB">{mem}</memory>
<currentMemory unit="MiB">{mem}</currentMemory>
<vcpu placement="static">{cpu}</vcpu>
<os>
<type arch="x86_64" machine="pc-q35-5.2">hvm</type>
</os>
<features>
<acpi/>
<apic/>
<vmport state="off"/>
</features>
<cpu mode="host-model" check="partial"/>
<clock offset="utc">
<timer name="rtc" tickpolicy="catchup"/>
<timer name="pit" tickpolicy="delay"/>
<timer name="hpet" present="no"/>
</clock>
<on_poweroff>destroy</on_poweroff>
<on_reboot>restart</on_reboot>
<on_crash>destroy</on_crash>
<pm>
<suspend-to-mem enabled="no"/>
<suspend-to-disk enabled="no"/>
</pm>
<devices>
<emulator>/usr/bin/qemu-system-x86_64</emulator>
<console type='pty'>
<target type='serial'/>
</console>
<video>
<model type="virtio" heads="1" primary="yes">
<acceleration accel3d="yes"/>
</model>
<address type="pci" domain="0x0000" bus="0x00" slot="0x01" function="0x0"/>
</video>
<channel type="spicevmc">
<target type="virtio" name="com.redhat.spice.0"/>
<address type="virtio-serial" controller="0" bus="0" port="2"/>
</channel>
<channel type="unix">
<target type="virtio" name="org.qemu.guest_agent.0"/>
<address type="virtio-serial" controller="0" bus="0" port="1"/>
</channel>
<graphics type="spice" autoport="yes">
<listen type="address"/>
<image compression="off"/>
</graphics>
<disk type="file" device="disk">
<driver name="qemu" type="qcow2"/>
<source file="{hdd_path}"/>
<target dev="vda" bus="virtio"/>
<address type="pci" domain="0x0000" bus="0x03" slot="0x00" function="0x0"/>
<boot order="1"/>
</disk>
<disk type="file" device="cdrom">
<driver name="qemu" type="raw"/>
<source file="{cdrom_path}"/>
<target dev="sda" bus="sata"/>
<readonly/>
<address type="drive" controller="0" bus="0" target="0" unit="0"/>
<boot order="2"/>
</disk>
<controller type="usb" index="0" model="ich9-ehci1">
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x7"/>
</controller>
<controller type="usb" index="0" model="ich9-uhci1">
<master startport="0"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x0" multifunction="on"/>
</controller>
<controller type="usb" index="0" model="ich9-uhci2">
<master startport="2"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x1"/>
</controller>
<controller type="usb" index="0" model="ich9-uhci3">
<master startport="4"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x1d" function="0x2"/>
</controller>
<controller type="sata" index="0">
<address type="pci" domain="0x0000" bus="0x00" slot="0x1f" function="0x2"/>
</controller>
<controller type="pci" index="0" model="pcie-root"/>
<controller type="pci" index="1" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="1" port="0x10"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x0" multifunction="on"/>
</controller>
<controller type="pci" index="2" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="2" port="0x11"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x1"/>
</controller>
<controller type="pci" index="3" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="3" port="0x12"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x2"/>
</controller>
<controller type="pci" index="4" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="4" port="0x13"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x3"/>
</controller>
<controller type="pci" index="5" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="5" port="0x14"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x4"/>
</controller>
<controller type="pci" index="6" model="pcie-root-port">
<model name="pcie-root-port"/>
<target chassis="6" port="0x15"/>
<address type="pci" domain="0x0000" bus="0x00" slot="0x02" function="0x5"/>
</controller>
<controller type="virtio-serial" index="0">
<address type="pci" domain="0x0000" bus="0x02" slot="0x00" function="0x0"/>
</controller>
<interface type="network">
<source network="default"/>
<model type="virtio"/>
<address type="pci" domain="0x0000" bus="0x01" slot="0x00" function="0x0"/>
</interface>
<input type="mouse" bus="ps2"/>
<input type="keyboard" bus="ps2"/>
<audio id="1" type="spice"/>
<memballoon model="virtio">
<address type="pci" domain="0x0000" bus="0x04" slot="0x00" function="0x0"/>
</memballoon>
<rng model="virtio">
<backend model="random">/dev/urandom</backend>
<address type="pci" domain="0x0000" bus="0x05" slot="0x00" function="0x0"/>
</rng>
</devices>
</domain>
"""
return xml_config
def spawn_vm(connection: virConnect,
name: str,
storage_pool_path: Path,
wireguard_config: str,
ssh_authorized_keys: Optional[Iterable[str]] = None,
storage_pool_name: str = 'wireguard_test'
) -> virDomain:
"""
spawm a VM provisioned via cloud-init that is a mesh node
:param connection: libvirt connection object
:param name: name of the VM
:param storage_pool_path: path storage pool objects are saved in
:param wireguard_config: wireguard config file as string
:param ssh_authorized_keys: ssh public keys as list to add to the image
:param storage_pool_name: name of the storage pool (will be created if nonexistent)
:return: a libvirt virDomain object representing a virtual machine
"""
# if storage pool does not exist, create it. Otherwise, fetch it.
storage_pools = (p.name() for p in connection.listAllStoragePools())
if storage_pool_name not in storage_pools:
storage_pool_path.mkdir(exist_ok=True)
storage_pool_xml_config = _gen_libvirt_storage_pool_config(storage_pool_name, storage_pool_path)
storage_pool: virStoragePool = connection.storagePoolCreateXML(storage_pool_xml_config)
else:
storage_pool = connection.storagePoolLookupByName(storage_pool_name)
image_format = 'qcow2'
volume_path = storage_pool_path / f'{name}.{image_format}'
# if image exists, fetch and wipe it. Otherwise, create it.
if str(volume_path) in (v.path() for v in storage_pool.listAllVolumes()):
volume = storage_pool.storageVolLookupByName(volume_path.name)
volume.wipe()
else:
backing_file_path = Path('/tmp/focal-server-cloudimg-amd64.qcow2')
volume_xml_config = _gen_libvirt_volume_config(name, volume_path, backing_file_path, image_format=image_format)
# volume_create_flags = libvirt.VIR_STORAGE_VOL_CREATE_PREALLOC_METADATA
storage_pool.createXML(volume_xml_config)
# if the cloud-init images exist, delete them.
cloud_init_volume_name = f'cloudinit-{name}.iso'
try:
cloud_init_volume = storage_pool.storageVolLookupByName(cloud_init_volume_name)
cloud_init_volume.delete()
except libvirt.libvirtError:
pass
# recreate cloud-init images
cloud_init_iso_path = storage_pool_path / cloud_init_volume_name
user_data_path, meta_data_path = gen_cloud_config(name, Path('/tmp/'), wireguard_config, ssh_authorized_keys)
_gen_cloudinit_iso_image(cloud_init_iso_path, user_data_path, meta_data_path)
# generate VM XML config
vm_xml_config = _gen_libvirt_config(
name,
volume_path,
cloud_init_iso_path,
cpu=1,
mem=512
)
vm_create_flags = libvirt.VIR_DOMAIN_START_PAUSED
# create the virtual machine
# might use defineXML for persistence. VMs disappear after shutdown
vm = connection.createXML(vm_xml_config, flags=vm_create_flags)
return vm

128
wireguard_mesh.py Normal file
View File

@ -0,0 +1,128 @@
import base64
import itertools
import secrets
from dataclasses import dataclass
from ipaddress import IPv4Network, IPv4Address
from typing import List, Optional, Dict, Iterator, Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
@dataclass
class MeshNodeConfig:
"""
Holds the config for a wireguard mesh node and relevant information.
Can also render the actual wireguard config into a string that can then be written out into a file.
"""
name: str
ip: str
peers: List['MeshNodeConfig'] = ()
private_key: str = base64.b64encode(
X25519PrivateKey.generate().private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode()
public_key: str = base64.b64encode(
X25519PrivateKey.from_private_bytes(
base64.b64decode(
private_key.encode()
)
).public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode()
psk: Optional[str] = None
endpoint: Optional[str] = None
listen_port: int = 51515
@property
def config(self) -> str:
"""
render a wireguard config file based on the variables set in this class.
:return: wireguard config as string
"""
config = '[Interface]\n' + \
f'Address = {self.ip}\n' + \
f'PrivateKey = {self.private_key}\n' + \
f'ListenPort = {self.listen_port}\n\n'
for peer in self.peers:
config += '[Peer]\n' + \
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
f'PublicKey = {peer.public_key}\n' + \
f'AllowedIPs = {peer.ip}/32\n' + \
f'PersistentKeepalive = 25\n\n'
return config
def create_mesh(node_count=3,
ip_pool: Iterator[IPv4Address] = IPv4Network('10.0.0.0/24').hosts(),
) -> Dict[str, MeshNodeConfig]:
"""
Create a number of wireguard mesh configs,
that can then be rolled out on machines to use as a wireguard mesh network.
:param ip_pool: Iterator over IPv4Address objects to draw IP Addresses from
:param node_count: number of configs to generate
:return: a dict with names of instances as keys and their MeshNodeConfig objects as corresponding values
"""
node_names = (
f'node-{idx}' for idx in range(1, node_count + 1)
)
psk = base64.b64encode(secrets.token_bytes(32)).decode()
nodes = dict(
map(
_create_mesh_node,
node_names,
itertools.repeat(ip_pool),
itertools.repeat(psk)
)
)
_connect_nodes(nodes)
return nodes
# TODO this should take a graph / adjacency list as argument and set peers based on that.
# currently all nodes are connected to all other nodes, which scales with O(n!) and
# takes significant amounts of time and compute for even small n < 10.
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
"""
This function implements the actual entanglement of nodes and thus decides for implementing
"who is able to talk to whom" rules within the wireguard mesh network.
:param nodes: dict with names of instances as keys and their MeshNodeConfig objects as corresponding values
:return: nothing
"""
for src, dst in itertools.permutations(nodes.keys(), 2):
nodes[dst].peers.append(nodes[src])
def _create_mesh_node(
node_name: str,
ip_pool: Iterator[IPv4Address],
psk: Optional[str] = None
) -> Tuple[str, MeshNodeConfig]:
"""
Create a single mesh node config
:param node_name: name of the node
:param ip_pool: iterator over IPv4Address objects to draw an IPv4Address for this node from
:param psk: preshared key as string. Optional.
:return: a tuple with name of the new node as first and the MeshNodeConfig as second element
"""
return node_name, MeshNodeConfig(
name=node_name,
peers=[],
ip=str(next(ip_pool)),
psk=psk
)