import base64 import itertools import os import secrets from dataclasses import dataclass from ipaddress import IPv4Network, IPv4Address from pathlib import Path from typing import Iterator, Dict, Tuple, List, Optional from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey def _spawn_vm(): pass @dataclass class MeshNodeConfig: name: str ip: str peers: List['MeshNodeConfig'] = () private_key: str = base64.b64encode( X25519PrivateKey.generate().private_bytes( encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption(), ) ).decode() public_key: str = base64.b64encode( X25519PrivateKey.from_private_bytes( base64.b64decode( private_key.encode() ) ).public_key().public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ) ).decode() psk: Optional[str] = None endpoint: Optional[str] = None listen_port: int = 51515 @property def config(self) -> str: config = '[Interface]\n' + \ f'Address = {self.ip}\n' + \ f'PrivateKey = {self.private_key}\n' + \ f'ListenPort = {self.listen_port}\n\n' for peer in self.peers: config += '[Peer]\n' + \ (f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \ (f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \ f'PublicKey = {peer.public_key}\n' + \ f'AllowedIPs = {peer.ip}/32\n' + \ f'PersistentKeepalive = 25\n\n' return config def _create_node( node_name: str, ip_pool: Iterator[IPv4Address], psk: Optional[str] = None ) -> Tuple[str, MeshNodeConfig]: return node_name, MeshNodeConfig( name=node_name, peers=[], ip=str(next(ip_pool)), psk=psk ) def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None: for src, dst in itertools.permutations(nodes.keys(), 2): nodes[dst].peers.append(nodes[src]) def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]: ip_pool = IPv4Network(ip_subnet).hosts() node_names = ( f'node-{idx}' for idx in range(1, node_count + 1) ) psk = base64.b64encode(secrets.token_bytes(32)).decode() nodes = dict( map( _create_node, node_names, itertools.repeat(ip_pool), itertools.repeat(psk) ) ) _connect_nodes(nodes) return nodes def main(): nodes = _create_mesh() print(nodes) tmp_dir = '/tmp/node_configs' p = Path(tmp_dir) if p.exists(): for file in p.iterdir(): file.unlink() p.rmdir() os.makedirs(tmp_dir) for node in nodes.values(): with open(p / f'{node.name}.conf', 'w+') as fh: fh.write(node.config) if __name__ == '__main__': main()