diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go index 51c1a36d2..e9f335bf3 100644 --- a/tsconsensus/tsconsensus.go +++ b/tsconsensus/tsconsensus.go @@ -89,6 +89,13 @@ type StreamLayer struct { // Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + allowed, err := allowedPeer(string(address), sl.tag, sl.s) + if err != nil { + return nil, err + } + if !allowed { + return nil, errors.New("peer is not allowed") + } ctx, _ := context.WithTimeout(context.Background(), timeout) return sl.s.Dial(ctx, "tcp", string(address)) } diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index 83b6c62ae..1011a2648 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -12,7 +12,6 @@ import ( "net/netip" "os" "path/filepath" - "slices" "strings" "testing" "time" @@ -140,19 +139,41 @@ func waitForNodesToBeTaggedInStatus(t *testing.T, ctx context.Context, ts *tsnet } else { tags = status.Peer[k].Tags } - if tags == nil || !slices.Contains(tags.AsSlice(), tag) { - return false + if tag == "" { + if tags != nil && tags.Len() != 0 { + return false + } + } else { + if tags == nil { + return false + } + sliceTags := tags.AsSlice() + if len(sliceTags) != 1 || sliceTags[0] != tag { + return false + } } } return true - }, 5, 1*time.Second) + }, 10, 2*time.Second) } func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) { t.Helper() for _, key := range nodeKeys { n := control.Node(key) - n.Tags = append(n.Tags, tag) + if tag == "" { + if len(n.Tags) != 1 { + t.Fatalf("expected tags to have one tag") + } + n.Tags = nil + } else { + if len(n.Tags) != 0 { + // if we want this to work with multiple tags we'll have to change the logic + // for checking if a tag got removed yet. + t.Fatalf("expected tags to be empty") + } + n.Tags = append(n.Tags, tag) + } b := true n.Online = &b control.UpdateNode(n) @@ -520,3 +541,65 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { t.Fatalf("tagged node trying to send should not time out, got: %v", err) } } + +func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { + t.Skip("flaky test, need to figure out how to actually cause a Dial if we want to test this") + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, control, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + tagNodes(t, control, []key.NodePublic{ps[2].key}, "") + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{ps[2].key}, "") + + // now when we try to communicate there's an open conn we can talk over still, but + // we won't dial a fresh one + // get Raft to redial by removing and readding + // TODO although this doesn't actually cause redialing apparently, at least not for the command rpc stuff. + fut := ps[0].c.Raft.RemoveServer(raft.ServerID(ps[2].c.Self.ID), 0, 5*time.Second) + err := fut.Error() + if err != nil { + t.Fatal(err) + } + + fut = ps[0].c.Raft.AddVoter(raft.ServerID(ps[2].c.Self.ID), raft.ServerAddress(raftAddr(ps[2].c.Self.Host, cfg)), 0, 5*time.Second) + err = fut.Error() + if err != nil { + t.Fatal(err) + } + + // ps[2] doesn't get updates any more + res, err := ps[0].c.ExecuteCommand(Command{Args: []byte{byte(1)}}) + if err != nil { + t.Fatalf("Error ExecuteCommand: %v", err) + } + if res.Err != nil { + t.Fatalf("Result Error ExecuteCommand: %v", res.Err) + } + + fxOneEventSent := func() bool { + fmt.Println(len(ps[0].sm.events), len(ps[1].sm.events), len(ps[2].sm.events)) + return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(ps[2].sm.events) == 3 + } + waitFor(t, "after untagging first and second node get events, but third does not", fxOneEventSent, 10, time.Second*1) + + res, err = ps[1].c.ExecuteCommand(Command{Args: []byte{byte(1)}}) + if err != nil { + t.Fatalf("Error ExecuteCommand: %v", err) + } + if res.Err != nil { + t.Fatalf("Result Error ExecuteCommand: %v", res.Err) + } + + fxTwoEventsSent := func() bool { + return len(ps[0].sm.events) == 5 && len(ps[1].sm.events) == 5 && len(ps[2].sm.events) == 3 + } + waitFor(t, "after untagging first and second node get events, but third does not", fxTwoEventsSent, 10, time.Second*1) +}