import base64 import itertools import os import secrets from dataclasses import dataclass from ipaddress import IPv4Network, IPv4Address from pathlib import Path from typing import Iterator, Dict, Tuple, List, Optional from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey def _spawn_vm(): pass @dataclass class MeshNodeConfig: name: str ip: str peers: List['MeshNodeConfig'] = () private_key: str = base64.b64encode( X25519PrivateKey.generate().private_bytes( encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption(), ) ).decode() public_key: str = base64.b64encode( X25519PrivateKey.from_private_bytes( base64.b64decode( private_key.encode() ) ).public_key().public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ) ).decode() psk: Optional[str] = None endpoint: Optional[str] = None listen_port: int = 51515 @property def config(self) -> str: config = '[Interface]\n' + \ f'Address = {self.ip}\n' + \ f'PrivateKey = {self.private_key}\n' + \ f'ListenPort = {self.listen_port}\n\n' for peer in self.peers: config += '[Peer]\n' + \ (f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \ (f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \ f'PublicKey = {peer.public_key}\n' + \ f'AllowedIPs = {peer.ip}/32\n' + \ f'PersistentKeepalive = 25\n\n' return config def _create_node( node_name: str, ip_pool: Iterator[IPv4Address], psk: Optional[str] = None ) -> Tuple[str, MeshNodeConfig]: return node_name, MeshNodeConfig( name=node_name, peers=[], ip=str(next(ip_pool)), psk=psk ) # TODO this should take a graph / adjacency list as argument and set peers based on that. # currently all nodes are connected to all other nodes, which scales with O(n!) and # takes significant amounts of time and compute for even small n < 10. def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None: for src, dst in itertools.permutations(nodes.keys(), 2): nodes[dst].peers.append(nodes[src]) def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]: ip_pool = IPv4Network(ip_subnet).hosts() node_names = ( f'node-{idx}' for idx in range(1, node_count + 1) ) psk = base64.b64encode(secrets.token_bytes(32)).decode() nodes = dict( map( _create_node, node_names, itertools.repeat(ip_pool), itertools.repeat(psk) ) ) _connect_nodes(nodes) return nodes def main(): nodes = _create_mesh() print(nodes) tmp_dir = '/tmp/node_configs' p = Path(tmp_dir) if p.exists(): for file in p.iterdir(): file.unlink() p.rmdir() os.makedirs(tmp_dir) for node in nodes.values(): with open(p / f'{node.name}.conf', 'w+') as fh: fh.write(node.config) if __name__ == '__main__': main()