wg-mesh-test/main.py
2021-08-31 23:24:31 +02:00

125 lines
3.4 KiB
Python

import base64
import itertools
import os
import secrets
from dataclasses import dataclass
from ipaddress import IPv4Network, IPv4Address
from pathlib import Path
from typing import Iterator, Dict, Tuple, List, Optional
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
def _spawn_vm():
pass
@dataclass
class MeshNodeConfig:
name: str
ip: str
peers: List['MeshNodeConfig'] = ()
private_key: str = base64.b64encode(
X25519PrivateKey.generate().private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode()
public_key: str = base64.b64encode(
X25519PrivateKey.from_private_bytes(
base64.b64decode(
private_key.encode()
)
).public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode()
psk: Optional[str] = None
endpoint: Optional[str] = None
listen_port: int = 51515
@property
def config(self) -> str:
config = '[Interface]\n' + \
f'Address = {self.ip}\n' + \
f'PrivateKey = {self.private_key}\n' + \
f'ListenPort = {self.listen_port}\n\n'
for peer in self.peers:
config += '[Peer]\n' + \
(f'Endpoint: {peer.endpoint}\n' if peer.endpoint is not None else '') + \
(f'PresharedKey: {peer.psk}\n' if peer.psk is not None else '') + \
f'PublicKey = {peer.public_key}\n' + \
f'AllowedIPs = {peer.ip}/32\n' + \
f'PersistentKeepalive = 25\n\n'
return config
def _create_node(
node_name: str,
ip_pool: Iterator[IPv4Address],
psk: Optional[str] = None
) -> Tuple[str, MeshNodeConfig]:
return node_name, MeshNodeConfig(
name=node_name,
peers=[],
ip=str(next(ip_pool)),
psk=psk
)
# TODO this should take a graph / adjacency list as argument and set peers based on that.
# currently all nodes are connected to all other nodes, which scales with O(n!) and
# takes significant amounts of time and compute for even small n < 10.
def _connect_nodes(nodes: Dict[str, MeshNodeConfig]) -> None:
for src, dst in itertools.permutations(nodes.keys(), 2):
nodes[dst].peers.append(nodes[src])
def _create_mesh(node_count=5, ip_subnet: str = '10.0.0.0/24') -> Dict[str, MeshNodeConfig]:
ip_pool = IPv4Network(ip_subnet).hosts()
node_names = (
f'node-{idx}' for idx in range(1, node_count + 1)
)
psk = base64.b64encode(secrets.token_bytes(32)).decode()
nodes = dict(
map(
_create_node,
node_names,
itertools.repeat(ip_pool),
itertools.repeat(psk)
)
)
_connect_nodes(nodes)
return nodes
def main():
nodes = _create_mesh()
print(nodes)
tmp_dir = '/tmp/node_configs'
p = Path(tmp_dir)
if p.exists():
for file in p.iterdir():
file.unlink()
p.rmdir()
os.makedirs(tmp_dir)
for node in nodes.values():
with open(p / f'{node.name}.conf', 'w+') as fh:
fh.write(node.config)
if __name__ == '__main__':
main()