mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 10:01:05 +01:00 
			
		
		
		
	Merge branch 'main' into fix-route-notify
This commit is contained in:
		
						commit
						26024fedc7
					
				@ -2,6 +2,7 @@
 | 
				
			|||||||
// ignoring it let us speed up the integration test
 | 
					// ignoring it let us speed up the integration test
 | 
				
			||||||
// development
 | 
					// development
 | 
				
			||||||
integration_test.go
 | 
					integration_test.go
 | 
				
			||||||
 | 
					integration_test/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Dockerfile*
 | 
					Dockerfile*
 | 
				
			||||||
docker-compose*
 | 
					docker-compose*
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										119
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										119
									
								
								api.go
									
									
									
									
									
								
							@ -286,21 +286,6 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	h.db.Save(&m)
 | 
						h.db.Save(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	update := make(chan []byte, 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pollData := make(chan []byte, 1)
 | 
					 | 
				
			||||||
	defer close(pollData)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	cancelKeepAlive := make(chan []byte, 1)
 | 
					 | 
				
			||||||
	defer close(cancelKeepAlive)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Storing update channel")
 | 
					 | 
				
			||||||
	h.clientsPolling.Store(m.ID, update)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	data, err := h.getMapResponse(mKey, req, m)
 | 
						data, err := h.getMapResponse(mKey, req, m)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().
 | 
							log.Error().
 | 
				
			||||||
@ -351,6 +336,23 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Only create update channel if it has not been created
 | 
				
			||||||
 | 
						var update chan []byte
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Creating or loading update channel")
 | 
				
			||||||
 | 
						if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
 | 
				
			||||||
 | 
							update = result.(chan []byte)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pollData := make(chan []byte, 1)
 | 
				
			||||||
 | 
						defer close(pollData)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cancelKeepAlive := make(chan []byte, 1)
 | 
				
			||||||
 | 
						defer close(cancelKeepAlive)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Info().
 | 
						log.Info().
 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
		Str("machine", m.Name).
 | 
							Str("machine", m.Name).
 | 
				
			||||||
@ -365,87 +367,15 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 | 
				
			|||||||
		Str("handler", "PollNetMap").
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
		Str("machine", m.Name).
 | 
							Str("machine", m.Name).
 | 
				
			||||||
		Msg("Notifying peers")
 | 
							Msg("Notifying peers")
 | 
				
			||||||
	peers, _ := h.getPeers(m)
 | 
							// TODO: Why does this block?
 | 
				
			||||||
	for _, p := range *peers {
 | 
						go h.notifyChangesToPeers(&m)
 | 
				
			||||||
		pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			log.Info().
 | 
					 | 
				
			||||||
				Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Str("peer", m.Name).
 | 
					 | 
				
			||||||
				Str("address", p.Addresses[0].String()).
 | 
					 | 
				
			||||||
				Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
 | 
					 | 
				
			||||||
			pUp.(chan []byte) <- []byte{}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			log.Info().
 | 
					 | 
				
			||||||
				Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Str("peer", m.Name).
 | 
					 | 
				
			||||||
				Msgf("Peer %s does not appear to be polling", p.Name)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
 | 
						h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
					 | 
				
			||||||
		select {
 | 
					 | 
				
			||||||
		case data := <-pollData:
 | 
					 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
		Str("machine", m.Name).
 | 
							Str("machine", m.Name).
 | 
				
			||||||
				Int("bytes", len(data)).
 | 
							Msg("Finished stream, closing PollNetMap session")
 | 
				
			||||||
				Msg("Sending data")
 | 
					 | 
				
			||||||
			_, err := w.Write(data)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
					Str("machine", m.Name).
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Msg("Cannot write data")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			now := time.Now().UTC()
 | 
					 | 
				
			||||||
			m.LastSeen = &now
 | 
					 | 
				
			||||||
			h.db.Save(&m)
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		case <-update:
 | 
					 | 
				
			||||||
			log.Debug().
 | 
					 | 
				
			||||||
				Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Msg("Received a request for update")
 | 
					 | 
				
			||||||
			data, err := h.getMapResponse(mKey, req, m)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
					Str("machine", m.Name).
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Msg("Could not get the map update")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			_, err = w.Write(*data)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
					Str("machine", m.Name).
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Msg("Could not write the map response")
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		case <-c.Request.Context().Done():
 | 
					 | 
				
			||||||
			log.Info().
 | 
					 | 
				
			||||||
				Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Msg("The client has closed the connection")
 | 
					 | 
				
			||||||
			now := time.Now().UTC()
 | 
					 | 
				
			||||||
			m.LastSeen = &now
 | 
					 | 
				
			||||||
			h.db.Save(&m)
 | 
					 | 
				
			||||||
			cancelKeepAlive <- []byte{}
 | 
					 | 
				
			||||||
			h.clientsPolling.Delete(m.ID)
 | 
					 | 
				
			||||||
			close(update)
 | 
					 | 
				
			||||||
			return false
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
 | 
					func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
 | 
				
			||||||
@ -514,10 +444,15 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
 | 
				
			|||||||
		DERPMap:      h.cfg.DerpMap,
 | 
							DERPMap:      h.cfg.DerpMap,
 | 
				
			||||||
		UserProfiles: []tailcfg.UserProfile{profile},
 | 
							UserProfiles: []tailcfg.UserProfile{profile},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("func", "getMapResponse").
 | 
				
			||||||
 | 
							Str("machine", req.Hostinfo.Hostname).
 | 
				
			||||||
 | 
							Msgf("Generated map response: %s", tailMapResponseToString(resp))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var respBody []byte
 | 
						var respBody []byte
 | 
				
			||||||
	if req.Compress == "zstd" {
 | 
						if req.Compress == "zstd" {
 | 
				
			||||||
		src, _ := json.Marshal(resp)
 | 
							src, _ := json.Marshal(resp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		encoder, _ := zstd.NewWriter(nil)
 | 
							encoder, _ := zstd.NewWriter(nil)
 | 
				
			||||||
		srcCompressed := encoder.EncodeAll(src, nil)
 | 
							srcCompressed := encoder.EncodeAll(src, nil)
 | 
				
			||||||
		respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
 | 
							respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								app.go
									
									
									
									
									
								
							@ -107,9 +107,9 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
	http.Redirect(w, req, target, http.StatusFound)
 | 
						http.Redirect(w, req, target, http.StatusFound)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ExpireEphemeralNodes deletes ephemeral machine records that have not been
 | 
					// expireEphemeralNodes deletes ephemeral machine records that have not been
 | 
				
			||||||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout
 | 
					// seen for longer than h.cfg.EphemeralNodeInactivityTimeout
 | 
				
			||||||
func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) {
 | 
					func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
 | 
				
			||||||
	ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
 | 
						ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
 | 
				
			||||||
	for range ticker.C {
 | 
						for range ticker.C {
 | 
				
			||||||
		h.expireEphemeralNodesWorker()
 | 
							h.expireEphemeralNodesWorker()
 | 
				
			||||||
@ -135,6 +135,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
 | 
				
			|||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					log.Error().Err(err).Str("machine", m.Name).Msg("🤮 Cannot delete ephemeral machine from the database")
 | 
										log.Error().Err(err).Str("machine", m.Name).Msg("🤮 Cannot delete ephemeral machine from the database")
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									h.notifyChangesToPeers(&m)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -165,6 +166,7 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go h.watchForKVUpdates(5000)
 | 
						go h.watchForKVUpdates(5000)
 | 
				
			||||||
 | 
						go h.expireEphemeralNodes(5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.cfg.TLSLetsEncryptHostname != "" {
 | 
						if h.cfg.TLSLetsEncryptHostname != "" {
 | 
				
			||||||
		if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
							if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,7 @@ var serveCmd = &cobra.Command{
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error initializing: %s", err)
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		go h.ExpireEphemeralNodes(5000)
 | 
					
 | 
				
			||||||
		err = h.Serve()
 | 
							err = h.Serve()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error initializing: %s", err)
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@ -22,6 +22,7 @@ require (
 | 
				
			|||||||
	github.com/rs/zerolog v1.23.0 // indirect
 | 
						github.com/rs/zerolog v1.23.0 // indirect
 | 
				
			||||||
	github.com/spf13/cobra v1.1.3
 | 
						github.com/spf13/cobra v1.1.3
 | 
				
			||||||
	github.com/spf13/viper v1.8.1
 | 
						github.com/spf13/viper v1.8.1
 | 
				
			||||||
 | 
						github.com/stretchr/testify v1.7.0 // indirect
 | 
				
			||||||
	github.com/tailscale/hujson v0.0.0-20200924210142-dde312d0d6a2
 | 
						github.com/tailscale/hujson v0.0.0-20200924210142-dde312d0d6a2
 | 
				
			||||||
	github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
 | 
						github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e
 | 
						golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.sum
									
									
									
									
									
								
							@ -817,6 +817,7 @@ github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5J
 | 
				
			|||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
					github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
				
			||||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
					github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
				
			||||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
 | 
					github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
 | 
				
			||||||
 | 
					github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As=
 | 
				
			||||||
github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
 | 
					github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
 | 
				
			||||||
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 | 
					github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 | 
				
			||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 | 
					github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 | 
				
			||||||
 | 
				
			|||||||
@ -9,17 +9,24 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/ory/dockertest/v3"
 | 
						"github.com/ory/dockertest/v3"
 | 
				
			||||||
	"github.com/ory/dockertest/v3/docker"
 | 
						"github.com/ory/dockertest/v3/docker"
 | 
				
			||||||
	"inet.af/netaddr"
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/suite"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gopkg.in/check.v1"
 | 
						"inet.af/netaddr"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var _ = check.Suite(&IntegrationSuite{})
 | 
					type IntegrationTestSuite struct {
 | 
				
			||||||
 | 
						suite.Suite
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type IntegrationSuite struct{}
 | 
					func TestIntegrationTestSuite(t *testing.T) {
 | 
				
			||||||
 | 
						suite.Run(t, new(IntegrationTestSuite))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var integrationTmpDir string
 | 
					var integrationTmpDir string
 | 
				
			||||||
var ih Headscale
 | 
					var ih Headscale
 | 
				
			||||||
@ -27,7 +34,7 @@ var ih Headscale
 | 
				
			|||||||
var pool dockertest.Pool
 | 
					var pool dockertest.Pool
 | 
				
			||||||
var network dockertest.Network
 | 
					var network dockertest.Network
 | 
				
			||||||
var headscale dockertest.Resource
 | 
					var headscale dockertest.Resource
 | 
				
			||||||
var tailscaleCount int = 10
 | 
					var tailscaleCount int = 5
 | 
				
			||||||
var tailscales map[string]dockertest.Resource
 | 
					var tailscales map[string]dockertest.Resource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
 | 
					func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
 | 
				
			||||||
@ -63,7 +70,7 @@ func dockerRestartPolicy(config *docker.HostConfig) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
					func (s *IntegrationTestSuite) SetupSuite() {
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
	h = Headscale{
 | 
						h = Headscale{
 | 
				
			||||||
		dbType:   "sqlite3",
 | 
							dbType:   "sqlite3",
 | 
				
			||||||
@ -104,7 +111,6 @@ func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
				
			|||||||
			fmt.Sprintf("%s/derp.yaml:/etc/headscale/derp.yaml", currentPath),
 | 
								fmt.Sprintf("%s/derp.yaml:/etc/headscale/derp.yaml", currentPath),
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Networks: []*dockertest.Network{&network},
 | 
							Networks: []*dockertest.Network{&network},
 | 
				
			||||||
		// Cmd: []string{"sleep", "3600"},
 | 
					 | 
				
			||||||
		Cmd:      []string{"headscale", "serve"},
 | 
							Cmd:      []string{"headscale", "serve"},
 | 
				
			||||||
		PortBindings: map[docker.Port][]docker.PortBinding{
 | 
							PortBindings: map[docker.Port][]docker.PortBinding{
 | 
				
			||||||
			"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
 | 
								"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
 | 
				
			||||||
@ -127,8 +133,6 @@ func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
				
			|||||||
		tailscaleOptions := &dockertest.RunOptions{
 | 
							tailscaleOptions := &dockertest.RunOptions{
 | 
				
			||||||
			Name:     hostname,
 | 
								Name:     hostname,
 | 
				
			||||||
			Networks: []*dockertest.Network{&network},
 | 
								Networks: []*dockertest.Network{&network},
 | 
				
			||||||
			// Make the container run until killed
 | 
					 | 
				
			||||||
			// Cmd: []string{"sleep", "3600"},
 | 
					 | 
				
			||||||
			Cmd:      []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
 | 
								Cmd:      []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
 | 
				
			||||||
			Env:      []string{},
 | 
								Env:      []string{},
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -141,6 +145,7 @@ func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
				
			|||||||
		fmt.Printf("Created %s container\n", hostname)
 | 
							fmt.Printf("Created %s container\n", hostname)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: Replace this logic with something that can be detected on Github Actions
 | 
				
			||||||
	fmt.Println("Waiting for headscale to be ready")
 | 
						fmt.Println("Waiting for headscale to be ready")
 | 
				
			||||||
	hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
 | 
						hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -164,14 +169,14 @@ func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
				
			|||||||
		&headscale,
 | 
							&headscale,
 | 
				
			||||||
		[]string{"headscale", "namespaces", "create", "test"},
 | 
							[]string{"headscale", "namespaces", "create", "test"},
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Println("Creating pre auth key")
 | 
						fmt.Println("Creating pre auth key")
 | 
				
			||||||
	authKey, err := executeCommand(
 | 
						authKey, err := executeCommand(
 | 
				
			||||||
		&headscale,
 | 
							&headscale,
 | 
				
			||||||
		[]string{"headscale", "-n", "test", "preauthkeys", "create", "--reusable", "--expiration", "24h"},
 | 
							[]string{"headscale", "-n", "test", "preauthkeys", "create", "--reusable", "--expiration", "24h"},
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	headscaleEndpoint := fmt.Sprintf("http://headscale:%s", headscale.GetPort("8080/tcp"))
 | 
						headscaleEndpoint := fmt.Sprintf("http://headscale:%s", headscale.GetPort("8080/tcp"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -186,12 +191,16 @@ func (s *IntegrationSuite) SetUpSuite(c *check.C) {
 | 
				
			|||||||
			command,
 | 
								command,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		fmt.Println("tailscale result: ", result)
 | 
							fmt.Println("tailscale result: ", result)
 | 
				
			||||||
		c.Assert(err, check.IsNil)
 | 
							assert.Nil(s.T(), err)
 | 
				
			||||||
		fmt.Printf("%s joined\n", hostname)
 | 
							fmt.Printf("%s joined\n", hostname)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The nodes need a bit of time to get their updated maps from headscale
 | 
				
			||||||
 | 
						// TODO: See if we can have a more deterministic wait here.
 | 
				
			||||||
 | 
						time.Sleep(20 * time.Second)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationSuite) TearDownSuite(c *check.C) {
 | 
					func (s *IntegrationTestSuite) TearDownSuite() {
 | 
				
			||||||
	if err := pool.Purge(&headscale); err != nil {
 | 
						if err := pool.Purge(&headscale); err != nil {
 | 
				
			||||||
		log.Printf("Could not purge resource: %s\n", err)
 | 
							log.Printf("Could not purge resource: %s\n", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -207,21 +216,102 @@ func (s *IntegrationSuite) TearDownSuite(c *check.C) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationSuite) TestListNodes(c *check.C) {
 | 
					func (s *IntegrationTestSuite) TestListNodes() {
 | 
				
			||||||
	fmt.Println("Listing nodes")
 | 
						fmt.Println("Listing nodes")
 | 
				
			||||||
	result, err := executeCommand(
 | 
						result, err := executeCommand(
 | 
				
			||||||
		&headscale,
 | 
							&headscale,
 | 
				
			||||||
		[]string{"headscale", "-n", "test", "nodes", "list"},
 | 
							[]string{"headscale", "-n", "test", "nodes", "list"},
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("List nodes: \n%s\n", result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Chck that the correct count of host is present in node list
 | 
				
			||||||
 | 
						lines := strings.Split(result, "\n")
 | 
				
			||||||
 | 
						assert.Equal(s.T(), len(tailscales), len(lines)-2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for hostname, _ := range tailscales {
 | 
						for hostname, _ := range tailscales {
 | 
				
			||||||
		c.Assert(strings.Contains(result, hostname), check.Equals, true)
 | 
							assert.Contains(s.T(), result, hostname)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationSuite) TestGetIpAddresses(c *check.C) {
 | 
					func (s *IntegrationTestSuite) TestGetIpAddresses() {
 | 
				
			||||||
	ipPrefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
 | 
						ipPrefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
 | 
				
			||||||
 | 
						ips, err := getIPs()
 | 
				
			||||||
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for hostname, _ := range tailscales {
 | 
				
			||||||
 | 
							s.T().Run(hostname, func(t *testing.T) {
 | 
				
			||||||
 | 
								ip := ips[hostname]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								fmt.Printf("IP for %s: %s\n", hostname, ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// c.Assert(ip.Valid(), check.IsTrue)
 | 
				
			||||||
 | 
								assert.True(t, ip.Is4())
 | 
				
			||||||
 | 
								assert.True(t, ipPrefix.Contains(ip))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								ips[hostname] = ip
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *IntegrationTestSuite) TestStatus() {
 | 
				
			||||||
 | 
						ips, err := getIPs()
 | 
				
			||||||
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for hostname, tailscale := range tailscales {
 | 
				
			||||||
 | 
							s.T().Run(hostname, func(t *testing.T) {
 | 
				
			||||||
 | 
								command := []string{"tailscale", "status"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								fmt.Printf("Getting status for %s\n", hostname)
 | 
				
			||||||
 | 
								result, err := executeCommand(
 | 
				
			||||||
 | 
									&tailscale,
 | 
				
			||||||
 | 
									command,
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
								assert.Nil(t, err)
 | 
				
			||||||
 | 
								// fmt.Printf("Status for %s: %s", hostname, result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Check if we have as many nodes in status
 | 
				
			||||||
 | 
								// as we have IPs/tailscales
 | 
				
			||||||
 | 
								lines := strings.Split(result, "\n")
 | 
				
			||||||
 | 
								assert.Equal(t, len(ips), len(lines)-1)
 | 
				
			||||||
 | 
								assert.Equal(t, len(tailscales), len(lines)-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Check that all hosts is present in all hosts status
 | 
				
			||||||
 | 
								for ipHostname, ip := range ips {
 | 
				
			||||||
 | 
									assert.Contains(t, result, ip.String())
 | 
				
			||||||
 | 
									assert.Contains(t, result, ipHostname)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *IntegrationTestSuite) TestPingAllPeers() {
 | 
				
			||||||
 | 
						ips, err := getIPs()
 | 
				
			||||||
 | 
						assert.Nil(s.T(), err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for hostname, tailscale := range tailscales {
 | 
				
			||||||
 | 
							for peername, ip := range ips {
 | 
				
			||||||
 | 
								s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
				
			||||||
 | 
									// We currently cant ping ourselves, so skip that.
 | 
				
			||||||
 | 
									if peername != hostname {
 | 
				
			||||||
 | 
										command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
 | 
				
			||||||
 | 
										result, err := executeCommand(
 | 
				
			||||||
 | 
											&tailscale,
 | 
				
			||||||
 | 
											command,
 | 
				
			||||||
 | 
										)
 | 
				
			||||||
 | 
										assert.Nil(t, err)
 | 
				
			||||||
 | 
										fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
				
			||||||
 | 
										assert.Contains(t, result, "pong")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getIPs() (map[string]netaddr.IP, error) {
 | 
				
			||||||
	ips := make(map[string]netaddr.IP)
 | 
						ips := make(map[string]netaddr.IP)
 | 
				
			||||||
	for hostname, tailscale := range tailscales {
 | 
						for hostname, tailscale := range tailscales {
 | 
				
			||||||
		command := []string{"tailscale", "ip"}
 | 
							command := []string{"tailscale", "ip"}
 | 
				
			||||||
@ -230,17 +320,16 @@ func (s *IntegrationSuite) TestGetIpAddresses(c *check.C) {
 | 
				
			|||||||
			&tailscale,
 | 
								&tailscale,
 | 
				
			||||||
			command,
 | 
								command,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		c.Assert(err, check.IsNil)
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ip, err := netaddr.ParseIP(strings.TrimSuffix(result, "\n"))
 | 
							ip, err := netaddr.ParseIP(strings.TrimSuffix(result, "\n"))
 | 
				
			||||||
		c.Assert(err, check.IsNil)
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
		fmt.Printf("IP for %s: %s", hostname, result)
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// c.Assert(ip.Valid(), check.IsTrue)
 | 
					 | 
				
			||||||
		c.Assert(ip.Is4(), check.Equals, true)
 | 
					 | 
				
			||||||
		c.Assert(ipPrefix.Contains(ip), check.Equals, true)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		ips[hostname] = ip
 | 
							ips[hostname] = ip
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return ips, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										31
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								machine.go
									
									
									
									
									
								
							@ -159,6 +159,10 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
 | 
					func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("func", "getPeers").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Finding peers")
 | 
				
			||||||
	machines := []Machine{}
 | 
						machines := []Machine{}
 | 
				
			||||||
	if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
 | 
						if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
 | 
				
			||||||
		m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
 | 
							m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
 | 
				
			||||||
@ -175,6 +179,11 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
 | 
				
			|||||||
		peers = append(peers, peer)
 | 
							peers = append(peers, peer)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
 | 
						sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("func", "getPeers").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msgf("Found peers: %s", tailNodesToString(peers))
 | 
				
			||||||
	return &peers, nil
 | 
						return &peers, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -238,3 +247,25 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return &hostinfo, nil
 | 
						return &hostinfo, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) notifyChangesToPeers(m *Machine) {
 | 
				
			||||||
 | 
						peers, _ := h.getPeers(*m)
 | 
				
			||||||
 | 
						for _, p := range *peers {
 | 
				
			||||||
 | 
							pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								log.Info().
 | 
				
			||||||
 | 
									Str("func", "notifyChangesToPeers").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("peer", m.Name).
 | 
				
			||||||
 | 
									Str("address", p.Addresses[0].String()).
 | 
				
			||||||
 | 
									Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
 | 
				
			||||||
 | 
								pUp.(chan []byte) <- []byte{}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								log.Info().
 | 
				
			||||||
 | 
									Str("func", "notifyChangesToPeers").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("peer", m.Name).
 | 
				
			||||||
 | 
									Msgf("Peer %s does not appear to be polling", p.Name)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -169,25 +169,7 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		for _, m := range *machines {
 | 
							for _, m := range *machines {
 | 
				
			||||||
			peers, _ := h.getPeers(m)
 | 
								h.notifyChangesToPeers(&m)
 | 
				
			||||||
			for _, p := range *peers {
 | 
					 | 
				
			||||||
				pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 | 
					 | 
				
			||||||
				if ok {
 | 
					 | 
				
			||||||
					log.Info().
 | 
					 | 
				
			||||||
						Str("func", "checkForNamespacesPendingUpdates").
 | 
					 | 
				
			||||||
						Str("machine", m.Name).
 | 
					 | 
				
			||||||
						Str("peer", m.Name).
 | 
					 | 
				
			||||||
						Str("address", p.Addresses[0].String()).
 | 
					 | 
				
			||||||
						Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
 | 
					 | 
				
			||||||
					pUp.(chan []byte) <- []byte{}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					log.Info().
 | 
					 | 
				
			||||||
						Str("func", "checkForNamespacesPendingUpdates").
 | 
					 | 
				
			||||||
						Str("machine", m.Name).
 | 
					 | 
				
			||||||
						Str("peer", m.Name).
 | 
					 | 
				
			||||||
						Msgf("Peer %s does not appear to be polling", p.Name)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	newV, err := h.getValue("namespaces_pending_updates")
 | 
						newV, err := h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										98
									
								
								poll.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								poll.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,98 @@
 | 
				
			|||||||
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
						"tailscale.com/types/wgkey"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) PollNetMapStream(
 | 
				
			||||||
 | 
						c *gin.Context,
 | 
				
			||||||
 | 
						m Machine,
 | 
				
			||||||
 | 
						req tailcfg.MapRequest,
 | 
				
			||||||
 | 
						mKey wgkey.Key,
 | 
				
			||||||
 | 
						pollData chan []byte,
 | 
				
			||||||
 | 
						update chan []byte,
 | 
				
			||||||
 | 
						cancelKeepAlive chan []byte,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
 | 
							log.Trace().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msg("Waiting for data to stream...")
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case data := <-pollData:
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Sending data received via pollData channel")
 | 
				
			||||||
 | 
								_, err := w.Write(data)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Cannot write data")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Data from pollData channel written successfully")
 | 
				
			||||||
 | 
								now := time.Now().UTC()
 | 
				
			||||||
 | 
								m.LastSeen = &now
 | 
				
			||||||
 | 
								h.db.Save(&m)
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Machine updated successfully after sending pollData")
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case <-update:
 | 
				
			||||||
 | 
								log.Debug().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msg("Received a request for update")
 | 
				
			||||||
 | 
								data, err := h.getMapResponse(mKey, req, m)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Could not get the map update")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								_, err = w.Write(*data)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Could not write the map response")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case <-c.Request.Context().Done():
 | 
				
			||||||
 | 
								log.Info().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msg("The client has closed the connection")
 | 
				
			||||||
 | 
								now := time.Now().UTC()
 | 
				
			||||||
 | 
								m.LastSeen = &now
 | 
				
			||||||
 | 
								h.db.Save(&m)
 | 
				
			||||||
 | 
								cancelKeepAlive <- []byte{}
 | 
				
			||||||
 | 
								h.clientsPolling.Delete(m.ID)
 | 
				
			||||||
 | 
								close(update)
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										17
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								utils.go
									
									
									
									
									
								
							@ -10,9 +10,11 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"golang.org/x/crypto/nacl/box"
 | 
						"golang.org/x/crypto/nacl/box"
 | 
				
			||||||
	"inet.af/netaddr"
 | 
						"inet.af/netaddr"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
	"tailscale.com/types/wgkey"
 | 
						"tailscale.com/types/wgkey"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -58,6 +60,7 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return encodeMsg(b, pubKey, privKey)
 | 
						return encodeMsg(b, pubKey, privKey)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -139,3 +142,17 @@ func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func tailNodesToString(nodes []*tailcfg.Node) string {
 | 
				
			||||||
 | 
						temp := make([]string, len(nodes))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for index, node := range nodes {
 | 
				
			||||||
 | 
							temp[index] = node.Name
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func tailMapResponseToString(resp tailcfg.MapResponse) string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user