From 14dd2c92971e5bfc16d45fe46a1be36601081766 Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Mon, 13 Jan 2025 13:29:41 -0800 Subject: [PATCH] wip --- tsconsensus/http.go | 127 +++++++++ tsconsensus/monitor.go | 140 ++++++++++ tsconsensus/tsconsensus.go | 363 +++++++++++++++++++++++++ tsconsensus/tsconsensus_test.go | 459 ++++++++++++++++++++++++++++++++ 4 files changed, 1089 insertions(+) create mode 100644 tsconsensus/http.go create mode 100644 tsconsensus/monitor.go create mode 100644 tsconsensus/tsconsensus.go create mode 100644 tsconsensus/tsconsensus_test.go diff --git a/tsconsensus/http.go b/tsconsensus/http.go new file mode 100644 index 000000000..4e19f1aac --- /dev/null +++ b/tsconsensus/http.go @@ -0,0 +1,127 @@ +package tsconsensus + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +type joinRequest struct { + RemoteHost string `json:'remoteAddr'` + RemoteID string `json:'remoteID'` +} + +type commandClient struct { + port uint16 + httpClient *http.Client +} + +func (rac *commandClient) Url(host string, path string) string { + return fmt.Sprintf("http://%s:%d%s", host, rac.port, path) +} + +func (rac *commandClient) Join(host string, jr joinRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rBs, err := json.Marshal(jr) + if err != nil { + return err + } + url := rac.Url(host, "/join") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rBs)) + if err != nil { + return err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return err + } + respBs, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if resp.StatusCode != 200 { + return errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs))) + } + return nil +} + +func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + url := rac.Url(host, "/executeCommand") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs)) + if err != nil { + return CommandResult{}, err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return CommandResult{}, err + } + respBs, err := io.ReadAll(resp.Body) + if err != nil { + return CommandResult{}, err + } + if resp.StatusCode != 200 { + return CommandResult{}, errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs))) + } + var cr CommandResult + if err = json.Unmarshal(respBs, &cr); err != nil { + return CommandResult{}, err + } + return cr, nil +} + +func (c *Consensus) makeCommandMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + decoder := json.NewDecoder(r.Body) + var jr joinRequest + err := decoder.Decode(&jr) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if jr.RemoteHost == "" { + http.Error(w, "Required: remoteAddr", http.StatusBadRequest) + return + } + if jr.RemoteID == "" { + http.Error(w, "Required: remoteID", http.StatusBadRequest) + return + } + err = c.handleJoin(jr) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + decoder := json.NewDecoder(r.Body) + var cmd Command + err := decoder.Decode(&cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + result, err := c.executeCommandLocally(cmd) + if err := json.NewEncoder(w).Encode(result); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + return mux +} diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go new file mode 100644 index 000000000..9e90fef38 --- /dev/null +++ b/tsconsensus/monitor.go @@ -0,0 +1,140 @@ +package tsconsensus + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "slices" + "strings" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" +) + +type status struct { + Status *ipnstate.Status + RaftState string +} + +type monitor struct { + ts *tsnet.Server + con *Consensus +} + +func (m *monitor) getStatus(ctx context.Context) (status, error) { + lc, err := m.ts.LocalClient() + if err != nil { + return status{}, err + } + tStatus, err := lc.Status(ctx) + if err != nil { + return status{}, err + } + return status{Status: tStatus, RaftState: m.con.Raft.State().String()}, nil +} + +func serveMonitor(c *Consensus, ts *tsnet.Server, listenAddr string) (*http.Server, error) { + ln, err := ts.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + m := &monitor{con: c, ts: ts} + mux := http.NewServeMux() + mux.HandleFunc("/full", m.handleFullStatus) + mux.HandleFunc("/", m.handleSummaryStatus) + mux.HandleFunc("/netmap", m.handleNetmap) + mux.HandleFunc("/dial", m.handleDial) + srv := &http.Server{Handler: mux} + go func() { + defer ln.Close() + err := srv.Serve(ln) + log.Printf("MonitorHTTP stopped serving with error: %v", err) + }() + return srv, nil +} + +func (m *monitor) handleFullStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + if err := json.NewEncoder(w).Encode(s); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + lines := []string{} + for _, p := range s.Status.Peer { + if p.Online { + lines = append(lines, fmt.Sprintf("%s\t\t%d\t%d\t%t", strings.Split(p.DNSName, ".")[0], p.RxBytes, p.TxBytes, p.Active)) + } + } + slices.Sort(lines) + lines = append([]string{fmt.Sprintf("RaftState: %s", s.RaftState)}, lines...) + txt := strings.Join(lines, "\n") + "\n" + w.Write([]byte(txt)) +} + +func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { + var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap + mask |= ipn.NotifyNoPrivateKeys + lc, err := m.ts.LocalClient() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + watcher, err := lc.WatchIPNBus(r.Context(), mask) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer watcher.Close() + + n, err := watcher.Next() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + j, _ := json.MarshalIndent(n.NetMap, "", "\t") + w.Write([]byte(j)) + return +} + +func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) { + fmt.Println("FRAN handle ping") + var dialParams struct { + Addr string + } + bs, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + err = json.Unmarshal(bs, &dialParams) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Println("dialing", dialParams.Addr) + c, err := m.ts.Dial(r.Context(), "tcp", dialParams.Addr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Println("ping success", c) + defer c.Close() + w.Write([]byte("ok\n")) + return +} diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go new file mode 100644 index 000000000..9826bac29 --- /dev/null +++ b/tsconsensus/tsconsensus.go @@ -0,0 +1,363 @@ +package tsconsensus + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "slices" + "time" + + "github.com/hashicorp/raft" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" +) + +/* +Package tsconsensus implements a consensus algorithm for a group of tsnet.Servers + +The Raft consensus algorithm relies on you implementing a state machine that will give the same +result to a give command as long as the same logs have been applied in the same order. + +tsconsensus uses the hashicorp/raft library to implement leader elections and log application. + +tsconsensus provides: + * cluster peer discovery based on tailscale tags + * executing a command on the leader + * communication between cluster peers over tailscale using tsnet + +Users implement a state machine that satisfies the raft.FSM interface, with the business logic they desire. +When changes to state are needed any node may + * create a Command instance with serialized Args. + * call ExecuteCommand with the Command instance + this will propagate the command to the leader, + and then from the reader to every node via raft. + * the state machine then can implement raft.Apply, and dispatch commands via the Command.Name + returning a CommandResult with an Err or a serialized Result. +*/ + +func addr(host string, port uint16) string { + return fmt.Sprintf("%s:%d", host, port) +} + +func raftAddr(host string, cfg Config) string { + return addr(host, cfg.RaftPort) +} + +// A SelfRaftNode is the info we need to talk to hashicorp/raft about our node. +// We specify the ID and Addr on Consensus Start, and then use it later for raft +// operations such as BootstrapCluster and AddVoter. +type SelfRaftNode struct { + ID string + Host string +} + +// A Config holds configurable values such as ports and timeouts. +// Use DefaultConfig to get a useful Config. +type Config struct { + CommandPort uint16 + RaftPort uint16 + MonitorPort uint16 + Raft *raft.Config + MaxConnPool int + ConnTimeout time.Duration +} + +// DefaultConfig returns a Config populated with default values ready for use. +func DefaultConfig() Config { + return Config{ + CommandPort: 6271, + RaftPort: 6270, + MonitorPort: 8081, + Raft: raft.DefaultConfig(), + MaxConnPool: 5, + ConnTimeout: 5 * time.Second, + } +} + +// StreamLayer implements an interface asked for by raft.NetworkTransport. +// It does the raft interprocess communication via tailscale. +type StreamLayer struct { + net.Listener + s *tsnet.Server +} + +// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. +func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + ctx, _ := context.WithTimeout(context.Background(), timeout) + return sl.s.Dial(ctx, "tcp", string(address)) +} + +// Start returns a pointer to a running Consensus instance. +func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string, cfg Config) (*Consensus, error) { + v4, _ := ts.TailscaleIPs() + cc := commandClient{ + port: cfg.CommandPort, + httpClient: ts.HTTPClient(), + } + self := SelfRaftNode{ + ID: v4.String(), + Host: v4.String(), + } + c := Consensus{ + CommandClient: &cc, + Self: self, + Config: cfg, + } + + lc, err := ts.LocalClient() + if err != nil { + return nil, err + } + tStatus, err := lc.Status(ctx) + if err != nil { + return nil, err + } + var targets []*ipnstate.PeerStatus + if targetTag != "" && tStatus.Self.Tags != nil && slices.Contains(tStatus.Self.Tags.AsSlice(), targetTag) { + for _, v := range tStatus.Peer { + if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), targetTag) { + targets = append(targets, v) + } + } + } else { + return nil, errors.New("targetTag empty, or this node is not tagged with it") + } + + r, err := startRaft(ts, &fsm, c.Self, cfg) + if err != nil { + return nil, err + } + c.Raft = r + srv, err := c.serveCmdHttp(ts) + if err != nil { + return nil, err + } + c.cmdHttpServer = srv + c.bootstrap(targets) + srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort)) + if err != nil { + return nil, err + } + c.monitorHttpServer = srv + return &c, nil +} + +func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) (*raft.Raft, error) { + config := cfg.Raft + config.LocalID = raft.ServerID(self.ID) + + // no persistence (for now?) + logStore := raft.NewInmemStore() + stableStore := raft.NewInmemStore() + snapshots := raft.NewInmemSnapshotStore() + + // opens the listener on the raft port, raft will close it when it thinks it's appropriate + ln, err := ts.Listen("tcp", raftAddr(self.Host, cfg)) + if err != nil { + return nil, err + } + + transport := raft.NewNetworkTransport(StreamLayer{ + s: ts, + Listener: ln, + }, + cfg.MaxConnPool, + cfg.ConnTimeout, + nil) // TODO pass in proper logging + + // after NewRaft it's possible some other raft node that has us in their configuration will get + // in contact, so by the time we do anything else we may already be a functioning member + // of a consensus + return raft.NewRaft(config, *fsm, logStore, stableStore, snapshots, transport) +} + +// A Consensus is the consensus algorithm for a tsnet.Server +// It wraps a raft.Raft instance and performs the peer discovery +// and command execution on the leader. +type Consensus struct { + Raft *raft.Raft + CommandClient *commandClient + Self SelfRaftNode + Config Config + cmdHttpServer *http.Server + monitorHttpServer *http.Server +} + +// bootstrap tries to join a raft cluster, or start one. +// +// We need to do the very first raft cluster configuration, but after that raft manages it. +// bootstrap is called at start up, and we are not currently aware of what the cluster config might be, +// our node may already be in it. Try to join the raft cluster of all the other nodes we know about, and +// if unsuccessful, assume we are the first and start our own. +// +// It's possible for bootstrap to return an error, or start a errant breakaway cluster. +// +// We have a list of expected cluster members already from control (the members of the tailnet with the tag) +// so we could do the initial configuration with all servers specified. +// Choose to start with just this machine in the raft configuration instead, as: +// - We want to handle machines joining after start anyway. +// - Not all tagged nodes tailscale believes are active are necessarily actually responsive right now, +// so let each node opt in when able. +func (c *Consensus) bootstrap(targets []*ipnstate.PeerStatus) error { + log.Printf("Trying to find cluster: num targets to try: %d", len(targets)) + for _, p := range targets { + if !p.Online { + log.Printf("Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0]) + } else { + log.Printf("Trying to find cluster: trying %s", p.TailscaleIPs[0]) + err := c.CommandClient.Join(p.TailscaleIPs[0].String(), joinRequest{ + RemoteHost: c.Self.Host, + RemoteID: c.Self.ID, + }) + if err != nil { + log.Printf("Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err) + } else { + log.Printf("Trying to find cluster: joined %s", p.TailscaleIPs[0]) + return nil + } + } + } + + log.Printf("Trying to find cluster: unsuccessful, starting as leader: %s", c.Self.Host) + f := c.Raft.BootstrapCluster( + raft.Configuration{ + Servers: []raft.Server{ + { + ID: raft.ServerID(c.Self.ID), + Address: raft.ServerAddress(c.raftAddr(c.Self.Host)), + }, + }, + }) + return f.Error() +} + +// ExecuteCommand propagates a Command to be executed on the leader. Which +// uses raft to Apply it to the followers. +func (c *Consensus) ExecuteCommand(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + result, err := c.executeCommandLocally(cmd) + var leErr lookElsewhereError + for errors.As(err, &leErr) { + result, err = c.CommandClient.ExecuteCommand(leErr.where, b) + } + return result, err +} + +// Stop attempts to gracefully shutdown various components. +func (c *Consensus) Stop(ctx context.Context) error { + fut := c.Raft.Shutdown() + err := fut.Error() + if err != nil { + log.Printf("Stop: Error in Raft Shutdown: %v", err) + } + err = c.cmdHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in command HTTP Shutdown: %v", err) + } + err = c.monitorHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in monitor HTTP Shutdown: %v", err) + } + return nil +} + +// A Command is a representation of a state machine action. +// The Name can be used to dispatch the command when received. +// The Args are serialized for transport. +type Command struct { + Name string + Args []byte +} + +// A CommandResult is a representation of the result of a state +// machine action. +// Err is any error that occurred on the node that tried to execute the command, +// including any error from the underlying operation and deserialization problems etc. +// Result is serialized for transport. +type CommandResult struct { + Err error + Result []byte +} + +type lookElsewhereError struct { + where string +} + +func (e lookElsewhereError) Error() string { + return fmt.Sprintf("not the leader, try: %s", e.where) +} + +var ErrLeaderUnknown = errors.New("Leader Unknown") + +func (c *Consensus) serveCmdHttp(ts *tsnet.Server) (*http.Server, error) { + ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host)) + if err != nil { + return nil, err + } + mux := c.makeCommandMux() + srv := &http.Server{Handler: mux} + go func() { + defer ln.Close() + err := srv.Serve(ln) + log.Printf("CmdHttp stopped serving with err: %v", err) + }() + return srv, nil +} + +func (c *Consensus) getLeader() (string, error) { + raftLeaderAddr, _ := c.Raft.LeaderWithID() + leaderAddr := (string)(raftLeaderAddr) + if leaderAddr == "" { + // Raft doesn't know who the leader is. + return "", ErrLeaderUnknown + } + // Raft gives us the address with the raft port, we don't always want that. + host, _, err := net.SplitHostPort(leaderAddr) + return host, err +} + +func (c *Consensus) executeCommandLocally(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + f := c.Raft.Apply(b, 10*time.Second) + err = f.Error() + result := f.Response() + if errors.Is(err, raft.ErrNotLeader) { + leader, err := c.getLeader() + if err != nil { + // we know we're not leader but we were unable to give the address of the leader + return CommandResult{}, err + } + return CommandResult{}, lookElsewhereError{where: leader} + } + if result == nil { + result = CommandResult{} + } + return result.(CommandResult), err +} + +func (c *Consensus) handleJoin(jr joinRequest) error { + remoteAddr := c.raftAddr(jr.RemoteHost) + f := c.Raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(remoteAddr), 0, 0) + if f.Error() != nil { + return f.Error() + } + return nil +} + +func (c *Consensus) raftAddr(host string) string { + return raftAddr(host, c.Config) +} + +func (c *Consensus) commandAddr(host string) string { + return addr(host, c.Config.CommandPort) +} diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go new file mode 100644 index 000000000..d7a669f4f --- /dev/null +++ b/tsconsensus/tsconsensus_test.go @@ -0,0 +1,459 @@ +package tsconsensus + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/hashicorp/raft" + "tailscale.com/client/tailscale" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netns" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/nettest" + "tailscale.com/types/key" + "tailscale.com/types/logger" +) + +type fsm struct { + events []map[string]interface{} + count int +} +type fsmSnapshot struct{} + +func (f *fsm) Apply(l *raft.Log) interface{} { + f.count++ + f.events = append(f.events, map[string]interface{}{ + "type": "Apply", + "l": l, + }) + return CommandResult{ + Result: []byte{byte(f.count)}, + } +} + +func (f *fsm) Snapshot() (raft.FSMSnapshot, error) { + panic("Snapshot unexpectedly used") + return nil, nil +} + +func (f *fsm) Restore(rc io.ReadCloser) error { + panic("Restore unexpectedly used") + return nil +} + +func (f *fsmSnapshot) Persist(sink raft.SnapshotSink) error { + panic("Persist unexpectedly used") + return nil +} + +func (f *fsmSnapshot) Release() { + panic("Release unexpectedly used") +} + +var verboseDERP = false +var verboseNodes = false + +// TODO copied from sniproxy_test +func startControl(t *testing.T) (control *testcontrol.Server, controlURL string) { + // Corp#4520: don't use netns for tests. + netns.SetEnabled(false) + t.Cleanup(func() { + netns.SetEnabled(true) + }) + + derpLogf := logger.Discard + if verboseDERP { + derpLogf = t.Logf + } + derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") + control = &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{ + Proxied: true, + }, + MagicDNSDomain: "tail-scale.ts.net", + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + control.HTTPTestServer.Start() + t.Cleanup(control.HTTPTestServer.Close) + controlURL = control.HTTPTestServer.URL + t.Logf("testcontrol listening on %s", controlURL) + return control, controlURL +} + +// TODO copied from sniproxy_test +func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) { + t.Helper() + + tmp := filepath.Join(t.TempDir(), hostname) + os.MkdirAll(tmp, 0755) + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, + } + if verboseNodes { + s.Logf = log.Printf + } + t.Cleanup(func() { s.Close() }) + + status, err := s.Up(ctx) + if err != nil { + t.Fatal(err) + } + return s, status.Self.PublicKey, status.TailscaleIPs[0] +} + +func pingNode(t *testing.T, control *testcontrol.Server, nodeKey key.NodePublic) { + t.Helper() + gotPing := make(chan bool, 1) + waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPing <- true + })) + defer waitPing.Close() + + for try := 0; try < 5; try++ { + pr := &tailcfg.PingRequest{URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try), Log: true} + if !control.AddPingRequest(nodeKey, pr) { + t.Fatalf("failed to AddPingRequest") + } + pingTimeout := time.NewTimer(2 * time.Second) + defer pingTimeout.Stop() + select { + case <-gotPing: + // ok! the machinery that refreshes the netmap has been nudged + return + case <-pingTimeout.C: + t.Logf("waiting for ping timed out: %d", try) + } + } +} + +func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) { + t.Helper() + for _, key := range nodeKeys { + n := control.Node(key) + n.Tags = append(n.Tags, tag) + b := true + n.Online = &b + control.UpdateNode(n) + } + + // all this ping stuff is only to prod the netmap to get updated with the tag we just added to the node + // ie to actually get the netmap issued to clients that represents the current state of the nodes + // there _must_ be a better way to do this, but I looked all day and this was the first thing I found that worked. + for _, key := range nodeKeys { + pingNode(t, control, key) + } +} + +// TODO test start with al lthe config settings +func TestStart(t *testing.T) { + nettest.SkipIfNoNetwork(t) + control, controlURL := startControl(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + one, k, _ := startNode(t, ctx, controlURL, "one") + + clusterTag := "tag:whatever" + // nodes must be tagged with the cluster tag, to find each other + tagNodes(t, control, []key.NodePublic{k}, clusterTag) + + sm := &fsm{} + r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig()) + if err != nil { + t.Fatal(err) + } + defer r.Stop(ctx) +} + +func waitFor(t *testing.T, msg string, condition func() bool, nTries int, waitBetweenTries time.Duration) { + for try := 0; try < nTries; try++ { + done := condition() + if done { + t.Logf("waitFor success: %s: after %d tries", msg, try) + return + } + time.Sleep(waitBetweenTries) + } + t.Fatalf("waitFor timed out: %s, after %d tries", msg, nTries) +} + +type participant struct { + c *Consensus + sm *fsm + ts *tsnet.Server + key key.NodePublic +} + +// starts and tags the *tsnet.Server nodes with the control, waits for the nodes to make successful +// LocalClient Status calls that show the first node as Online. +func startNodesAndWaitForPeerStatus(t *testing.T, ctx context.Context, clusterTag string, nNodes int) ([]*participant, *testcontrol.Server, string) { + ps := make([]*participant, nNodes) + keysToTag := make([]key.NodePublic, nNodes) + localClients := make([]*tailscale.LocalClient, nNodes) + control, controlURL := startControl(t) + for i := 0; i < nNodes; i++ { + ts, key, _ := startNode(t, ctx, controlURL, fmt.Sprintf("node: %d", i)) + ps[i] = &participant{ts: ts, key: key} + keysToTag[i] = key + lc, err := ts.LocalClient() + if err != nil { + t.Fatalf("%d: error getting local client: %v", i, err) + } + localClients[i] = lc + } + tagNodes(t, control, keysToTag, clusterTag) + fxCameOnline := func() bool { + // all the _other_ nodes see the first as online + for i := 1; i < nNodes; i++ { + status, err := localClients[i].Status(ctx) + if err != nil { + t.Fatalf("%d: error getting status: %v", i, err) + } + if !status.Peer[ps[0].key].Online { + return false + } + } + return true + } + waitFor(t, "other nodes see node 1 online in ts status", fxCameOnline, 10, 2*time.Second) + return ps, control, controlURL +} + +// populates participants with their consensus fields, waits for all nodes to show all nodes +// as part of the same consensus cluster. Starts the first participant first and waits for it to +// become leader before adding other nodes. +func createConsensusCluster(t *testing.T, ctx context.Context, clusterTag string, participants []*participant, cfg Config) { + participants[0].sm = &fsm{} + first, err := Start(ctx, participants[0].ts, (*fsm)(participants[0].sm), clusterTag, cfg) + if err != nil { + t.Fatal(err) + } + fxFirstIsLeader := func() bool { + return first.Raft.State() == raft.Leader + } + waitFor(t, "node 0 is leader", fxFirstIsLeader, 10, 2*time.Second) + participants[0].c = first + + for i := 1; i < len(participants); i++ { + participants[i].sm = &fsm{} + c, err := Start(ctx, participants[i].ts, (*fsm)(participants[i].sm), clusterTag, cfg) + if err != nil { + t.Fatal(err) + } + participants[i].c = c + } + + fxRaftConfigContainsAll := func() bool { + for i := 0; i < len(participants); i++ { + fut := participants[i].c.Raft.GetConfiguration() + err = fut.Error() + if err != nil { + t.Fatalf("%d: Getting Configuration errored: %v", i, err) + } + if len(fut.Configuration().Servers) != len(participants) { + return false + } + } + return true + } + waitFor(t, "all raft machines have all servers in their config", fxRaftConfigContainsAll, 10, time.Second*2) +} + +func TestApply(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 2) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + + fut := ps[0].c.Raft.Apply([]byte("woo"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Raft Apply Error: %v", err) + } + + fxBothMachinesHaveTheApply := func() bool { + return len(ps[0].sm.events) == 1 && len(ps[1].sm.events) == 1 + } + waitFor(t, "the apply event made it into both state machines", fxBothMachinesHaveTheApply, 10, time.Second*1) +} + +// calls ExecuteCommand on each participant and checks that all participants get all commands +func assertCommandsWorkOnAnyNode(t *testing.T, participants []*participant) { + for i, p := range participants { + res, err := p.c.ExecuteCommand(Command{Args: []byte{byte(i)}}) + if err != nil { + t.Fatalf("%d: Error ExecuteCommand: %v", i, err) + } + if res.Err != nil { + t.Fatalf("%d: Result Error ExecuteCommand: %v", i, res.Err) + } + retVal := int(res.Result[0]) + // the test implementation of the fsm returns the count of events that have been received + if retVal != i+1 { + t.Fatalf("Result, want %d, got %d", i+1, retVal) + } + + expectedEventsLength := i + 1 + fxEventsInAll := func() bool { + for _, pOther := range participants { + if len(pOther.sm.events) != expectedEventsLength { + return false + } + } + return true + } + waitFor(t, "event makes it to all", fxEventsInAll, 10, time.Second*1) + } +} + +func TestConfig(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + // test all is well with non default ports + cfg.CommandPort = 12347 + cfg.RaftPort = 11882 + mp := uint16(8798) + cfg.MonitorPort = mp + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + assertCommandsWorkOnAnyNode(t, ps) + + url := fmt.Sprintf("http://%s:%d/", ps[0].c.Self.Host, mp) + httpClientOnTailnet := ps[1].ts.HTTPClient() + rsp, err := httpClientOnTailnet.Get(url) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != 200 { + t.Fatalf("monitor status want %d, got %d", 200, rsp.StatusCode) + } + body, err := io.ReadAll(rsp.Body) + if err != nil { + t.Fatal(err) + } + // Not a great assertion because it relies on the format of the response. + line1 := strings.Split(string(body), "\n")[0] + if line1[:10] != "RaftState:" { + t.Fatalf("getting monitor status, first line, want something that starts with 'RaftState:', got '%s'", line1) + } +} + +func TestFollowerFailover(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + + smThree := ps[2].sm + + fut := ps[0].c.Raft.Apply([]byte("a"), 2*time.Second) + futTwo := ps[0].c.Raft.Apply([]byte("b"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futTwo.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + + fxAllMachinesHaveTheApplies := func() bool { + return len(ps[0].sm.events) == 2 && len(ps[1].sm.events) == 2 && len(smThree.events) == 2 + } + waitFor(t, "the apply events made it into all state machines", fxAllMachinesHaveTheApplies, 10, time.Second*1) + + //a follower goes loses contact with the cluster + ps[2].c.Stop(ctx) + + // applies still make it to one and two + futThree := ps[0].c.Raft.Apply([]byte("c"), 2*time.Second) + futFour := ps[0].c.Raft.Apply([]byte("d"), 2*time.Second) + err = futThree.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futFour.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + fxAliveMachinesHaveTheApplies := func() bool { + return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(smThree.events) == 2 + } + waitFor(t, "the apply events made it into eligible state machines", fxAliveMachinesHaveTheApplies, 10, time.Second*1) + + // follower comes back + smThreeAgain := &fsm{} + rThreeAgain, err := Start(ctx, ps[2].ts, (*fsm)(smThreeAgain), clusterTag, DefaultConfig()) + if err != nil { + t.Fatal(err) + } + defer rThreeAgain.Stop(ctx) + fxThreeGetsCaughtUp := func() bool { + return len(smThreeAgain.events) == 4 + } + waitFor(t, "the apply events made it into the third node when it appeared with an empty state machine", fxThreeGetsCaughtUp, 20, time.Second*2) + if len(smThree.events) != 2 { + t.Fatalf("Expected smThree to remain on 2 events: got %d", len(smThree.events)) + } +} + +func TestRejoin(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + // 1st node gets a redundant second join request from the second node + ps[0].c.handleJoin(joinRequest{ + RemoteHost: ps[1].c.Self.Host, + RemoteID: ps[1].c.Self.ID, + }) + + tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner") + tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag) + smJoiner := &fsm{} + cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg) + if err != nil { + t.Fatal(err) + } + ps = append(ps, &participant{ + sm: smJoiner, + c: cJoiner, + ts: tsJoiner, + key: keyJoiner, + }) + + assertCommandsWorkOnAnyNode(t, ps) +}