diff --git a/internal/app/machined/pkg/controllers/network/link_merge_test.go b/internal/app/machined/pkg/controllers/network/link_merge_test.go index c87e8c716..06e1ab6c1 100644 --- a/internal/app/machined/pkg/controllers/network/link_merge_test.go +++ b/internal/app/machined/pkg/controllers/network/link_merge_test.go @@ -274,6 +274,79 @@ func (suite *LinkMergeSuite) TestMergeFlapping() { })) } +func (suite *LinkMergeSuite) TestMergeWireguard() { + static := network.NewLinkSpec(network.ConfigNamespaceName, "configuration/wglan0") + *static.TypedSpec() = network.LinkSpecSpec{ + Name: "wglan0", + Wireguard: network.WireguardSpec{ + ListenPort: 1234, + Peers: []network.WireguardPeer{ + { + PublicKey: "bGsc2rOpl6JHd/Pm4fYrIkEABL0ZxW7IlaSyh77IMhw=", + Endpoint: "127.0.0.1:9999", + }, + }, + }, + ConfigLayer: network.ConfigMachineConfiguration, + } + + wglanOperator := network.NewLinkSpec(network.ConfigNamespaceName, "wglan/wglan0") + *wglanOperator.TypedSpec() = network.LinkSpecSpec{ + Name: "wglan0", + Wireguard: network.WireguardSpec{ + PrivateKey: "IG9MqCII7z54Ysof1fQ9a7WcMNG+qNJRMyRCQz3JTUY=", + ListenPort: 3456, + Peers: []network.WireguardPeer{ + { + PublicKey: "RXdQkMTD1Jcxd/Wizr9k8syw8ANs57l5jTormDVHAVs=", + Endpoint: "127.0.0.1:1234", + }, + }, + }, + ConfigLayer: network.ConfigOperator, + } + + for _, res := range []resource.Resource{static, wglanOperator} { + suite.Require().NoError(suite.state.Create(suite.ctx, res), "%v", res.Spec()) + } + + suite.Assert().NoError(retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( + func() error { + return suite.assertLinks([]string{ + "wglan0", + }, func(r *network.LinkSpec) error { + suite.Assert().Equal("IG9MqCII7z54Ysof1fQ9a7WcMNG+qNJRMyRCQz3JTUY=", r.TypedSpec().Wireguard.PrivateKey) + suite.Assert().Equal(1234, r.TypedSpec().Wireguard.ListenPort) + suite.Assert().Len(r.TypedSpec().Wireguard.Peers, 2) + + suite.Assert().Equal( + network.WireguardPeer{ + PublicKey: "RXdQkMTD1Jcxd/Wizr9k8syw8ANs57l5jTormDVHAVs=", + Endpoint: "127.0.0.1:1234", + }, + r.TypedSpec().Wireguard.Peers[0], + ) + + suite.Assert().Equal( + network.WireguardPeer{ + PublicKey: "bGsc2rOpl6JHd/Pm4fYrIkEABL0ZxW7IlaSyh77IMhw=", + Endpoint: "127.0.0.1:9999", + }, + r.TypedSpec().Wireguard.Peers[1], + ) + + return nil + }) + })) + + suite.Require().NoError(suite.state.Destroy(suite.ctx, wglanOperator.Metadata())) + + suite.Assert().NoError(retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( + func() error { + return suite.assertNoLinks("wglan0") + })) +} + func (suite *LinkMergeSuite) TearDownTest() { suite.T().Log("tear down") diff --git a/internal/app/machined/pkg/controllers/network/link_spec.go b/internal/app/machined/pkg/controllers/network/link_spec.go index 6b5a85efb..ac7ac4519 100644 --- a/internal/app/machined/pkg/controllers/network/link_spec.go +++ b/internal/app/machined/pkg/controllers/network/link_spec.go @@ -379,6 +379,7 @@ func (ctrl *LinkSpecController) syncLink(ctx context.Context, r controller.Runti link.TypedSpec().Wireguard.Sort() + // order here is important: we allow listenPort to be zero in the configuration if !existingSpec.Equal(&link.TypedSpec().Wireguard) { config, err := link.TypedSpec().Wireguard.Encode(&existingSpec) if err != nil { @@ -389,7 +390,7 @@ func (ctrl *LinkSpecController) syncLink(ctx context.Context, r controller.Runti return fmt.Errorf("error configuring wireguard device %q: %w", link.TypedSpec().Name, err) } - logger.Info("reconfigured wireguard link") + logger.Info("reconfigured wireguard link", zap.Int("peers", len(link.TypedSpec().Wireguard.Peers))) // notify link status controller, as wireguard updates can't be watched via netlink API if err = r.Modify(ctx, network.NewLinkRefresh(network.NamespaceName, network.LinkKindWireguard), func(r resource.Resource) error { diff --git a/pkg/resources/network/link.go b/pkg/resources/network/link.go index 1a0de8770..e0b479bf6 100644 --- a/pkg/resources/network/link.go +++ b/pkg/resources/network/link.go @@ -312,7 +312,8 @@ func (spec *WireguardSpec) Equal(other *WireguardSpec) bool { return false } - if spec.ListenPort != other.ListenPort { + // listenPort of '0' means use any available port, so we shouldn't consider this to be a "change" + if spec.ListenPort != other.ListenPort && other.ListenPort != 0 { return false } @@ -501,3 +502,35 @@ func (spec *WireguardSpec) Decode(dev *wgtypes.Device) { } } } + +// Merge with other Wireguard spec overwriting non-zero values. +func (spec *WireguardSpec) Merge(other WireguardSpec) { + if other.ListenPort != 0 { + spec.ListenPort = other.ListenPort + } + + if other.FirewallMark != 0 { + spec.FirewallMark = other.FirewallMark + } + + if other.PrivateKey != "" { + spec.PrivateKey = other.PrivateKey + } + + // avoid adding same peer twice, no real peer information merging for now + for _, peer := range other.Peers { + exists := false + + for _, p := range spec.Peers { + if p.PublicKey == peer.PublicKey { + exists = true + + break + } + } + + if !exists { + spec.Peers = append(spec.Peers, peer) + } + } +} diff --git a/pkg/resources/network/link_spec.go b/pkg/resources/network/link_spec.go index 331fb13a8..78487dd2a 100644 --- a/pkg/resources/network/link_spec.go +++ b/pkg/resources/network/link_spec.go @@ -83,7 +83,7 @@ func (spec *LinkSpecSpec) Merge(other *LinkSpecSpec) error { } if other.Type != 0 { - spec.Type = 0 + spec.Type = other.Type } if other.ParentName != "" { @@ -102,8 +102,14 @@ func (spec *LinkSpecSpec) Merge(other *LinkSpecSpec) error { spec.BondMaster = other.BondMaster } + // Wireguard config should be able to apply non-zero values in earlier config layers which may be zero values in later layers. + // Thus, we handle each Wireguard configuration value discretely. if !other.Wireguard.IsZero() { - spec.Wireguard = other.Wireguard + if spec.Wireguard.IsZero() { + spec.Wireguard = other.Wireguard + } else { + spec.Wireguard.Merge(other.Wireguard) + } } spec.ConfigLayer = other.ConfigLayer diff --git a/pkg/resources/network/link_test.go b/pkg/resources/network/link_test.go index 26f413716..dd1cdcba2 100644 --- a/pkg/resources/network/link_test.go +++ b/pkg/resources/network/link_test.go @@ -166,6 +166,15 @@ func TestWireguardSpecDecode(t *testing.T) { assert.Equal(t, expected, spec) assert.True(t, expected.Equal(&spec)) + // zeroed out listen port is still acceptable on the right side + spec.ListenPort = 0 + assert.True(t, expected.Equal(&spec)) + + // ... but not on the left side + expected.ListenPort = 0 + spec.ListenPort = 30000 + assert.False(t, expected.Equal(&spec)) + var zeroSpec network.WireguardSpec assert.False(t, zeroSpec.Equal(&spec)) @@ -318,3 +327,157 @@ func TestWireguardSpecEncode(t *testing.T) { }, }, delta) } + +func TestWireguardSpecMerge(t *testing.T) { + priv, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + pub1, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + pub2, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + for _, tt := range []struct { + name string + spec network.WireguardSpec + other network.WireguardSpec + + expected network.WireguardSpec + }{ + { + name: "zero", + }, + { + name: "speczero", + other: network.WireguardSpec{ + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + + expected: network.WireguardSpec{ + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + }, + { + name: "otherzero", + spec: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + }, + }, + }, + + expected: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + }, + }, + }, + }, + { + name: "mixed", + spec: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + }, + }, + }, + other: network.WireguardSpec{ + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + + expected: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + }, + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + }, + { + name: "peerconflict", + spec: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + PersistentKeepaliveInterval: time.Second, + }, + }, + }, + other: network.WireguardSpec{ + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + Endpoint: "127.0.0.1:466", + }, + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + + expected: network.WireguardSpec{ + PrivateKey: priv.String(), + FirewallMark: 34, + ListenPort: 456, + Peers: []network.WireguardPeer{ + { + PublicKey: pub1.String(), + PersistentKeepaliveInterval: time.Second, + }, + { + PublicKey: pub2.String(), + Endpoint: "127.0.0.1:3445", + }, + }, + }, + }, + } { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + spec := tt.spec + spec.Merge(tt.other) + + assert.Equal(t, tt.expected, spec) + }) + } +}