wg-mesh-test/main.py
2021-08-31 23:17:54 +02:00

122 lines
3.2 KiB
Python

import base64
import itertools
import os
import secrets
from dataclasses import dataclass
from ipaddress import IPv4Network, IPv4Address
from pathlib import Path
from typing import Iterator, Dict, Tuple, List, Optional
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
def _spawn_vm():
pass
@dataclass
class MeshNodeConfig:
name: str
ip: str
peers: List['MeshNodeConfig'] = ()
private_key: str = base64.b64encode(
X25519PrivateKey.generate().private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode()
public_key: str = base64.b64encode(
X25519PrivateKey.from_private_bytes(
base64.b64decode(
private_key.encode()
)
).public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode()
psk: Optional[str] = None
endpoint: Optional[str] = None
listen_port: int = 51515
@property
def config(self) -> str:
config = '[Interface]\n' + \
f'Address = {self.ip}\n' + \
f'PrivateKey = {self.private_key}\n' + \
f'ListenPort = {self.listen_port}\n\n'
for peer in self.peers:
config += '[Peer]\n' + \
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
f'PublicKey = {peer.public_key}\n' + \
f'AllowedIPs = {peer.ip}/32\n' + \
f'PersistentKeepalive = 25\n\n'
return config
def _create_node(
node_name: str,
ip_pool: Iterator[IPv4Address],
psk: Optional[str] = None
) -> Tuple[str, MeshNodeConfig]:
return node_name, MeshNodeConfig(
name=node_name,
peers=[],
ip=str(next(ip_pool)),
psk=psk
)
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
for src, dst in itertools.permutations(nodes.keys(), 2):
nodes[dst].peers.append(nodes[src])
def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]:
ip_pool = IPv4Network(ip_subnet).hosts()
node_names = (
f'node-{idx}' for idx in range(1, node_count + 1)
)
psk = base64.b64encode(secrets.token_bytes(32)).decode()
nodes = dict(
map(
_create_node,
node_names,
itertools.repeat(ip_pool),
itertools.repeat(psk)
)
)
_connect_nodes(nodes)
return nodes
def main():
nodes = _create_mesh()
print(nodes)
tmp_dir = '/tmp/node_configs'
p = Path(tmp_dir)
if p.exists():
for file in p.iterdir():
file.unlink()
p.rmdir()
os.makedirs(tmp_dir)
for node in nodes.values():
with open(p / f'{node.name}.conf', 'w+') as fh:
fh.write(node.config)
if __name__ == '__main__':
main()